# Quick alignment of spots

In [None]:
# Basic imports
import os,sys,re
from importlib import reload
import numpy as np
import pandas as pd
#import torch
print(os.getpid())

In [None]:
# ChromAn related imports
sys.path.append(r'/lab/weissman_imaging/puzheng/Softwares/') # parent folder of ChromAn
import ImageAnalysis3 as ia3
import h5py
from ImageAnalysis3.classes import _allowed_kwds
import ast
from ChromAn.src import file_io

from ChromAn.src.file_io import dax_process
from ChromAn.src.file_io import data_organization
from ChromAn.src.visual_tools import interactive
from ChromAn.src.correction_tools.alignment import generate_drift_crops

In [None]:
# data folder
data_folder = r'/lab/weissman_imaging/puzheng/PE_LT/20230926-4T1Hek3_preEdit_400k0907'
# scan subfolders
folders, fovs = data_organization.search_fovs_in_folders(data_folder)
# analysis folder
analysis_folder = os.path.join(data_folder, 'Analysis')

In [None]:
color_usage_df = data_organization.Color_Usage(os.path.join(analysis_folder, "Color_Usage.csv"))

In [None]:
save_folder = analysis_folder
save_filenames = [os.path.join(save_folder, _fl) for _fl in os.listdir(save_folder)
                  if _fl.split(os.extsep)[-1]=='hdf5']
import re
match = re.match(r'.*_([0-9]+).hdf5', os.path.basename(save_filenames[0]), ).groups()[0]
save_filenames = [_f for _f in sorted(save_filenames, 
                                      key=lambda _v:int(re.match(r'.*_([0-9]+).hdf5', os.path.basename(_v), ).groups()[0]))]

# extract fov_id
save_fov_ids = [int(os.path.basename(_fl).split('.hdf5')[0].split('_')[-1]) for _fl in save_filenames]

debug = False

print(f"{len(save_filenames)} fovs detected")


segmentation_folder = os.path.join(analysis_folder, 'Segmentation')
if not os.path.exists(segmentation_folder):
    os.makedirs(segmentation_folder)
    print(f"Creating segmentation_folder: {segmentation_folder}")
else:
    print(f"Use segmentation_folder: {segmentation_folder}")

cand_spot_folder = os.path.join(analysis_folder, 'CandSpots')
if not os.path.exists(cand_spot_folder):
    os.makedirs(cand_spot_folder)
    print(f"Creating cand_spot_folder: {cand_spot_folder}")
else:
    print(f"Use cand_spot_folder: {cand_spot_folder}")

decoder_folder = cand_spot_folder.replace('CandSpots', 'Decoder')
if debug:
    _version = 0
    while os.path.exists(os.path.join(decoder_folder, f'v{_version}')):
        _version += 1
    decoder_folder = os.path.join(decoder_folder, f'v{_version}')
if not os.path.exists(decoder_folder):
    os.makedirs(decoder_folder)
    print(f"Creating decoder_folder: {decoder_folder}")
else:
    print(f"Use decoder_folder: {decoder_folder}")

In [None]:
from ChromAn.src.file_io.image_crop import generate_neighboring_crop,crop_neighboring_area
from scipy.stats import scoreatpercentile
import matplotlib.pyplot as plt
def _rescaling(im, vmin=None, vmax=None):
    if vmin is None:
        vmin = np.min(im)
    if vmax is None:
        vmax = np.max(im)
    _res_im = np.clip(im, vmin, vmax)
    _res_im = (_res_im - vmin) / (vmax - vmin)
    _res_im = (_res_im * np.iinfo(np.uint8).max ).astype(np.uint8)
    return _res_im

def rescale_by_percentile(im, min_max_percent=[30,99.95]):
    from scipy.stats import scoreatpercentile
    vmin, vmax = scoreatpercentile(im, min(min_max_percent)), scoreatpercentile(im, max(min_max_percent))
    return _rescaling(im, vmin=vmin, vmax=vmax)

In [None]:
color_usage_df

In [None]:
# load images:
_save_filename = save_filenames[4]

with h5py.File(_save_filename, 'r') as _f:
    _merfish_spots = [_s[_s[:,0] > 0] for _s in _f['merfish']['spots'][:]]
    _merfish_ids = _f['merfish']['ids'][:]
    _merfish_ims = _f['merfish']['ims'][:]
    _merfish_drifts = _f['merfish']['drifts'][:]
    
    _rna_spots = [_s[_s[:,0] > 0] for _s in _f['rna']['spots'][:]]
    _rna_ids = _f['rna']['ids'][:]
    _rna_ims = _f['rna']['ims'][:]
    _rna_drifts = _f['rna']['drifts'][:]
    _rna_channels = _f['rna']['channels'][:]
    _dapi_im = _f.attrs['dapi_im']
    print(_f.attrs.keys())

