In [1]:
%matplotlib qt
data_dir = input("Enter the data directory:")

# Parameters (unlikely to change)
n_range_lim = 10 # size of n_range below which SNR considered unreliable
Athresh = 0.05 # overlap threshold - automatically split anything below it
cr_thresh = 0.9 # component-raw correlation threshold below which component deemed suspicious quality
pb_thresh = 0.95 # component-best parent correlation threshold above which component deemed likely merge

Enter the data directory: H:\CNMFoutputs\jordan\tierM\20230822_SL1PL32_slice_Done_qc_doneqc


In [2]:
%%time
## LOADING EVERYTHING UP - TAKES ~20 sec

# load packages

import napari
from magicgui import magicgui, widgets
import time

from IPython import get_ipython
from IPython.display import clear_output
import os
import matplotlib.pyplot as plt
import numpy as np
from scipy.sparse import csc_matrix
from scipy import signal as sg
from scipy.ndimage import center_of_mass
import scipy
import pickle

from tifffile.tifffile import imwrite,imread
from tqdm.auto import tqdm,trange

from copy import deepcopy
import h5py

import caiman as cm
from caiman.source_extraction.cnmf import cnmf,params
from caiman.paths import caiman_datadir
from caiman.utils.visualization import get_contours

try:
    if __IPYTHON__:
        get_ipython().run_line_magic('load_ext', 'autoreload')
        get_ipython().run_line_magic('autoreload', '2')
except NameError:
    pass

def load_pickle(file_path):
    """
    Load a dictionary from a pickle file.

    Args:
    - file_path (str): Path to the pickle file.

    Returns:
    - dict: Loaded dictionary.
    """
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    return data

## Loading all the inputs
os.chdir(data_dir)
#cnmf_path = caiman_datadir()+'/example_movies/demoMovie3DYxxbnobg_20240318170305_cnmf.hdf5'
cnmf_path = os.path.join(data_dir, 'ch0_means_movie_nobg_cnmf_merged_refit.hdf5')

# CNMFE model
cnmf_model = cnmf.load_CNMF(cnmf_path, 
                            n_processes=1,
                            dview=None)
print(f"Successfully loaded merged/refit CNMF model")

mc_memmapped_fname = [i for i in os.listdir() if 'memmap__' in i][0]
Yr, dims, T = cm.load_memmap(mc_memmapped_fname)
images = np.array(np.reshape(Yr.T, [T] + list(dims), order='F')) 
print(f"Successfully loaded data")

# Line up static inputs
spcomps = np.reshape(cnmf_model.estimates.A.toarray(),cnmf_model.estimates.dims + (-1,),order='F')
spcomps = spcomps.transpose([3,2,0,1]).astype(np.float32)
CY = cnmf_model.estimates.C + cnmf_model.estimates.YrA # temporal loadings

spcomps_e = np.concatenate((np.zeros(spcomps[0,...][np.newaxis,...].shape,dtype=np.float32),spcomps))
spcomps_e2 = np.argmax(spcomps_e,axis=0)

coms = np.zeros((spcomps.shape[0],3))
for ncomp in trange(spcomps.shape[0]):
    coms[ncomp,...] = center_of_mass(spcomps[ncomp,...])

spcomps_l = np.zeros(spcomps.shape).astype('int')
for ncomp in trange(spcomps.shape[0]):
    spcomps_l[ncomp,...] = spcomps[ncomp,...]>0
print(f"Successfully initialized")

Successfully loaded merged/refit CNMF model
Successfully loaded data


  0%|          | 0/185 [00:00<?, ?it/s]

  0%|          | 0/185 [00:00<?, ?it/s]

Successfully initialized
CPU times: total: 51.3 s
Wall time: 56.8 s


In [3]:
%%time
spcomps_f = np.reshape(spcomps,(spcomps.shape[0],) + (-1,))
A_corr = np.matmul(spcomps_f,spcomps_f.T)
SOL_m = [np.where(A_corr[i,:]>0)[0][np.argsort(-A_corr[i,np.where(A_corr[i,:]>0)]).tolist()[0]].tolist() for i in range(A_corr.shape[0])]
SOL_m = [[x for x in SOL_m[i] if x != i] for i in range(len(SOL_m))]
SOL = [i for i in range(len(SOL_m)) if len(SOL_m[i]) !=0]

