In [None]:
# the first GUI is for observing first pass clicking outputs (no refit), with edits 240513 for streamlined QC
# use cnmf-qc folder as intermediate site - rename folders with _qc and _qc_done during processing
# goal is to look through every component one by one and add points to split/resurrect components
# then reclick them with cleangui-qc, refit on the ds, and then perform mask matching with minimal additional QC
# after all that, perform full refit with raw on matched subset of components

In [37]:
%matplotlib qt
data_dir = input("Enter the data directory:")

# Parameters (unlikely to change)
n_range_lim = 10 # size of n_range below which SNR considered unreliable
Athresh = 0.05 # overlap threshold - automatically split anything below it
cr_thresh = 0.9 # component-raw correlation threshold below which component deemed suspicious quality
pb_thresh = 0.95 # component-best parent correlation threshold above which component deemed likely merge

Enter the data directory: H:\CNMFoutputs\charlotte\mask_matching\231030_DG60PL67_charlotte_done\231030_DG60PL67_mouse_Yael_Done_qc_doneqc


In [38]:
%%time
## LOADING EVERYTHING UP - TAKES ~20 sec

# load packages

import napari
from magicgui import magicgui, widgets
import time

from IPython import get_ipython
from IPython.display import clear_output
import os
import matplotlib.pyplot as plt
import numpy as np
from scipy.sparse import csc_matrix
from scipy import signal as sg
import scipy
import pickle

from tifffile.tifffile import imwrite,imread
from tqdm.auto import tqdm,trange

from copy import deepcopy
import h5py

import caiman as cm
from caiman.source_extraction.cnmf import cnmf,params
from caiman.paths import caiman_datadir
from caiman.utils.visualization import get_contours

try:
    if __IPYTHON__:
        get_ipython().run_line_magic('load_ext', 'autoreload')
        get_ipython().run_line_magic('autoreload', '2')
except NameError:
    pass

def load_pickle(file_path):
    """
    Load a dictionary from a pickle file.

    Args:
    - file_path (str): Path to the pickle file.

    Returns:
    - dict: Loaded dictionary.
    """
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    return data

## Loading all the inputs
os.chdir(data_dir)
#cnmf_path = caiman_datadir()+'/example_movies/demoMovie3DYxxbnobg_20240318170305_cnmf.hdf5'
cnmf_path = os.path.join(data_dir, 'ch0_means_movie_nobg_cnmf.hdf5')

# CNMFE model
cnmf_model = cnmf.load_CNMF(cnmf_path, 
                            n_processes=1,
                            dview=None)
print(f"Successfully loaded CNMF model")

mc_memmapped_fname = [i for i in os.listdir() if 'memmap__' in i][0]
Yr, dims, T = cm.load_memmap(mc_memmapped_fname)
images = np.array(np.reshape(Yr.T, [T] + list(dims), order='F')) 
print(f"Successfully loaded data")

#d = cnmf_model.estimates.A.shape[0]
#dims = cnmf_model.estimates.dims
#axis = 2
#order = list(range(4))
#order.insert(0, order.pop(axis))
#index_permut = np.reshape(np.arange(d), dims, order='F').transpose(
#        order[:-1]).reshape(d, order='F')
#A = csc_matrix(cnmf_model.estimates.A)[index_permut, :]
#dims = tuple(np.array(dims)[order[:3]])
#d1, d2, d3 = dims
#nr, T = cnmf_model.estimates.C.shape
#image_cells = np.array(A.mean(axis=1)).reshape(dims, order='F')
#coors = get_contours(A, dims, thr=Cthr)
coors = load_pickle(os.path.join(data_dir, 'ch0_means_movie_nobg_coors.pickle'))
print(f"Successfully loaded contours")

cc = [[l for l in n['coordinates']] for n in coors] # x,y values of contour coordinates for each component
cc1 = [[(l[:, 0]) for l in n['coordinates']] for n in coors] # x values of contour coordinates for each component
cc2 = [[(l[:, 1]) for l in n['coordinates']] for n in coors] # y values of contour coordinates for each component
length = np.ravel([list(map(len, cc)) for cc in cc1])
shapes = [[np.vstack([np.append(i,np.flip(pt)) for pt in cc[j][i]]) for i in range(len(cc[j]))] for j in range(len(cc))]

# Line up all static inputs
SNRs = cnmf_model.estimates.SNR_comp
SNR_min = cnmf_model.estimates.SNRmin
SOL = np.argsort(-SNRs)
spcomps = np.reshape(cnmf_model.estimates.A.toarray(),cnmf_model.estimates.dims + (-1,),order='F')
spcomps = spcomps.transpose([3,2,0,1])
images2 = images.transpose([0,3,1,2])
#SOL = np.argsort(-cnmf_model.estimates.SNR_comp) 
C = cnmf_model.estimates.C
CY = cnmf_model.estimates.C + cnmf_model.estimates.YrA # temporal loadings
R = cnmf_model.estimates.Craw # masks applied to raw movie

CYsav = cnmf_model.estimates.CYsav # smoothened CY curve
def sav_calc(sraw):
    return sg.savgol_filter(sraw,3,1)
Rsav = np.zeros(R.shape)
for i in range(R.shape[0]):
    Rsav[i,:] = sav_calc(R[i,:])
CYsavsort = np.sort(CYsav,axis=1)
CYsavb10 = np.mean(CYsavsort[:,:int(np.ceil(CYsavsort.shape[1]/10))],axis=1) # bottom 10% mean
Rsavsort = np.sort(Rsav,axis=1)
Rsavb10 = np.mean(Rsavsort[:,:int(np.ceil(Rsavsort.shape[1]/10))],axis=1) # bottom 10% mean
n_range = cnmf_model.estimates.n_range
if n_range is None:
    CYf = CYsavb10
    Rf = Rsavb10
else:
    CYf = np.mean(CY[:,n_range],axis=1)
    Rf = np.mean(R[:,n_range],axis=1)

Cn = cnmf_model.estimates.Cn # correlation image (not necessary)
keepargs = cnmf_model.estimates.keepargs
SOL = np.array([x for x in list(SOL) if x in list(keepargs)]) # initial search order list
print(f"Successfully generated search list")

A1 = csc_matrix(cnmf_model.estimates.A)
nr = A1.shape[1]
A_corr = scipy.sparse.triu(A1.T * A1)
A_corr.setdiag(0)
A_corr = A_corr.tocsc()
C_corr = scipy.sparse.lil_matrix(A_corr.shape)
for ii in range(nr):
    overlap_indices = scipy.sparse.find(A_corr[ii, :])[1][scipy.sparse.find(A_corr[ii, :])[2]>Athresh]
    if len(overlap_indices) > 0:
            # we chesk the correlation of the calcium traces for each overlapping components
        corr_values = [scipy.stats.pearsonr(C[ii, :], C[jj, :])[
            0] for jj in overlap_indices]
        C_corr[ii, overlap_indices] = corr_values
C_tot = C_corr + C_corr.T
CYR_corr = np.zeros(nr)
for ii in range(nr):
    CYR_corr[ii] = scipy.stats.pearsonr(R[ii,:],CY[ii,:])[0]
print(f"Successfully calculated correlations")

# Initialize all running variables in a single dictionary - just lists of arguments/component IDs (SOL, CSL, CMG, CKG, saved merge groups, trash)
# check if saved file exists - load that if it does, else instantiate new vars_dict!
save_path = 'ch0_means_movie_nobg_compfilt.pickle'
ori_path = 'ch0_means_movie_nobg_compfilt_ori.pickle'
if save_path in os.listdir():
    vars_dict1 = load_pickle(save_path)
# duplicate original version if not already saved
if ori_path not in os.listdir():
    with open(ori_path, 'wb') as f:
        pickle.dump(vars_dict1, f)
print(f"Successfully initialized")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Successfully loaded CNMF model
Successfully loaded data
Successfully loaded contours
Successfully generated search list
Successfully calculated correlations
Successfully initialized
CPU times: total: 16.5 s
Wall time: 22.8 s


In [39]:
%%time
spcomps_merge = np.asarray([np.sum(spcomps[vars_dict1["SMG"][component],...],axis=0) for component in range(len(vars_dict1["SMG"]))])
spcomps_merge_e = np.concatenate((np.zeros(spcomps[0,...][np.newaxis,...].shape,dtype=np.float32),spcomps_merge))
spcomps_merge_e2 = np.argmax(spcomps_merge_e,axis=0)

CPU times: total: 18 s
Wall time: 18 s


In [40]:
# allow split/resurrect points and trash points (e.g. if the timecourse looks really bad)
qcdata_path = 'ch0_means_movie_nobg_compqc.npy'
if qcdata_path in os.listdir():
    qc_data = np.load(qcdata_path,allow_pickle=True)[()]