In [None]:
_rna_channels[10]

In [None]:
## quick visualization:
%matplotlib notebook
interactive.imshow_mark_3d(_merfish_ims, image_names=_merfish_ids)

In [None]:
_rna_ids[10]

In [None]:
interactive.imshow_mark_3d([_merfish_ims[-1], _rna_ims[10]])

In [None]:
_rna_spots[12]

In [None]:
## quick visualization:
%matplotlib notebook
interactive.imshow_mark_3d(_rna_ims, image_names=_rna_ids)

In [None]:
## mass plot
%matplotlib inline
spot_ids = np.arange(len(_merfish_spots[list(_merfish_ids).index(53)]))

for _sid in spot_ids[:60]:
    sel_center = _merfish_spots[list(_merfish_ids).index(53)][_sid,1:4]

    crop = generate_neighboring_crop(sel_center, 150, single_im_size=np.array(_dapi_im.shape))
    sel_local_dapi_im = _dapi_im[crop.to_slices()]
    sel_local_im = _merfish_ims[list(_merfish_ids).index(53)][crop.to_slices()]
    _sel_ims = np.array([sel_local_im, np.zeros(np.shape(sel_local_im)), sel_local_dapi_im])
    sel_im_proj = np.array([rescale_by_percentile(_img.max(0)) 
                            for _img in _sel_ims]).transpose(1,2,0)
    #intbc_ids = _ids[_ids <= 21]
    edit_ids = _rna_ids[(_rna_ids >= 97) & (_rna_ids <= 117)]
    # crop edit images:
    edit_ims = []
    for _id in edit_ids:
        _idx = list(_rna_ids).index(_id)
        # new center
        _im, _dft = _rna_ims[_idx], _rna_drifts[_idx]
        _local_im = crop_neighboring_area(_im, sel_center-_dft, 15)
        edit_ims.append(_local_im)
    
    fig, axes = plt.subplots(1, len(edit_ims), sharex=True, sharey=True, figsize=(len(edit_ims),1.5), dpi=150)
    plt.subplots_adjust(wspace=0.1, hspace=0.1)
    for _iax, ax in enumerate(np.ravel(axes)):
        ax.imshow(edit_ims[_iax].max(0), 
                  vmin=scoreatpercentile(edit_ims, 10), 
                  vmax=np.max(edit_ims), cmap='Greys_r')
        ax.set_axis_off()
        #if _iax < 8:
        ax.set_title(f"{edit_ids[_iax]}", fontsize=10)
        #else:
        #    pass
            #ax.set_title("unedited", fontsize=10)
    #fig.show()

In [None]:
# site 1 image
%matplotlib inline
fig, axes = plt.subplots(1, len(edit_ims), sharex=True, sharey=True, figsize=(len(edit_ims),1.5), dpi=150)
plt.subplots_adjust(wspace=0.1, hspace=0.1)
for _iax, ax in enumerate(np.ravel(axes)):
    ax.imshow(edit_ims[_iax].max(0), 
              vmin=scoreatpercentile(edit_ims, 10), 
              vmax=np.max(edit_ims), cmap='Greys_r')
    ax.set_axis_off()
    if _iax < 8:
        ax.set_title(f"{_iax+1}", fontsize=10)
    else:
        pass
        #ax.set_title("unedited", fontsize=10)
fig.show()

In [None]:
%matplotlib inline

## Load images
_save_filename = save_filenames[2]