CPU times: total: 8.58 s
Wall time: 8.58 s


In [4]:
%%time
merge_path = f'ch0_means_movie_nobg_compmerge.npy'
if merge_path in os.listdir(data_dir):
    merge_dict1 = np.load(os.path.join(data_dir, merge_path),allow_pickle=True)[()]
    print(merge_dict1)
else:
    merge_dict1 = {
        "SOL": SOL, # list of components to be searched (exclude those with zero matches - update each iteration)
        "SOL_m": SOL_m, # search list for given component
        "MpointsC": [[] for i in range(len(SOL_m))], # merge points for each component in SOL
        "MpointsA": list(), # all match points - all points in MpointsC plus COMs of points that had matches selected!
    }
    print(merge_dict1)

{'SOL': [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 13, 14, 15, 17, 18, 19, 20, 21, 23, 24, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 143, 144, 146, 147, 148, 149, 150, 151, 154, 155, 156, 157, 158, 161, 162, 163, 164, 165, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 181, 182, 183, 184], 'SOL_m': [[65, 66, 52, 4], [107, 79], [54, 53], [80], [66, 0, 116, 69, 67], [], [60, 54, 56, 53], [184, 135], [127, 125, 126, 112], [95, 138, 38, 98], [70, 32, 177], [74, 73, 139], [], [113], [94, 135], [93, 91], [], [18, 63, 136, 62], [17, 62

In [5]:
# Initialize viewer and start GUI
viewer = napari.Viewer()
viewer.add_image(images.transpose([0,3,1,2]),name='cells',colormap='gray')

merge_dict2 = deepcopy(merge_dict1)
merge_dict3 = deepcopy(merge_dict2)

#global fig1, ax11, ax12, leg1
#fig1, ax11 = plt.subplots() # dF/F
#ax12 = ax11.twinx()
#fig1.suptitle('Fluorescence signal plots')
#ax11.plot(CY[0,:].T/100,c='green',label='initialization')
#leg1 = fig1.legend(loc="upper left")
###
# Key bindings to speed up selections
@viewer.bind_key('n')
def pressN(viewer):
    clickN()

@viewer.bind_key('u')
def pressU(viewer):
    clickU()

#@viewer.bind_key('c')
#def pressC(viewer):
#    viewer.layers['all components'].visible = not viewer.layers['all components'].visible

#@viewer.bind_key('v')
#def pressV(viewer):
#    viewer.layers['component contours'].visible = not viewer.layers['component contours'].visible

@viewer.bind_key('Up')
def jump_up(viewer):
    viewer.dims.set_current_step(1, viewer.dims.current_step[1] - 1)

@viewer.bind_key('Down')
def jump_down(viewer):
    viewer.dims.set_current_step(1, viewer.dims.current_step[1] + 1)

@viewer.bind_key('Left')
def jump_left(viewer):
    viewer.dims.set_current_step(0, viewer.dims.current_step[0] - 1)

@viewer.bind_key('Right')
def jump_right(viewer):
    viewer.dims.set_current_step(0, viewer.dims.current_step[0] + 1)
    
@magicgui(
    auto_call=True,btn={"widget_type": "PushButton", "text": "Next component"}
)
def click_next(btn):
    clickN()

@magicgui(
    auto_call=True,btn={"widget_type": "PushButton", "text": "Undo"}
)
def click_undo(btn):
    clickU()

# Save and close button inside napari
@magicgui(
    auto_call=True,btn={"widget_type": "PushButton", "text": "Save and close GUI"}
)
def save_btn(btn):
    #qc_data["qc_spots"] = viewer.layers['qc spots'].data.astype('int')
    #qc_data["trash_spots"] = viewer.layers['trash spots'].data.astype('int')
    np.save(merge_path,merge_dict1)
    plt.close('all')
    viewer.close()

# Message displayer inside napari to print any messages
lab = widgets.Label()
def print_lab(message=None):
    if message is None:
        lab.value = ""
    else:
        lab.value = message

# Progress displayer inside napari
prlab = widgets.Label()
def prlab_update():
    prlab.value = str(len(merge_dict1['SOL'])) + " masks to go"

# Aligning widgets
layout = widgets.Container(
    widgets=[click_next,click_undo,lab,save_btn,prlab], layout="vertical", labels=False
)

def show_comp(comp):
    try:
        viewer.layers.remove('component contour')
    except:
        pass
    try:
        viewer.layers.remove('neighbor contours')
    except:
        pass
    try:
        viewer.layers.remove('current merge points')
    except:
        pass
    try:
        viewer.layers.remove('all merge points')
    except:
        pass
     ##   
    #plt.sca(ax11)
    #plt.cla()
    #plt.sca(ax12)
    #plt.cla()
    #ax11.set_xlabel('Frame')
    #ax11.set_ylabel('Signal')
    #ax11.tick_params(axis='y', labelcolor='green')
    #ax12.tick_params(axis='y', labelcolor='darkorange')

    viewer.add_labels(spcomps_l[comp,...],name='component contour',color={0:'transparent',1:'red'})
    viewer.layers['component contour'].contour = 1

    nconts = []
    for j in merge_dict1['SOL_m'][comp]:
        nconts.append((j+1)*spcomps_l[j,...])
    viewer.add_labels(np.sum(np.array(nconts),axis=0),name='neighbor contours')
    viewer.layers['neighbor contours'].contour = 1

    #for j in add_r.tolist():
    #    add_r_l[spcomps_l == j+1] = j+1
    #viewer.add_labels(add_r_l,name='neighbor contours')
    #viewer.layers['neighbor contours'].contour = 1
    
    viewer.camera.center = coms[comp]
    viewer.dims.set_point(1,coms[comp][0])
    viewer.camera.zoom = 3

    viewer.add_points(data=merge_dict1["MpointsA"],name='all merge points',ndim=3,face_color='red',size=6,out_of_slice_display=True,visible=False)
    viewer.add_points(data=merge_dict1["MpointsC"][merge_dict1["SOL"][0]],name='current merge points',ndim=3,face_color='lime',size=6,out_of_slice_display=True)

    #ax11.plot(CY[comp,:].T/100,c='green',label='component')
    #ax12.plot(CY[merge_dict1["SOL_m"][comp],:].T/100,c='darkorange',ls='--',label='neighbors')
    #global leg1
    #leg1.remove()
    #leg1 = fig1.legend(loc="upper left")
    #fig1.canvas.draw_idle()

def clickN():
    global merge_dict1, merge_dict2
    #, merge_dict3
    #merge_dict3 = deepcopy(merge_dict2)
    merge_dict2 = deepcopy(merge_dict1)
    newspots = viewer.layers['current merge points'].data.astype('int')
    merge_dict1["MpointsC"][merge_dict1["SOL"][0]] = newspots
    if len(merge_dict1["MpointsA"]) == 0:
        merge_dict1["MpointsA"] = np.vstack((newspots,coms[merge_dict1["SOL"][0]]))
    else:
        merge_dict1["MpointsA"] = np.vstack((merge_dict1["MpointsA"],newspots,coms[merge_dict1["SOL"][0]]))
    
    if len(merge_dict1["SOL"]) > 1:
        merge_dict1["SOL"] = merge_dict1["SOL"][1:]
        prlab_update()
        show_comp(merge_dict1["SOL"][0])
    else:
        print_lab("Congratulations - all done! Remember to press Save!")

def clickU():
    merge_dict1 = deepcopy(merge_dict2)
    prlab_update()
    print_lab("")
    show_comp(merge_dict1["SOL"][0])

viewer.window.add_dock_widget(layout)
show_comp(merge_dict1["SOL"][0])
prlab_update()

  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKING,
  TYPE_CHECKIN

In [81]:
merge_dict1

{'SOL': [7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  45,
  46,
  47,
  48,
  49,
  50,
  51,
  52,
  53,
  55,
  56,
  57,
  58,
  59,
  60,
  61,
  62,
  63,
  64,
  65,
  66,
  67,
  68,
  69,
  70,
  71,
  73,
  74,
  75,
  76,
  78,
  80,
  81,
  82,
  83,
  84,
  85,
  86,
  88,
  89,
  90,
  91,
  92,
  94,
  99,
  100,
  101,
  102,
  103,
  104,
  105,
  106,
  107,
  109,
  112,
  113],
 'SOL_m': [[],
  [2, 86, 3],
  [1],
  [5, 86, 1],
  [],
  [3, 86],
  [7, 47, 88],
  [6, 88],
  [104],
  [10, 75],
  [9, 76],
  [12],
  [11, 13],
  [12],
  [21, 22, 113],
  [17, 19, 16, 20],
  [17, 15, 20],
  [15, 16, 19, 20, 18],
  [19, 17],
  [15, 17, 18],
  [16, 17, 15],
  [14, 113],
  [113, 14],
  [25, 27, 26, 24],
  [28, 23, 25, 26],
  [23, 27, 28, 24, 26],
  [23, 27, 25, 24],
  [23, 25, 26],
  [24, 25],
  [],
  [34, 3

In [77]:
len(merge_dict1["MpointsA"])

0

In [72]:
np.vstack((viewer.layers['current merge points'].data.astype('int'),coms[merge_dict1["SOL"][0]]))

array([[ 11.        , 193.        , 363.        ],
       [ 11.        , 236.        , 384.        ],
       [ 10.86844453, 214.51415577, 369.21463607]])

In [68]:
coms[merge_dict1["SOL"][0]]

array([ 10.86844453, 214.51415577, 369.21463607])

In [76]:
clickN()

ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 0 and the array at index 1 has size 3

In [50]:
def show_comp(comp):
    try:
        viewer.layers.remove('component contour')
    except:
        pass
    try:
        viewer.layers.remove('neighbor contours')
    except:
        pass
    try:
        viewer.layers.remove('current merge points')
    except:
        pass
    try:
        viewer.layers.remove('all merge points')
    except:
        pass
     ##   
    plt.sca(ax11)
    plt.cla()
    plt.sca(ax12)
    plt.cla()
    ax11.set_xlabel('Frame')
    ax11.set_ylabel('Signal')
    ax11.tick_params(axis='y', labelcolor='green')
    ax12.tick_params(axis='y', labelcolor='darkorange')

    viewer.add_labels(spcomps_l[comp,...],name='component contour',color={0:'transparent',1:'red'})
    viewer.layers['component contour'].contour = 1

    nconts = []
    for j in merge_dict1['SOL_m'][comp]:
        nconts.append((j+1)*spcomps_l[j,...])
    viewer.add_labels(np.sum(np.array(nconts),axis=0),name='neighbor contours')
    viewer.layers['neighbor contours'].contour = 1

    #for j in add_r.tolist():
    #    add_r_l[spcomps_l == j+1] = j+1
    #viewer.add_labels(add_r_l,name='neighbor contours')
    #viewer.layers['neighbor contours'].contour = 1
    
    viewer.camera.center = coms[comp]
    viewer.dims.set_point(1,coms[comp][0])
    viewer.camera.zoom = 3

    viewer.add_points(data=merge_dict1["MpointsC"][merge_dict1["SOL"][0]],name='current merge points',ndim=3,face_color='lime',size=6,out_of_slice_display=True)
    viewer.add_points(data=merge_dict1["MpointsA"],name='all merge points',ndim=3,face_color='red',size=6,out_of_slice_display=True,visible=False)

    ax11.plot(CY[comp,:].T/100,c='green',label='component')
    ax12.plot(CY[merge_dict1["SOL_m"][comp],:].T/100,c='darkorange',ls='--',label='neighbors')
    global leg1
    leg1.remove()
    leg1 = fig1.legend(loc="upper left")
    fig1.canvas.draw_idle()

In [29]:
len(merge_dict1["SOL_m"])

115

In [51]:
show_comp(merge_dict1["SOL"][0])

  TYPE_CHECKING,


In [34]:
coms[merge_dict1["SOL"][0]]

array([ 10.86844453, 214.51415577, 369.21463607])

In [38]:
spcomps_l[0,...].shape

(25, 433, 495)

In [47]:
len(nconts)

NameError: name 'nconts' is not defined

In [65]:
plt.imshow(0,facecolor=viewer.layers['neighbor contours'].colormap.colors[2])

AttributeError: AxesImage.set() got an unexpected keyword argument 'facecolor'

In [63]:
merge_dict1['SOL_m'][merge_dict1["SOL"][0]]

[2, 86, 3]