else:
    qc_data = {
        "qc_spots": np.empty((0,3)).astype('int'),
        "trash_spots": np.empty((0,3)).astype('int')
    }

In [41]:
# Initialize viewer and start GUI
viewer = napari.Viewer()
viewer.add_image(images.transpose([0,3,1,2]),name='cells',colormap='gray')

global fig1, ax11, ax12, leg1
fig1, ax11 = plt.subplots() # dF/F
ax12 = ax11.twinx()
fig1.suptitle('Fluorescence signal plots')
ax11.plot(CY[0,:].T/100,c='green',label='initialization')
leg1 = fig1.legend(loc="upper left")
###
# Key bindings to speed up selections
@viewer.bind_key('n')
def pressN(viewer):
    clickN()

@viewer.bind_key('p')
def pressP(viewer):
    clickP()

@viewer.bind_key('b')
def pressB(viewer):
    clickP()

@viewer.bind_key('c')
def pressC(viewer):
    viewer.layers['all components'].visible = not viewer.layers['all components'].visible

@viewer.bind_key('v')
def pressV(viewer):
    viewer.layers['component contours'].visible = not viewer.layers['component contours'].visible

@viewer.bind_key('Up')
def jump_up(viewer):
    viewer.dims.set_current_step(1, viewer.dims.current_step[1] - 1)

@viewer.bind_key('Down')
def jump_down(viewer):
    viewer.dims.set_current_step(1, viewer.dims.current_step[1] + 1)

@viewer.bind_key('Left')
def jump_left(viewer):
    viewer.dims.set_current_step(0, viewer.dims.current_step[0] - 1)

@viewer.bind_key('Right')
def jump_right(viewer):
    viewer.dims.set_current_step(0, viewer.dims.current_step[0] + 1)
    
@magicgui(
    auto_call=True,btn={"widget_type": "PushButton", "text": "Next component"}
)
def click_next(btn):
    clickN()

@magicgui(
    auto_call=True,btn={"widget_type": "PushButton", "text": "Previous component"}
)
def click_prev(btn):
    clickP()

# Save and close button inside napari
@magicgui(
    auto_call=True,btn={"widget_type": "PushButton", "text": "Save and close GUI"}
)
def save_btn(btn):
    qc_data["qc_spots"] = viewer.layers['qc spots'].data.astype('int')
    qc_data["trash_spots"] = viewer.layers['trash spots'].data.astype('int')
    np.save(qcdata_path,qc_data)
    plt.close('all')
    viewer.close()

# Aligning widgets
layout = widgets.Container(
    widgets=[click_next,click_prev,save_btn], layout="vertical", labels=False
)
###
@magicgui(
    auto_call=True,
    component={"widget_type": "Slider", "min": 1, "max": len(vars_dict1["SMG"]), "step": 1, "orientation": "vertical"},
    layout="horizontal",
)
def show_comp(component: int = 1, init: bool = False):
    comp = component
    try:
        viewer.layers.remove('component')
    except:
        pass
    try:
        viewer.layers.remove('component contours')
    except:
        pass
    plt.sca(ax11)
    plt.cla()
    plt.sca(ax12)
    plt.cla()
    ax11.set_xlabel('Frame')
    ax11.set_ylabel('Signal')
    ax11.tick_params(axis='y', labelcolor='green')
    ax12.tick_params(axis='y', labelcolor='darkorange')
    viewer.add_image(np.sum(spcomps[vars_dict1["SMG"][component-1],...],axis=0),name='component',colormap='green',opacity=1,blending='additive',visible=False)
    viewer.add_points(np.vstack([g for g in [v[~np.isnan(v).any(axis=1)] for v in [j for k in [shapes[i] for i in vars_dict1["SMG"][component-1]] for j in k]] if g.size>0]),name='component contours',symbol='disc',size=2,face_color='lime',visible=True)
    viewer.camera.center = coors[vars_dict1["SMG"][component-1][0]].get('CoM')
    viewer.dims.set_point(1,coors[vars_dict1["SMG"][component-1][0]].get('CoM')[0])
    viewer.camera.zoom = 3
    if not init:
        viewer.layers.move(3,6)
        viewer.layers.selection.active = viewer.layers[5]
    ax11.plot(CY[vars_dict1["SMG"][component-1],:].T/100,c='green',label='components')
    ax12.plot(np.mean(R[vars_dict1["SMG"][component-1],:],axis=0).T,c='darkorange',ls='--',label='components raw')
    global leg1
    leg1.remove()
    leg1 = fig1.legend(loc="upper left")
    fig1.canvas.draw_idle()

def clickN():
    show_comp.component.value = show_comp.component.value + 1

def clickP():
    show_comp.component.value = show_comp.component.value - 1

viewer.window.add_dock_widget(show_comp)
viewer.window.add_dock_widget(layout)
show_comp(init=True)
viewer.add_labels(spcomps_merge_e2,name='all components')
viewer.add_points(data=qc_data["trash_spots"],name='trash spots',ndim=3,face_color='red',size=6,out_of_slice_display=True)
viewer.add_points(data=qc_data["qc_spots"],name='qc spots',ndim=3,face_color='lime',size=6,out_of_slice_display=True)

<Points layer 'qc spots' at 0x24f6e8651e0>

In [25]:
viewer.layers['cells'].scale = [1,4,1,1]
viewer.layers['all components'].scale = [4,1,1]

In [6]:
print(qc_data)

{'qc_spots': array([[  5, 273,  28],
       [  5, 293,  38],
       [  5, 220,   7],
       ...,
       [ 14, 387, 370],
       [ 14, 406, 401],
       [ 14, 417, 401]]), 'trash_spots': array([], shape=(0, 3), dtype=int32)}


In [7]:
# once all components have been evaluated, run this to update the vars_dict for downstream qc
# i.e. this should 1. trash the groups with trash spots, and then 2. resurrect all components with qc spots (including merged groups)
qc_spotcomps = []
for r in range(0,qc_data['qc_spots'].shape[0]):
    qc_spotcomps = qc_spotcomps + np.where(spcomps[:,qc_data['qc_spots'][r,0],qc_data['qc_spots'][r,1],qc_data['qc_spots'][r,2]]!=0)[0].tolist()

qc_spotcomps2 = []
tr = [qc_spotcomps2.append(x) for x in qc_spotcomps if x not in qc_spotcomps2]

qc_notres = [x for i in range(len(vars_dict1['SMG'])) for x in vars_dict1['SMG'][i]]# + vars_dict1['trash']

qc_spotcomps_res = [x for x in qc_spotcomps2 if x not in qc_notres]

qc_spotcomps_split_mrgcomp = []
for r in range(0,qc_data['qc_spots'].shape[0]):
    comp = spcomps_merge_e2[qc_data['qc_spots'][r,0],qc_data['qc_spots'][r,1],qc_data['qc_spots'][r,2]]
    if comp != 0:
        qc_spotcomps_split_mrgcomp.append(comp-1)
qc_spotcomps_split = [x for i in qc_spotcomps_split_mrgcomp for x in vars_dict1['SMG'][i]]
qc_spotcomps_all = qc_spotcomps_res + qc_spotcomps_split

trash_spotcomps_mrgcomp = []
for r in range(0,qc_data['trash_spots'].shape[0]):
    comp = spcomps_merge_e2[qc_data['trash_spots'][r,0],qc_data['trash_spots'][r,1],qc_data['trash_spots'][r,2]]
    if comp != 0:
        trash_spotcomps_mrgcomp.append(comp-1)

ref_SOL = np.argsort(-cnmf_model.estimates.SNR_comp)
ref_SOL = np.array([x for x in list(ref_SOL) if x in [x for x in qc_spotcomps_all if CYsavb10[x]>cnmf_model.estimates.Sigmin]]) # initial search order list
print(ref_SOL)
print(len(ref_SOL))

vars_dict1['SOL'] = list(ref_SOL)
vds = [vars_dict1['SMG'][i] for i in [j for j in range(len(vars_dict1['SMG'])) if j not in (trash_spotcomps_mrgcomp + qc_spotcomps_split_mrgcomp)]]
vars_dict1['SMG'] = vds
print(vars_dict1)