for _save_filename in save_filenames:
    try:
        print(_save_filename)
        with h5py.File(_save_filename, 'r') as _f:
            _all_spots = _f['merfish']['spots'][:]
            _spots_list = [_s[_s[:,0] > 0] for _s in _f['merfish']['spots'][:]]
            _ids = _f['merfish']['ids'][:]
            _ims = _f['merfish']['ims'][:]
            _drifts = _f['merfish']['drifts'][:]
            _dapi_im = _f.attrs['dapi_im']
            print(_f.attrs.keys())

        ## mass plot
        from ChromAn.src.file_io.image_crop import generate_neighboring_crop
        from scipy.stats import scoreatpercentile

        figure_folder = os.path.join(analysis_folder, 'saved_figures_20231012')
        if not os.path.exists(figure_folder):
            os.makedirs(figure_folder)

        # crop intbc images:
        spot_ids = np.arange(len(_spots_list[list(_ids).index(53)]))
        for _sid in spot_ids:
            sel_center = _spots_list[list(_ids).index(53)][_sid,1:4]
            
            crop = generate_neighboring_crop(sel_center, 150, single_im_size=np.array(_dapi_im.shape))
            sel_local_dapi_im = _dapi_im[crop.to_slices()]
            sel_local_im = _ims[list(_ids).index(53)][crop.to_slices()]
            _sel_ims = np.array([sel_local_im, np.zeros(np.shape(sel_local_im)), sel_local_dapi_im])
            sel_im_proj = np.array([rescale_by_percentile(_img.max(0)) 
                                    for _img in _sel_ims]).transpose(1,2,0)

            #intbc_ids = _ids[_ids <= 21]
            edit_ids = _ids[(_ids > 21) & (_ids < 52)]

            # crop edit images:
            edit_ims = []
            for _id in edit_ids:
                _idx = list(_ids).index(_id)
                # new center
                _im, _dft = _ims[_idx], _drifts[_idx]
                _local_im = crop_neighboring_area(_im, sel_center-_dft, 15)
                edit_ims.append(_local_im)
                #break
            # site 1 image
            site1_edit_ims = np.array(edit_ims[0::3])
            fig, axes = plt.subplots(1,len(site1_edit_ims)-1, sharex=True, sharey=True, figsize=(4,1.2), dpi=150)
            plt.subplots_adjust(wspace=0.1, hspace=0.1)
            for _iax, ax in enumerate(np.ravel(axes)):
                ax.imshow(site1_edit_ims[_iax].max(0), 
                          vmin=scoreatpercentile(site1_edit_ims, 0), 
                          vmax=np.max(site1_edit_ims), cmap='Greys_r')
                ax.set_axis_off()
                if _iax < 8:
                    ax.set_title(f"{_iax+1}", fontsize=10)
                else:
                    pass
                    #ax.set_title("unedited", fontsize=10)

            fig.suptitle(f'Emx1 Edits, spot:{_sid}', fontsize=12)
            fig.savefig(os.path.join(figure_folder,
                                    os.path.basename(_save_filename).replace('.hdf5',f'_Emx1edits_decode_{_sid}.png')),
                        transparent=True,
                       )
            plt.show()
            # dapi image
            fig,ax = plt.subplots(figsize=(4,4), dpi=150)
            ax.imshow(sel_im_proj)
            ax.set_title(f"R:SV40, B:DAPI", fontsize=10)
            ax.set_axis_off()
            fig.savefig(os.path.join(figure_folder,
                                    os.path.basename(_save_filename).replace('.hdf5',f'_Emx1edits_sv40Dapi_{_sid}.png')),
                        transparent=True,
                       )
            plt.show()
    except:
        print(f"Fail for savefile: {_save_filename}")

In [None]:
sv40_centers = _spots_list[list(_ids).index(53)][:,1:4]
sv40_im = _ims[list(_ids).index(53)]

In [None]:
interactive.imshow_mark_3d([_dapi_im, sv40_im], image_names=['DAPI', 'SV40'])

In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt
plt.figure()
for _spots in _spots_list:
    plt.scatter(_spots[:,3],_spots[:,2], s= 1)

In [None]:
# crop intbc images:
sel_center = _spots_list[list(_ids).index(53)][12,1:4]
sel_local_im = crop_neighboring_area(_ims[list(_ids).index(53)], sel_center, 100)

intbc_ids = _ids[_ids <= 21]
edit_ids = _ids[(_ids > 21) & (_ids < 52)]

from ChromAn.src.file_io.image_crop import crop_neighboring_area
intbc_ims = []
for _id in intbc_ids:
    _idx = list(_ids).index(_id)
    # new center
    _im, _dft = _ims[_idx], _drifts[_idx]
    _local_im = crop_neighboring_area(_im, sel_center-_dft, 15)
    intbc_ims.append(_local_im)
    #break
# crop edit images:
edit_ims = []
for _id in edit_ids:
    _idx = list(_ids).index(_id)
    # new center
    _im, _dft = _ims[_idx], _drifts[_idx]
    _local_im = crop_neighboring_area(_im, sel_center-_dft, 15)
    edit_ims.append(_local_im)
    #break

In [None]:
fig, ax = plt.subplots()
ax.imshow(sel_local_im.max(0))

interactive.imshow_mark_3d([sel_local_im] + intbc_ims, image_names=['ref'] + list(intbc_ids))

In [None]:
fig, axes = plt.subplots(3,7)
for _iax, ax in enumerate(np.ravel(axes)):
    ax.imshow(intbc_ims[_iax].max(0), vmin=4000, vmax=20000)
    ax.set_title(f"{_iax+1}")
fig.suptitle('Integration barcode')

In [None]:
fig, axes = plt.subplots(3,9)
for _iax, ax in enumerate(np.ravel(axes)):
    ax.imshow(edit_ims[_iax].max(0) , vmin=4000, vmax=20000)
fig.suptitle('Edits')

In [None]:
from ChromAn.src.file_io.image_crop import generate_neighboring_crop