[724 198 380 604 482 409 627 544 223 134  54 204 565  21 357  17 185 615
 521 180  26 580 507 730 729 767  52 893 748 310 737 734 475 826 162 578
 744 888 370  14 691 158 338 494 621 433 736 638 358  10 859 836  25 350
 379 557 371 339 778 192   2 344 295 709 851 352 883 353 317 264 740 571
 133 355 208 882 474 596 189 312 885 487 746 890 503 488 168 351 478 529
 156 167 769 551  39 130 519  13 693 838 815 577 207 579 172 583 622 511
 739 169 418 658 830 527 288 887 894 176 411 594 135 531 852 152 807  48
 368 151 140 159 628 650 875 590 294 555 141 702 897 164 163  29 862 570
  61 508 735 589 174 624 754 863 731 401 827 321 743 523 792 206 165 193
 595 328 625 683 139 346 157 177 547 173 569  57 536 524 522 534 118  82
 566 309   7 528 764  42 476 591 423  79 821 501 564 533 486 202 761 452
 458  72 495 833 274 647 286 337 472 425 672 710 132  30 879 315 669  95
 793  19 247 758 774  66  84 251 585 387 753 613 145 804 567  58 703 525
 282 878  70 538 191 786 750 257 689 608 349  81 32

In [8]:
# for qc, conditional on nonzero SOL above
if len(ref_SOL) != 0:
    save_path = 'ch0_means_movie_nobg_compfilt.pickle'
    with open(save_path, 'wb') as f:
        pickle.dump(vars_dict1, f)

In [18]:
vars_dict1

{'SOL': [],
 'CSL': [],
 'CMG': [],
 'CKG': [],
 'SMG': [[451, 269, 458, 274, 280, 466],
  [448, 268],
  [482, 304, 502, 271, 454],
  [260, 242, 422],
  [236],
  [409, 228, 212, 393, 233],
  [188, 184, 376],
  [369],
  [113],
  [511, 487, 534],
  [525, 526, 707],
  [522, 688],
  [497, 694],
  [684],
  [78],
  [625, 447],
  [425, 475, 456],
  [647, 621],
  [636, 819],
  [831],
  [302],
  [389, 419],
  [165],
  [125],
  [142, 117, 128, 134],
  [457, 485],
  [473, 501],
  [97, 63, 76],
  [318, 313, 285, 316, 500, 289],
  [311, 314, 345],
  [237, 246],
  [424, 421, 455, 453],
  [443],
  [232],
  [216, 244, 213, 247],
  [479, 490, 641, 664],
  [528],
  [529, 709],
  [150, 168, 169, 154],
  [161, 153],
  [480, 492, 516, 336, 355],
  [273, 243, 278, 252],
  [164],
  [158, 156],
  [347],
  [86],
  [608, 437],
  [623, 430, 605, 396, 411],
  [431, 412],
  [614],
  [423],
  [70],
  [590],
  [575, 604, 602, 780, 597],
  [610, 603, 786],
  [756],
  [616, 587],
  [203],
  [370],
  [367, 358, 579, 57

In [None]:
# fixing 20230824_SL1_wells001, 220622_PK44_wells001 duplicate components

In [23]:
len(vars_dict1['SMG'])

202

In [8]:
ln = np.sum(np.array([len(comp) for comp in vars_dict1['SMG']]))
ids = set.union(*[set(comp) for comp in vars_dict1['SMG']])

In [9]:
ln

410

In [10]:
len(ids)

409

In [11]:
import collections
a = [x for y in vars_dict1['SMG'] for x in y]
b = [item for item, count in collections.Counter(a).items() if count > 1]

In [12]:
[x for x in vars_dict1['SMG'] if len(set(x) & set(b)) != 0]

[[-1], [-1]]

In [14]:
np.where(spcomps[:,7,366,414]!=0)[0].tolist()

[892]

In [17]:
[x for x in vars_dict1['SMG'] if 892 in x]

[]

In [19]:
cnmf_model.estimates.C.shape

(893, 165)

In [20]:
[i for i in range(len(vars_dict1['SMG'])) if len(set(vars_dict1['SMG'][i]) & set(b)) != 0]

[142, 143]

In [14]:
vars_dict1['SMG'][21]

[663, 835, 681, 881]

In [27]:
vars_dict1['SMG']

[[451, 269, 458, 274, 280, 466],
 [448, 268],
 [482, 304, 502, 271, 454],
 [260, 242, 422],
 [236],
 [409, 228, 212, 393, 233],
 [188, 184, 376],
 [369],
 [113],
 [511, 487, 534],
 [525, 526, 707],
 [522, 688],
 [497, 694],
 [684],
 [78],
 [625, 447],
 [425, 475, 456],
 [647, 621],
 [636, 819],
 [831],
 [302],
 [389, 419],
 [165],
 [125],
 [142, 117, 128, 134],
 [457, 485],
 [473, 501],
 [97, 63, 76],
 [318, 313, 285, 316, 500, 289],
 [311, 314, 345],
 [237, 246],
 [424, 421, 455, 453],
 [443],
 [232],
 [216, 244, 213, 247],
 [479, 490, 641, 664],
 [528],
 [529, 709],
 [150, 168, 169, 154],
 [161, 153],
 [480, 492, 516, 336, 355],
 [273, 243, 278, 252],
 [164],
 [158, 156],
 [347],
 [86],
 [608, 437],
 [623, 430, 605, 396, 411],
 [431, 412],
 [614],
 [423],
 [70],
 [590],
 [575, 604, 602, 780, 597],
 [610, 603, 786],
 [756],
 [616, 587],
 [203],
 [370],
 [367, 358, 579, 573],
 [398],
 [799],
 [140],
 [82, 53, 44, 47, 33, 58],
 [49],
 [52],
 [174, 157],
 [298, 507, 499, 344, 342],
 [337

In [22]:
[vars_dict1['SMG'][i] for i in range(len(vars_dict1['SMG'])) if i not in [142, 143]]

[[451, 269, 458, 274, 280, 466],
 [448, 268],
 [482, 304, 502, 271, 454],
 [260, 242, 422],
 [236],
 [409, 228, 212, 393, 233],
 [188, 184, 376],
 [369],
 [113],
 [511, 487, 534],
 [525, 526, 707],
 [522, 688],
 [497, 694],
 [684],
 [78],
 [625, 447],
 [425, 475, 456],
 [647, 621],
 [636, 819],
 [831],
 [302],
 [389, 419],
 [165],
 [125],
 [142, 117, 128, 134],
 [457, 485],
 [473, 501],
 [97, 63, 76],
 [318, 313, 285, 316, 500, 289],
 [311, 314, 345],
 [237, 246],
 [424, 421, 455, 453],
 [443],
 [232],
 [216, 244, 213, 247],
 [479, 490, 641, 664],
 [528],
 [529, 709],
 [150, 168, 169, 154],
 [161, 153],
 [480, 492, 516, 336, 355],
 [273, 243, 278, 252],
 [164],
 [158, 156],
 [347],
 [86],
 [608, 437],
 [623, 430, 605, 396, 411],
 [431, 412],
 [614],
 [423],
 [70],
 [590],
 [575, 604, 602, 780, 597],
 [610, 603, 786],
 [756],
 [616, 587],
 [203],
 [370],
 [367, 358, 579, 573],
 [398],
 [799],
 [140],
 [82, 53, 44, 47, 33, 58],
 [49],
 [52],
 [174, 157],
 [298, 507, 499, 344, 342],
 [337

In [24]:
vars_dict1['SMG'] = [vars_dict1['SMG'][i] for i in range(len(vars_dict1['SMG'])) if i not in [142, 143]]
len(vars_dict1['SMG'])

200

In [25]:
vars_dict1['SMG'].append([892])

In [26]:
len(vars_dict1['SMG'])

201

In [28]:
print(vars_dict1)

{'SOL': [], 'CSL': [], 'CMG': [], 'CKG': [], 'SMG': [[451, 269, 458, 274, 280, 466], [448, 268], [482, 304, 502, 271, 454], [260, 242, 422], [236], [409, 228, 212, 393, 233], [188, 184, 376], [369], [113], [511, 487, 534], [525, 526, 707], [522, 688], [497, 694], [684], [78], [625, 447], [425, 475, 456], [647, 621], [636, 819], [831], [302], [389, 419], [165], [125], [142, 117, 128, 134], [457, 485], [473, 501], [97, 63, 76], [318, 313, 285, 316, 500, 289], [311, 314, 345], [237, 246], [424, 421, 455, 453], [443], [232], [216, 244, 213, 247], [479, 490, 641, 664], [528], [529, 709], [150, 168, 169, 154], [161, 153], [480, 492, 516, 336, 355], [273, 243, 278, 252], [164], [158, 156], [347], [86], [608, 437], [623, 430, 605, 396, 411], [431, 412], [614], [423], [70], [590], [575, 604, 602, 780, 597], [610, 603, 786], [756], [616, 587], [203], [370], [367, 358, 579, 573], [398], [799], [140], [82, 53, 44, 47, 33, 58], [49], [52], [174, 157], [298, 507, 499, 344, 342], [337], [698, 678, 51

In [29]:
save_path = 'ch0_means_movie_nobg_compfilt.pickle'
with open(save_path, 'wb') as f:
    pickle.dump(vars_dict1, f)

{'qc_spots': array([[  5,  83,  67],
       [  5,  96,  72],
       [  5, 136,  53],
       ...,
       [  7, 339, 115],
       [ 11,  96, 388],
       [ 12,  66, 352]]), 'trash_spots': array([], shape=(0, 3), dtype=int32)}


[386 566 624 278 227 266 397 426 219 627 857 472 448 533 606 240 308 616
 645 408 482 275  63 688 659 668 686 239 269 418 424 777 622  95  74 436
 396 633  88 468 241 756 383 306 180 457 787 744 609 394 110 129 674 594
  50 340 256 585 206  54 267  46 332 315  86 800 246  91 430 653  64 765
 757 526  44 187 570 325 503 862  72 427 119 880  48 681 445 611 763 334
 293 575 567 329 124 663 388 432 489 584 398 603 665 200 417 280 201 399
 851 285 708  36 805 311 797 734 295 532 707 760 639 630 221 250  84 540
 400 203 458 752 287 669 346 451 519 803 494  34  96 498 163 271 613 351
 850 598 605 413  15 511 865 354 845 156 125 698 192  66 874 876 695 191
 350 664 565 706 691  35 827 685 870 342 463 879 145  52 562 133 339 772
 341 423 253 193 152 357 530 385 780 793 380 148 289  59 111 682 338 582
 615 189 578 290 802 561 233 643 889 326 840 469  49 248 480 788 701 623
  43 855 884 186 467  69 348  39 374 320 579 638 548 327 212 576 528 462
 726 165 382 739 736 464 330 126 259 260 804 516 23

In [None]:
#find the post-merge/refit ds GUI further down!
%matplotlib qt
data_dir = input("Enter the data directory:")

# Parameters (unlikely to change)
n_range_lim = 10 # size of n_range below which SNR considered unreliable
Athresh = 0.05 # overlap threshold - automatically split anything below it
cr_thresh = 0.9 # component-raw correlation threshold below which component deemed suspicious quality
pb_thresh = 0.95 # component-best parent correlation threshold above which component deemed likely merge

In [None]:
%%time
## LOADING EVERYTHING UP - TAKES ~20 sec

# load packages

import napari
from magicgui import magicgui, widgets
import time

from IPython import get_ipython
from IPython.display import clear_output
import os
import matplotlib.pyplot as plt
import numpy as np
from scipy.sparse import csc_matrix
from scipy import signal as sg
import scipy
import pickle

from tifffile.tifffile import imwrite,imread
from tqdm.auto import tqdm,trange

from copy import deepcopy
import h5py

import caiman as cm
from caiman.source_extraction.cnmf import cnmf,params
from caiman.paths import caiman_datadir
from caiman.utils.visualization import get_contours

try:
    if __IPYTHON__:
        get_ipython().run_line_magic('load_ext', 'autoreload')
        get_ipython().run_line_magic('autoreload', '2')
except NameError:
    pass

def load_pickle(file_path):
    """
    Load a dictionary from a pickle file.

    Args:
    - file_path (str): Path to the pickle file.

    Returns:
    - dict: Loaded dictionary.
    """
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    return data

## Loading all the inputs
os.chdir(data_dir)
#cnmf_path = caiman_datadir()+'/example_movies/demoMovie3DYxxbnobg_20240318170305_cnmf.hdf5'
cnmf_path = os.path.join(data_dir, 'ch0_means_movie_nobg_cnmf.hdf5')

# CNMFE model
cnmf_model = cnmf.load_CNMF(cnmf_path, 
                            n_processes=1,
                            dview=None)
print(f"Successfully loaded CNMF model")

mc_memmapped_fname = [i for i in os.listdir() if 'memmap__' in i][0]
Yr, dims, T = cm.load_memmap(mc_memmapped_fname)
images = np.array(np.reshape(Yr.T, [T] + list(dims), order='F')) 
print(f"Successfully loaded data")

#d = cnmf_model.estimates.A.shape[0]
#dims = cnmf_model.estimates.dims
#axis = 2
#order = list(range(4))
#order.insert(0, order.pop(axis))
#index_permut = np.reshape(np.arange(d), dims, order='F').transpose(
#        order[:-1]).reshape(d, order='F')
#A = csc_matrix(cnmf_model.estimates.A)[index_permut, :]
#dims = tuple(np.array(dims)[order[:3]])
#d1, d2, d3 = dims
#nr, T = cnmf_model.estimates.C.shape
#image_cells = np.array(A.mean(axis=1)).reshape(dims, order='F')
#coors = get_contours(A, dims, thr=Cthr)
coors = load_pickle(os.path.join(data_dir, 'ch0_means_movie_nobg_coors.pickle'))
print(f"Successfully loaded contours")

cc = [[l for l in n['coordinates']] for n in coors] # x,y values of contour coordinates for each component
cc1 = [[(l[:, 0]) for l in n['coordinates']] for n in coors] # x values of contour coordinates for each component
cc2 = [[(l[:, 1]) for l in n['coordinates']] for n in coors] # y values of contour coordinates for each component
length = np.ravel([list(map(len, cc)) for cc in cc1])
shapes = [[np.vstack([np.append(i,np.flip(pt)) for pt in cc[j][i]]) for i in range(len(cc[j]))] for j in range(len(cc))]

# Line up all static inputs
SNRs = cnmf_model.estimates.SNR_comp
SNR_min = cnmf_model.estimates.SNRmin
SOL = np.argsort(-SNRs)
spcomps = np.reshape(cnmf_model.estimates.A.toarray(),cnmf_model.estimates.dims + (-1,),order='F')
spcomps = spcomps.transpose([3,2,0,1])
images2 = images.transpose([0,3,1,2])
#SOL = np.argsort(-cnmf_model.estimates.SNR_comp) 
C = cnmf_model.estimates.C
CY = cnmf_model.estimates.C + cnmf_model.estimates.YrA # temporal loadings
R = cnmf_model.estimates.Craw # masks applied to raw movie

CYsav = cnmf_model.estimates.CYsav # smoothened CY curve
def sav_calc(sraw):
    return sg.savgol_filter(sraw,3,1)
Rsav = np.zeros(R.shape)
for i in range(R.shape[0]):
    Rsav[i,:] = sav_calc(R[i,:])
CYsavsort = np.sort(CYsav,axis=1)
CYsavb10 = np.mean(CYsavsort[:,:int(np.ceil(CYsavsort.shape[1]/10))],axis=1) # bottom 10% mean
Rsavsort = np.sort(Rsav,axis=1)
Rsavb10 = np.mean(Rsavsort[:,:int(np.ceil(Rsavsort.shape[1]/10))],axis=1) # bottom 10% mean
n_range = cnmf_model.estimates.n_range
if n_range is None:
    CYf = CYsavb10
    Rf = Rsavb10
else:
    CYf = np.mean(CY[:,n_range],axis=1)
    Rf = np.mean(R[:,n_range],axis=1)

Cn = cnmf_model.estimates.Cn # correlation image (not necessary)
keepargs = cnmf_model.estimates.keepargs
SOL = np.array([x for x in list(SOL) if x in list(keepargs)]) # initial search order list
print(f"Successfully generated search list")

A1 = csc_matrix(cnmf_model.estimates.A)
nr = A1.shape[1]
A_corr = scipy.sparse.triu(A1.T * A1)
A_corr.setdiag(0)
A_corr = A_corr.tocsc()
C_corr = scipy.sparse.lil_matrix(A_corr.shape)
for ii in range(nr):
    overlap_indices = scipy.sparse.find(A_corr[ii, :])[1][scipy.sparse.find(A_corr[ii, :])[2]>Athresh]
    if len(overlap_indices) > 0:
            # we chesk the correlation of the calcium traces for each overlapping components
        corr_values = [scipy.stats.pearsonr(C[ii, :], C[jj, :])[
            0] for jj in overlap_indices]
        C_corr[ii, overlap_indices] = corr_values
C_tot = C_corr + C_corr.T
CYR_corr = np.zeros(nr)
for ii in range(nr):
    CYR_corr[ii] = scipy.stats.pearsonr(R[ii,:],CY[ii,:])[0]
print(f"Successfully calculated correlations")

# Initialize all running variables in a single dictionary - just lists of arguments/component IDs (SOL, CSL, CMG, CKG, saved merge groups, trash)
# check if saved file exists - load that if it does, else instantiate new vars_dict!
save_path = 'ch0_means_movie_nobg_compfilt.pickle'
ori_path = 'ch0_means_movie_nobg_compfilt_ori.pickle'
if save_path in os.listdir():
    vars_dict1 = load_pickle(save_path)
# duplicate original version if not already saved
if ori_path not in os.listdir():
    with open(ori_path, 'wb') as f:
        pickle.dump(vars_dict1, f)
print(f"Successfully initialized")

In [None]:
%%time
spcomps_merge = np.asarray([np.sum(spcomps[vars_dict1["SMG"][component],...],axis=0) for component in range(len(vars_dict1["SMG"]))])
spcomps_merge_e = np.concatenate((np.zeros(spcomps[0,...][np.newaxis,...].shape,dtype=np.float32),spcomps_merge))
spcomps_merge_e2 = np.argmax(spcomps_merge_e,axis=0)

In [None]:
# Initialize viewer and start GUI
viewer = napari.Viewer()
viewer.add_image(images.transpose([0,3,1,2]),name='cells',colormap='gray')

global fig1, ax11, ax12, leg1
fig1, ax11 = plt.subplots() # dF/F
ax12 = ax11.twinx()
fig1.suptitle('Fluorescence signal plots')
ax11.plot(CY[0,:].T/100,c='green',label='initialization')
leg1 = fig1.legend(loc="upper left")

@magicgui(
    auto_call=True,
    component={"widget_type": "Slider", "min": 1, "max": len(vars_dict1["SMG"]), "step": 1, "orientation": "vertical"},
    layout="horizontal",
)
def show_comp(component: int = 1):
    try:
        viewer.layers.remove('component')
    except:
        pass
    try:
        viewer.layers.remove('component contours')
    except:
        pass
    plt.sca(ax11)
    plt.cla()
    plt.sca(ax12)
    plt.cla()
    ax11.set_xlabel('Frame')
    ax11.set_ylabel('Signal')
    ax11.tick_params(axis='y', labelcolor='green')
    ax12.tick_params(axis='y', labelcolor='darkorange')
    viewer.add_image(np.sum(spcomps[vars_dict1["SMG"][component-1],...],axis=0),name='component',colormap='green',opacity=1,blending='additive',visible=False)
    viewer.add_points(np.vstack([g for g in [v[~np.isnan(v).any(axis=1)] for v in [j for k in [shapes[i] for i in vars_dict1["SMG"][component-1]] for j in k]] if g.size>0]),name='component contours',symbol='disc',size=2,face_color='lime',visible=True)
    #viewer.camera.center = coors[vars_dict1["SMG"][component-1][0]].get('CoM')
    viewer.dims.set_point(1,coors[vars_dict1["SMG"][component-1][0]].get('CoM')[0])
    #viewer.camera.zoom = 3
    ax11.plot(CY[vars_dict1["SMG"][component-1],:].T/100,c='green',label='components')
    ax12.plot(np.mean(R[vars_dict1["SMG"][component-1],:],axis=0).T,c='darkorange',ls='--',label='components raw')
    global leg1
    leg1.remove()
    leg1 = fig1.legend(loc="upper left")
    fig1.canvas.draw_idle()

viewer.window.add_dock_widget(show_comp)
show_comp()
viewer.add_labels(spcomps_merge_e2,name='all components')

In [None]:
# this part adjusts vars_dict for cleanup clicking in cleangui using points placed above
# run first line before closing napari!
spots = viewer.layers['Points'].data
print(spots.shape[0])

In [None]:
# we are adding to the search list the components that overlap the points but have not already been added to components
# ignoring any SNR cutoffs - but keeping sigmin cutoff!
spotcomps = []
for r in trange(0,spots.shape[0]):
    spotcomps = spotcomps + np.where(spcomps[:,spots[r,1].astype('int'),spots[r,2].astype('int'),spots[r,3].astype('int')]!=0)[0].tolist()

spotcomps2 = []
tr = [spotcomps2.append(x) for x in spotcomps if x not in spotcomps2]

spotcomps3 = [x for x in spotcomps2 if x not in [x for i in range(len(vars_dict1['SMG'])) for x in vars_dict1['SMG'][i]]]

SOL = np.argsort(-SNRs)
SOL = np.array([x for x in list(SOL) if x in [x for x in spotcomps3 if CYsavb10[x]>cnmf_model.estimates.Sigmin]]) # initial search order list
print(SOL)
print(len(SOL))

vars_dict1['SOL'] = list(SOL)
print(vars_dict1)

In [None]:
# save vars_dict1 and run cleangui-qc!
save_path = 'ch0_means_movie_nobg_compfilt.pickle'
with open(save_path, 'wb') as f:
    pickle.dump(vars_dict1, f)

In [None]:
#USE THIS FOR POST-MERGE/REFIT
#find the post-merge/refit ds GUI further down!
%matplotlib qt
data_dir = input("Enter the data directory:")

In [None]:
%%time
## LOADING EVERYTHING UP - TAKES ~20 sec

# load packages

import napari
from magicgui import magicgui, widgets
import time

from IPython import get_ipython
from IPython.display import clear_output
import os
import matplotlib.pyplot as plt
import numpy as np
from scipy.sparse import csc_matrix
from scipy import signal as sg
import scipy
import pickle

from tifffile.tifffile import imwrite,imread
from tqdm.auto import tqdm,trange

from copy import deepcopy
import h5py

import caiman as cm
from caiman.source_extraction.cnmf import cnmf,params
from caiman.paths import caiman_datadir
from caiman.utils.visualization import get_contours

try:
    if __IPYTHON__:
        get_ipython().run_line_magic('load_ext', 'autoreload')
        get_ipython().run_line_magic('autoreload', '2')
except NameError:
    pass

def load_pickle(file_path):
    """
    Load a dictionary from a pickle file.

    Args:
    - file_path (str): Path to the pickle file.

    Returns:
    - dict: Loaded dictionary.
    """
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    return data

## Loading all the inputs
os.chdir(data_dir)
#cnmf_path = caiman_datadir()+'/example_movies/demoMovie3DYxxbnobg_20240318170305_cnmf.hdf5'
cnmf_path_merged = os.path.join(data_dir, 'ch0_means_movie_nobg_cnmf_merged.hdf5')
cnmf_path_merged_refit = os.path.join(data_dir, 'ch0_means_movie_nobg_cnmf_merged_refit.hdf5')

# CNMFE model
cnmf_model_merged = cnmf.load_CNMF(cnmf_path_merged, 
                            n_processes=1,
                            dview=None)
cnmf_model_merged_refit = cnmf.load_CNMF(cnmf_path_merged_refit, 
                            n_processes=1,
                            dview=None)
print(f"Successfully loaded CNMF model")

mc_memmapped_fname = [i for i in os.listdir() if 'memmap__' in i][0]
Yr, dims, T = cm.load_memmap(mc_memmapped_fname)
images = np.array(np.reshape(Yr.T, [T] + list(dims), order='F')) 
print(f"Successfully loaded data")

# Line up all static inputs
spcomps_m = np.reshape(cnmf_model_merged.estimates.A.toarray(),cnmf_model_merged.estimates.dims + (-1,),order='F')
spcomps_m = spcomps_m.transpose([3,2,0,1]).astype(np.float32)
CY_m = cnmf_model_merged.estimates.C + cnmf_model_merged.estimates.YrA # temporal loadings
spcomps_mr = np.reshape(cnmf_model_merged_refit.estimates.A.toarray(),cnmf_model_merged_refit.estimates.dims + (-1,),order='F')
spcomps_mr = spcomps_mr.transpose([3,2,0,1]).astype(np.float32)
CY_mr = cnmf_model_merged_refit.estimates.C + cnmf_model_merged_refit.estimates.YrA # temporal loadings
print(f"Setup done")

In [None]:
%%time
spcomps_m_e = np.concatenate((np.zeros(spcomps_m[0,...][np.newaxis,...].shape,dtype=np.float32),spcomps_m))
spcomps_m_e2 = np.argmax(spcomps_m_e,axis=0)
spcomps_mr_e = np.concatenate((np.zeros(spcomps_mr[0,...][np.newaxis,...].shape,dtype=np.float32),spcomps_mr))
spcomps_mr_e2 = np.argmax(spcomps_mr_e,axis=0)

In [None]:
# Initialize viewer and start GUI
viewer = napari.Viewer()
viewer.add_image(images.transpose([0,3,1,2]),name='cells',colormap='gray')

global fig1, ax11, ax12, leg1
fig1, ax11 = plt.subplots() # dF/F
ax12 = ax11.twinx()
fig1.suptitle('Fluorescence signal plots')
ax11.plot(CY_m[0,:].T/100,c='green',label='initialization')
leg1 = fig1.legend(loc="upper left")

@magicgui(
    auto_call=True,
    component={"widget_type": "Slider", "min": 1, "max": np.max([CY_mr.shape[0],CY_m.shape[0]]), "step": 1, "orientation": "vertical"},
    layout="horizontal",
)
def show_comp(component: int = 1):
    try:
        viewer.layers.remove('component merged')
    except:
        pass
    try:
        viewer.layers.remove('component refit')
    except:
        pass
    plt.sca(ax11)
    plt.cla()
    plt.sca(ax12)
    plt.cla()
    ax11.set_xlabel('Frame')
    ax11.set_ylabel('Signal')
    ax11.tick_params(axis='y', labelcolor='green')
    ax12.tick_params(axis='y', labelcolor='darkorange')
    viewer.add_image(spcomps_m[component-1,...],name='component merged',colormap='green',opacity=1,blending='additive',visible=True)
    viewer.add_image(spcomps_mr[component-1,...],name='component refit',colormap='darkorange',opacity=1,blending='additive',visible=True)
    #viewer.add_points(np.vstack([g for g in [v[~np.isnan(v).any(axis=1)] for v in [j for k in [shapes[i] for i in vars_dict1["SMG"][component-1]] for j in k]] if g.size>0]),name='component contours',symbol='disc',size=2,face_color='lime',visible=True)
    #viewer.camera.center = coors[vars_dict1["SMG"][component-1][0]].get('CoM')
    #viewer.dims.set_point(1,coors[vars_dict1["SMG"][component-1][0]].get('CoM')[0])
    #viewer.camera.zoom = 3
    ax11.plot(CY_m[component-1,:].T/100,c='green',label='component merged')
    ax12.plot(CY_mr[component-1,:].T/100,c='darkorange',label='component refit')
    #ax12.plot(np.mean(R[vars_dict1["SMG"][component-1],:],axis=0).T,c='darkorange',ls='--',label='components raw')
    fig1.canvas.draw_idle()

viewer.window.add_dock_widget(show_comp)
show_comp()
viewer.add_labels(spcomps_m_e2,name='all components merged')
viewer.add_labels(spcomps_mr_e2,name='all components refit')

In [None]:
#USE THIS FOR POST-RAW
%matplotlib qt
data_dir = input("Enter the data directory:")

In [None]:
%%time
## LOADING EVERYTHING UP - TAKES ~20 sec

# load packages

import napari
from magicgui import magicgui, widgets
import time

from IPython import get_ipython
from IPython.display import clear_output
import os
import matplotlib.pyplot as plt
import numpy as np
from scipy.sparse import csc_matrix
from scipy import signal as sg
import scipy
import pickle

from tifffile.tifffile import imwrite,imread
from tqdm.auto import tqdm,trange

from copy import deepcopy
import h5py

import caiman as cm
from caiman.source_extraction.cnmf import cnmf,params
from caiman.paths import caiman_datadir
from caiman.utils.visualization import get_contours

try:
    if __IPYTHON__:
        get_ipython().run_line_magic('load_ext', 'autoreload')
        get_ipython().run_line_magic('autoreload', '2')
except NameError:
    pass

def load_pickle(file_path):
    """
    Load a dictionary from a pickle file.

    Args:
    - file_path (str): Path to the pickle file.

    Returns:
    - dict: Loaded dictionary.
    """
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    return data

## Loading all the inputs
os.chdir(data_dir)
#cnmf_path = caiman_datadir()+'/example_movies/demoMovie3DYxxbnobg_20240318170305_cnmf.hdf5'
cnmf_path_merged = os.path.join(data_dir, 'ch0_means_movie_nobg_cnmf_mr_raw.hdf5')
cnmf_path_merged_refit = os.path.join(data_dir, 'ch0_means_movie_nobg_cnmf_mr_raw_refit.hdf5')

# CNMFE model
cnmf_model_merged = cnmf.load_CNMF(cnmf_path_merged, 
                            n_processes=1,
                            dview=None)
cnmf_model_merged_refit = cnmf.load_CNMF(cnmf_path_merged_refit, 
                            n_processes=1,
                            dview=None)
print(f"Successfully loaded CNMF model")

mc_memmapped_fname = [i for i in os.listdir() if 'memmap__' in i][0]
Yr, dims, T = cm.load_memmap(mc_memmapped_fname)
images = np.array(np.reshape(Yr.T, [T] + list(dims), order='F')) 
print(f"Successfully loaded data")

# Line up all static inputs
spcomps_m = np.reshape(cnmf_model_merged.estimates.A.toarray(),cnmf_model_merged.estimates.dims + (-1,),order='F')
spcomps_m = spcomps_m.transpose([3,2,0,1]).astype(np.float32)
CY_m = cnmf_model_merged.estimates.C + cnmf_model_merged.estimates.YrA # temporal loadings
spcomps_mr = np.reshape(cnmf_model_merged_refit.estimates.A.toarray(),cnmf_model_merged_refit.estimates.dims + (-1,),order='F')
spcomps_mr = spcomps_mr.transpose([3,2,0,1]).astype(np.float32)
CY_mr = cnmf_model_merged_refit.estimates.C + cnmf_model_merged_refit.estimates.YrA # temporal loadings
print(f"Setup done")

In [None]:
%%time
spcomps_m_e = np.concatenate((np.zeros(spcomps_m[0,...][np.newaxis,...].shape,dtype=np.float32),spcomps_m))
spcomps_m_e2 = np.argmax(spcomps_m_e,axis=0)
spcomps_mr_e = np.concatenate((np.zeros(spcomps_mr[0,...][np.newaxis,...].shape,dtype=np.float32),spcomps_mr))
spcomps_mr_e2 = np.argmax(spcomps_mr_e,axis=0)

In [None]:
# Initialize viewer and start GUI
viewer = napari.Viewer()
viewer.add_image(images.transpose([0,3,1,2]),name='cells',colormap='gray')

global fig1, ax11, ax12, leg1
fig1, ax11 = plt.subplots() # dF/F
ax12 = ax11.twinx()
fig1.suptitle('Fluorescence signal plots')
ax11.plot(CY_m[0,:].T/100,c='green',label='initialization')
leg1 = fig1.legend(loc="upper left")

@magicgui(
    auto_call=True,
    component={"widget_type": "Slider", "min": 1, "max": np.max([CY_mr.shape[0],CY_m.shape[0]]), "step": 1, "orientation": "vertical"},
    layout="horizontal",
)
def show_comp(component: int = 1):
    try:
        viewer.layers.remove('component ds')
    except:
        pass
    try:
        viewer.layers.remove('component raw')
    except:
        pass
    plt.sca(ax11)
    plt.cla()
    plt.sca(ax12)
    plt.cla()
    ax11.set_xlabel('Frame')
    ax11.set_ylabel('Signal')
    ax11.tick_params(axis='y', labelcolor='green')
    ax12.tick_params(axis='y', labelcolor='darkorange')
    viewer.add_image(spcomps_m[component-1,...],name='component ds',colormap='green',opacity=1,blending='additive',visible=True)
    viewer.add_image(spcomps_mr[component-1,...],name='component raw',colormap='darkorange',opacity=1,blending='additive',visible=True)
    #viewer.add_points(np.vstack([g for g in [v[~np.isnan(v).any(axis=1)] for v in [j for k in [shapes[i] for i in vars_dict1["SMG"][component-1]] for j in k]] if g.size>0]),name='component contours',symbol='disc',size=2,face_color='lime',visible=True)
    #viewer.camera.center = coors[vars_dict1["SMG"][component-1][0]].get('CoM')
    #viewer.dims.set_point(1,coors[vars_dict1["SMG"][component-1][0]].get('CoM')[0])
    #viewer.camera.zoom = 3
    ax11.plot(CY_m[component-1,:].T/100,c='green',label='component ds')
    ax12.plot(CY_mr[component-1,:].T/100,c='darkorange',label='component raw')
    ax11.set_zorder(1)  # default zorder is 0 for ax1 and ax2
    ax11.patch.set_visible(False)  # prevents ax1 from hiding ax2
    #ax12.plot(np.mean(R[vars_dict1["SMG"][component-1],:],axis=0).T,c='darkorange',ls='--',label='components raw')
    fig1.canvas.draw_idle()

viewer.window.add_dock_widget(show_comp)
show_comp()
viewer.add_labels(spcomps_m_e2,name='all components ds')
viewer.add_labels(spcomps_mr_e2,name='all components raw')

In [None]:
# visualizing reference and moving masks
%matplotlib qt
ref_dir = input("Enter the reference CNMF directory (e.g. in vivo/slice):")
mov_dir = input("Enter the moving CNMF directory (e.g. wells):")

In [None]:
%%time
## LOADING EVERYTHING UP - TAKES ~20 sec

# load packages

import napari
from magicgui import magicgui, widgets
import time

from IPython import get_ipython
from IPython.display import clear_output
import os
import matplotlib.pyplot as plt
import numpy as np
from scipy.sparse import csc_matrix
from scipy import signal as sg
import scipy
import pickle

from tifffile.tifffile import imwrite,imread
from tqdm.auto import tqdm,trange

from copy import deepcopy
import h5py

import caiman as cm
from caiman.source_extraction.cnmf import cnmf,params
from caiman.paths import caiman_datadir
from caiman.utils.visualization import get_contours

try:
    if __IPYTHON__:
        get_ipython().run_line_magic('load_ext', 'autoreload')
        get_ipython().run_line_magic('autoreload', '2')
except NameError:
    pass

def load_pickle(file_path):
    """
    Load a dictionary from a pickle file.

    Args:
    - file_path (str): Path to the pickle file.

    Returns:
    - dict: Loaded dictionary.
    """
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    return data

# CNMFE model
cnmf_ref_path = os.path.join(ref_dir, 'ch0_means_movie_nobg_cnmf_merged_refit.hdf5')
cnmf_mov_path = os.path.join(mov_dir, 'ch0_means_movie_nobg_cnmf_merged_refit.hdf5')

# CNMFE model
cnmf_ref = cnmf.load_CNMF(cnmf_ref_path, 
                            n_processes=1,
                            dview=None)
cnmf_mov = cnmf.load_CNMF(cnmf_mov_path, 
                            n_processes=1,
                            dview=None)

# Line up all static inputs
spcomps_r = np.reshape(cnmf_ref.estimates.A.toarray(),cnmf_ref.estimates.dims + (-1,),order='F')
spcomps_r = spcomps_r.transpose([3,2,0,1]).astype(np.float32)
CY_r = cnmf_ref.estimates.C + cnmf_ref.estimates.YrA # temporal loadings
spcomps_m = np.reshape(cnmf_mov.estimates.A.toarray(),cnmf_mov.estimates.dims + (-1,),order='F')
spcomps_m = spcomps_m.transpose([3,2,0,1]).astype(np.float32)
CY_m = cnmf_mov.estimates.C + cnmf_mov.estimates.YrA # temporal loadings
print(f"Successfully loaded CNMF model")

mc_memmapped_fname_r = [i for i in os.listdir(ref_dir) if 'memmap__' in i][0]
Yr_r, dims_r, T_r = cm.load_memmap(ref_dir + mc_memmapped_fname_r)
images_r = np.array(np.reshape(Yr_r.T, [T_r] + list(dims_r), order='F')) 

mc_memmapped_fname_m = [i for i in os.listdir(mov_dir) if 'memmap__' in i][0]
Yr_m, dims_m, T_m = cm.load_memmap(mov_dir + mc_memmapped_fname_m)
images_m = np.array(np.reshape(Yr_m.T, [T_m] + list(dims_m), order='F')) 
print(f"Successfully loaded data")

base_dir = r'/mnt/ssd_cache/pkalugin/stitch_warp'
warp_dir = os.path.join(base_dir,os.path.split(os.path.dirname(ref_dir))[1].split('_')[0] + '_' + os.path.split(os.path.dirname(ref_dir))[1].split('_')[1])
dfield_path = warp_dir + r'/bigwarp dfield.tif'
dfield = imread(dfield_path)
# will need to adjust this to several moving images per group!
ref_img = imread(warp_dir + r'/' + [f for f in os.listdir(warp_dir) if 'fr0_coor' in f][0])
mov_img = imread(warp_dir + r'/' + [f for f in os.listdir(warp_dir) if 'fr1_coor_rot' in f][0])

import SimpleITK as sitk

displacement_image = sitk.GetImageFromArray(np.double(dfield.transpose([0,2,3,1])),sitk.sitkVectorFloat64)

tx = sitk.DisplacementFieldTransform(displacement_image)
tx.SetInterpolator(sitk.sitkLinear)

mov_img_d = sitk.GetArrayFromImage(sitk.Resample(sitk.GetImageFromArray(np.double(mov_img)),tx))

print(mov_img_d.shape)
print(ref_img.shape)
shape_keep = ()
for i in range(len(ref_img.shape)):
    shape_keep += (np.min([ref_img.shape[i],mov_img_d.shape[i]]),)
print(shape_keep)
mov_img_d = mov_img_d[:shape_keep[0],:shape_keep[1],:shape_keep[2]]
ref_img = ref_img[:shape_keep[0],:shape_keep[1],:shape_keep[2]]
print(mov_img_d.shape)
print(ref_img.shape)

In [None]:
print(spcomps_m.shape)
print(images_m.shape)

In [None]:
# this part just performs the bigstitcher rigid body move on the spcomps_m
# will need to adjust for multiple moving images
from scipy.ndimage import affine_transform
import xml.etree.ElementTree as ET
base_dir = r'/mnt/ssd_cache/pkalugin/stitch_warp'

spcomps_m_rot = np.zeros(spcomps_m.shape)
images_m_rot = np.zeros(images_m.shape)

n = 0
if n==0:
    tree = ET.parse(warp_dir+f'/dataset.xml')
else:
    tree = ET.parse(warp_dir+f'/dataset{n:01}.xml')
root = tree.getroot()
w_mat = np.array(root[2][1][0][1].text.split()).astype(np.float64).reshape([3,4])
w_mat_i = np.linalg.inv(np.append(w_mat,np.array([0,0,0,1])[np.newaxis,:],axis=0))
w_mat[0,3] = w_mat_i[1,3]
w_mat[1,3] = w_mat_i[0,3]

for ncomp in trange(spcomps_m.shape[0]):
    spcomps_m_rot[ncomp,...] = affine_transform(spcomps_m[ncomp,...].transpose([1,2,0]),w_mat).transpose([2,0,1])

for nfr in trange(images_m.shape[0]):
    images_m_rot[nfr,...] = affine_transform(images_m[nfr,...],w_mat)

print(f"Setup done")

In [None]:
spcomps_m_rot[spcomps_m_rot<0.001] = 0

In [None]:
spcomps_m_warp = np.zeros(spcomps_m_rot.shape)
images_m_warp = np.zeros(images_m_rot.shape)

for ncomp in trange(spcomps_m.shape[0]):
    spcomps_m_warp[ncomp,...] = sitk.GetArrayFromImage(sitk.Resample(sitk.GetImageFromArray(np.double(spcomps_m_rot[ncomp,...])),tx))

for nfr in trange(images_m.shape[0]):
    images_m_warp[nfr,...] = sitk.GetArrayFromImage(sitk.Resample(sitk.GetImageFromArray(np.double(images_m_rot[nfr,...].transpose([2,0,1]))),tx)).transpose([1,2,0])

In [None]:
%%time
spcomps_r_e = np.concatenate((np.zeros(spcomps_r[0,...][np.newaxis,...].shape,dtype=np.float32),spcomps_r))
spcomps_r_e2 = np.argmax(spcomps_r_e,axis=0)
spcomps_m_rot_e = np.concatenate((np.zeros(spcomps_m_rot[0,...][np.newaxis,...].shape,dtype=np.float32),spcomps_m_rot.astype(np.float32)))
spcomps_m_rot_e2 = np.argmax(spcomps_m_rot_e,axis=0)
spcomps_m_warp_e = np.concatenate((np.zeros(spcomps_m_warp[0,...][np.newaxis,...].shape,dtype=np.float32),spcomps_m_warp.astype(np.float32)))
spcomps_m_warp_e2 = np.argmax(spcomps_m_warp_e,axis=0)

In [None]:
# Initialize viewer and start GUI
viewer = napari.Viewer()
viewer.add_image(images_r.transpose([0,3,1,2]),name='cells',colormap='gray')
viewer.add_image(images_m_rot.transpose([0,3,1,2]),name='cells',colormap='gray')
viewer.add_image(images_m_warp.transpose([0,3,1,2]),name='cells',colormap='gray')

global fig1, ax11, ax12, leg1
fig1, ax11 = plt.subplots() # dF/F
ax12 = ax11.twinx()
fig1.suptitle('Fluorescence signal plots')
ax11.plot(CY_m[0,:].T/100,c='green',label='initialization')
leg1 = fig1.legend(loc="upper left")

@magicgui(
    auto_call=True,
    component={"widget_type": "Slider", "min": 1, "max": np.max([CY_r.shape[0],CY_m.shape[0]]), "step": 1, "orientation": "vertical"},
    layout="horizontal",
)
def show_comp(component: int = 1):
    try:
        viewer.layers.remove('component mov')
    except:
        pass
    try:
        viewer.layers.remove('component ref')
    except:
        pass
    plt.sca(ax11)
    plt.cla()
    plt.sca(ax12)
    plt.cla()
    ax11.set_xlabel('Frame')
    ax11.set_ylabel('Signal')
    ax11.tick_params(axis='y', labelcolor='green')
    ax12.tick_params(axis='y', labelcolor='darkorange')
    viewer.add_image(spcomps_m_warp[component-1,...],name='component mov',colormap='green',opacity=1,blending='additive',visible=True)
    viewer.add_image(spcomps_r[component-1,...],name='component ref',colormap='darkorange',opacity=1,blending='additive',visible=True)
    #viewer.add_points(np.vstack([g for g in [v[~np.isnan(v).any(axis=1)] for v in [j for k in [shapes[i] for i in vars_dict1["SMG"][component-1]] for j in k]] if g.size>0]),name='component contours',symbol='disc',size=2,face_color='lime',visible=True)
    #viewer.camera.center = coors[vars_dict1["SMG"][component-1][0]].get('CoM')
    #viewer.dims.set_point(1,coors[vars_dict1["SMG"][component-1][0]].get('CoM')[0])
    #viewer.camera.zoom = 3
    ax11.plot(CY_m[component-1,:].T/100,c='green',label='component mov')
    ax12.plot(CY_r[component-1,:].T/100,c='darkorange',label='component ref')
    ax11.set_zorder(1)  # default zorder is 0 for ax1 and ax2
    ax11.patch.set_visible(False)  # prevents ax1 from hiding ax2
    #ax12.plot(np.mean(R[vars_dict1["SMG"][component-1],:],axis=0).T,c='darkorange',ls='--',label='components raw')
    fig1.canvas.draw_idle()

viewer.window.add_dock_widget(show_comp)
show_comp()
viewer.add_labels(spcomps_r_e2,name='all components ref')
viewer.add_labels(spcomps_m_rot_e2,name='all components rot')
viewer.add_labels(spcomps_m_warp_e2,name='all components mov')

In [None]:
np.amin(spcomps_r[spcomps_r>0])

In [None]:
spcomps_m_rot.shape

In [None]:
%matplotlib qt
plt.imshow(spcomps_r.astype(np.float32)[77,10,...])

In [None]:
spcomps_m_rot.astype(np.float32)[0,10,200,200] == 0

In [None]:
%matplotlib inline
spcomps_m_rot_z = spcomps_m_rot
spcomps_m_rot_z[spcomps_m_rot<0.001] = 0
plt.plot(spcomps_m_rot[:,10,239,248])
plt.plot(spcomps_m_rot_z[:,10,239,248])

In [None]:
%matplotlib qt

In [None]:
viewer = napari.Viewer()
viewer.add_image(images_r.transpose([0,3,1,2]),scale=[1,4,1,1],name='cells',colormap='gray')

In [None]:
viewer = napari.Viewer()
viewer.add_image(images.transpose([0,3,1,2]),scale=[1,4,1,1],name='cells',colormap='gray')

In [None]:
images_r.shape

In [None]:
plt.plot(images_r[:,215,275,4])

In [None]:
spcomps_r.shape

In [None]:
cnmf_ref.estimates.A.shape

In [None]:
cnmf_ref_path

In [None]:
cnmf_path = os.path.join(ref_dir, 'ch0_means_movie_nobg_cnmf.hdf5')

# CNMFE model
cnmf_model = cnmf.load_CNMF(cnmf_path, 
                            n_processes=1,
                            dview=None)

In [None]:
save_path = 'ch0_means_movie_nobg_compfilt.pickle'
if save_path in os.listdir(ref_dir):
    vars_dict1 = load_pickle(save_path)

In [None]:
print(vars_dict1)

In [None]:
#keepargs = cnmf_model.estimates.keepargs
keepargs

In [None]:
keepargs.shape

In [None]:
spcomps = np.reshape(cnmf_model.estimates.A.toarray(),cnmf_model.estimates.dims + (-1,),order='F')
spcomps = spcomps.transpose([3,2,0,1])
spcomps.shape

In [None]:
np.where(spcomps[:,12,192,328]!=0)

In [None]:
viewer.layers['Points'].data

In [None]:
np.where(spcomps[:,8,255,345]!=0)

In [None]:
plt.imshow(spcomps[484,13,...])

In [None]:
#%matplotlib inline
plt.plot(cnmf_model.estimates.C[[455, 484],...].T+cnmf_model.estimates.YrA[[455, 484],...].T)

In [None]:
#%matplotlib inline
plt.plot(cnmf_model.estimates.C[[504,672],...].T+cnmf_model.estimates.YrA[[504,672],...].T)

In [None]:
#%matplotlib inline
plt.plot(cnmf_model.estimates.C[[504,672],...].T)

In [None]:
%matplotlib inline
plt.plot(cnmf_model.estimates.C[[445, 477, 624, 654],...].T+cnmf_model.estimates.YrA[[445, 477, 624, 654],...].T)

In [None]:
%matplotlib inline
plt.plot(CY[[274, 290, 446, 456],...].T)
plt.plot(CYsav[[274, 290, 446, 456],...].T)

In [None]:
plt.plot((CY-CYsav)[[274, 290, 446, 456],...].T)

In [None]:
CYsavsort = np.sort(CYsav,axis=1)
CYsavt10 = np.mean(CYsavsort[:,int(CYsavsort.shape[1]-np.ceil(CYsavsort.shape[1]/10)):],axis=1) # top 10% mean
CYsavb10 = np.mean(CYsavsort[:,:int(np.ceil(CYsavsort.shape[1]/10))],axis=1) # bottom 10% mean
if n_range is None:
    sig = CYsavt10 - CYsavb10
else:
    sig = CYsavt10 - np.mean(CYsav[:,n_range],axis=1) # baseline mean

noisestd = np.std(CY - CYsav,axis=1) # noise calculated across whole movie
#SNR = sig/noisestd

In [None]:
sig[[274, 290, 446, 456]]

In [None]:
noisestd[[274, 290, 446, 456]]

In [None]:
cnmf_model.estimates.C.shape

In [None]:
%matplotlib qt

In [None]:
SNRs[[271, 297, 301, 323, 455, 484]]

In [None]:
SNR_min

In [None]:
SNR = sig/noisestd
SNRsort = np.sort(SNR)
SNRsort[int(np.floor(SNRsort.shape[0]/4))]

In [None]:
SNRsort[int(np.floor(0.2*SNRsort.shape[0]))]

In [None]:
cnmf_model.estimates.Sigmin

In [None]:
CY.shape

In [None]:
plt.hist(SNR,bins=50)
plt.vlines(SNRmin,0,plt.gca().get_ylim()[1],color='r')
plt.vlines(np.mean(SNR),0,plt.gca().get_ylim()[1],color='g')
plt.vlines(np.median(SNR),0,plt.gca().get_ylim()[1],color='b')

In [None]:
np.median(SNR)

In [None]:
np.mean(R[vars_dict1["SMG"][component],:],axis=0).T

In [None]:
merge_traces_raw = np.asarray([np.mean(R[vars_dict1["SMG"][component],:],axis=0) for component in range(len(vars_dict1["SMG"]))])
merge_traces_raw_dff = np.asarray([np.mean((R[vars_dict1["SMG"][component],:].T-Rf[np.array(vars_dict1["SMG"][component])])/Rf[np.array(vars_dict1["SMG"][component])],axis=1) for component in range(len(vars_dict1["SMG"]))])

In [None]:
merge_traces_raw_dff.shape

In [None]:
from scipy.cluster.hierarchy import dendrogram, linkage

data = merge_traces_raw_dff
linkage_data = linkage(data, method='ward', metric='euclidean')
dendrogram(linkage_data)

plt.show() 

In [None]:
linkage_data.shape

In [None]:
from sklearn.cluster import AgglomerativeClustering

hierarchical_cluster = AgglomerativeClustering(n_clusters=4, metric='euclidean', linkage='ward')
labels = hierarchical_cluster.fit_predict(data)

In [None]:
labels[69]

In [None]:
plt.imshow(merge_traces_raw_dff[np.argsort(labels),...],cmap='turbo')

In [None]:
plt.imshow(merge_traces_raw_dff,cmap='turbo')

In [None]:
spcomps_merge_e.shape

In [None]:
np.amax(spcomps_merge_e2)

In [None]:
len(vars_dict1["SMG"])

In [None]:
type(spcomps)

In [None]:
spcomps.dtype

In [None]:
spcomps_add = np.zeros(spcomps[0,...][np.newaxis,...].shape,dtype=np.float32)

In [None]:
spcomps_add.dtype

In [None]:
spcompse.dtype

In [None]:
spcompse

In [None]:
spcompse = np.concatenate((spcomps_add,spcomps))

In [None]:
spcomps[vars_dict1["SMG"][0],...].shape

In [None]:
images.shape

In [None]:
spcomps[0,...][np.newaxis,...].shape

In [None]:
spcompse = np.append(np.zeros(spcomps[0,...][np.newaxis,...].shape,dtype=np.float32),spcomps,axis=0)

In [None]:
spcompse.shape

In [None]:
spcompse2 = np.argmax(spcompse,axis=0)

In [None]:
np.amax(spcomps2)

In [None]:
spcomps_merge = np.asarray([np.sum(spcomps[vars_dict1["SMG"][component],...],axis=0) for component in range(len(vars_dict1["SMG"])-1)])

In [None]:
spcomps_merge.shape

In [None]:
spcomps_merge2 = np.nanargmax(spcomps_merge,0)

In [None]:
spcomps_merge2

In [None]:
spcomps[vars_dict1["SMG"],...].shape

In [None]:
spcomps_merge[spcomps_merge == 0] = None

In [None]:
spcomps_merge