In [None]:
def _rescaling(im, vmin=None, vmax=None):
    if vmin is None:
        vmin = np.min(im)
    if vmax is None:
        vmax = np.max(im)
    _res_im = np.clip(im, vmin, vmax)
    _res_im = (_res_im - vmin) / (vmax - vmin)
    _res_im = (_res_im * np.iinfo(np.uint8).max ).astype(np.uint8)
    return _res_im

def rescale_by_percentile(im, min_max_percent=[30,99.95]):
    from scipy.stats import scoreatpercentile
    vmin, vmax = scoreatpercentile(im, min(min_max_percent)), scoreatpercentile(im, max(min_max_percent))
    return _rescaling(im, vmin=vmin, vmax=vmax)

In [None]:
%matplotlib inline
site1_edit_ims = np.array(edit_ims[0::3])

fig, axes = plt.subplots(1,len(site1_edit_ims), sharex=True, sharey=True, figsize=(4,1.2), dpi=150)
plt.subplots_adjust(wspace=0.1, hspace=0.1)
for _iax, ax in enumerate(np.ravel(axes)):
    ax.imshow(site1_edit_ims[_iax].max(0), 
              vmin=scoreatpercentile(site1_edit_ims, 0), 
              vmax=np.max(site1_edit_ims), cmap='Greys_r')
    ax.set_axis_off()
    if _iax < 8:
        ax.set_title(f"{_iax+1}", fontsize=10)
    else:
        ax.set_title("unedited", fontsize=10)
    
fig.suptitle('Emx1 Edits', fontsize=12)
plt.show()


In [None]:
%matplotlib inline
# mass plot
from ChromAn.src.file_io.image_crop import generate_neighboring_crop
from scipy.stats import scoreatpercentile

figure_folder = os.path.join(analysis_folder, 'saved_figures_20231012')
if not os.path.exists(figure_folder):
    os.makedirs(figure_folder)

# crop intbc images:
spot_id = np.arange(30)
for _sid in spot_id:
    sel_center = _spots_list[list(_ids).index(53)][_sid,1:4]

    
    crop = generate_neighboring_crop(sel_center, 150, single_im_size=np.array(_dapi_im.shape))
    sel_local_dapi_im = _dapi_im[crop.to_slices()]
    sel_local_im = _ims[list(_ids).index(53)][crop.to_slices()]
    _sel_ims = np.array([sel_local_im, np.zeros(np.shape(sel_local_im)), sel_local_dapi_im])
    sel_im_proj = np.array([rescale_by_percentile(_img.max(0)) 
                            for _img in _sel_ims]).transpose(1,2,0)

    #intbc_ids = _ids[_ids <= 21]
    edit_ids = _ids[(_ids > 21) & (_ids < 52)]

    # crop edit images:
    edit_ims = []
    for _id in edit_ids:
        _idx = list(_ids).index(_id)
        # new center
        _im, _dft = _ims[_idx], _drifts[_idx]
        _local_im = crop_neighboring_area(_im, sel_center-_dft, 15)
        edit_ims.append(_local_im)
        #break
    # site 1 image
    site1_edit_ims = np.array(edit_ims[0::3])
    fig, axes = plt.subplots(1,len(site1_edit_ims), sharex=True, sharey=True, figsize=(4,1.2), dpi=150)
    plt.subplots_adjust(wspace=0.1, hspace=0.1)
    for _iax, ax in enumerate(np.ravel(axes)):
        ax.imshow(site1_edit_ims[_iax].max(0), 
                  vmin=scoreatpercentile(site1_edit_ims, 0), 
                  vmax=np.max(site1_edit_ims), cmap='Greys_r')
        ax.set_axis_off()
        if _iax < 8:
            ax.set_title(f"{_iax+1}", fontsize=10)
        else:
            ax.set_title("unedited", fontsize=10)

    fig.suptitle(f'Emx1 Edits, spot:{_sid}', fontsize=12)
    fig.savefig(os.path.join(figure_folder,
                            os.path.basename(_save_filename).replace('.hdf5',f'_Emx1edits_decode_{_sid}.png')),
                transparent=True,
               )
    plt.show()
    # dapi image
    fig,ax = plt.subplots(figsize=(4,4), dpi=150)
    ax.imshow(sel_im_proj)
    ax.set_title(f"R:SV40, B:DAPI", fontsize=10)
    ax.set_axis_off()
    fig.savefig(os.path.join(figure_folder,
                            os.path.basename(_save_filename).replace('.hdf5',f'_Emx1edits_sv40Dapi_{_sid}.png')),
                transparent=True,
               )
    plt.show()

In [None]:
%matplotlib notebook
interactive.imshow_mark_3d([sel_local_im] + edit_ims, image_names=['ref'] + list(edit_ids))

In [None]:
np.sum(_ids <= 21)