# Motion correction and cells detection
---

Pre-synaptic axonal terminals with GCamp5f

In [None]:
import cv2
try:
    cv2.setNumThreads(8)
except():
    pass

import os
import sys
import glob
import yaml

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import caiman as cm
from caiman.motion_correction import MotionCorrect
from caiman.source_extraction.cnmf import cnmf as cnmf
from caiman.source_extraction.cnmf import params as params
from caiman.source_extraction.cnmf import utilities as util

from skimage.util import montage
from skimage.filters import rank
from skimage import morphology
from skimage import exposure
from skimage import measure
from skimage import io

from scipy.ndimage import measurements

import bokeh.plotting as bpl
import holoviews as hv
bpl.output_notebook()
hv.notebook_extension('bokeh')

# Parameters

#### Input files path

In [None]:
samp_name = 'A0005'
samp_path = os.path.join(''.join(sys.path[0].split('neuro')), 'data_neuro', samp_name)

# sample YAML metadata file uploading
with open(f'{samp_path}/{samp_name}_meta.yaml') as f:
    samp_meta = yaml.safe_load(f)

reg_name = 'pre_gcamp'
reg_path = f'{samp_path}/{reg_name}.tif'

reg_memmap = f'{samp_path}/pre_gcamp_d1_320_d2_320_d3_1_order_C_frames_1091.mmap'
reg_fit = f'{samp_path}/{reg_name}_fit.hdf5'
reg_refit = None  # f'{samp_path}/{reg_name}_refit.hdf5'

#### CaImAn parameters

In [None]:
# data params
file_path = [reg_path]
fr = 1                      # imaging rate in frames per second
decay_time = 3              # length of a typical transient in seconds (see source/Getting_Started.rst)
dxy = (0.311, 0.311)        # spatial resolution of FOV in pixels per um

# patch params
rf =  100                    # half-size of the patches in pixels. e.g., if rf=25, patches are 50x50
stride = 50                 # amount of overlap between the patches in pixels

# pre-peocess params
noise_method = 'logmexp'     # PSD averaging method for computing the noise std
only_init = False

# motion correction params
max_deviation_rigid = 3     # maximum shifts deviation allowed for patch with respect to rigid shifts
max_shifts = (10, 10)       # maximum allowed rigid shifts (in pixels)
strides = (80, 80)          # start a new patch for pw-rigid motion correction every x pixels
overlaps = (20, 20)         # overlap between pathes (size of patch strides+overlaps)
pw_rigid = True             # flag for performing non-rigid motion correction

# init params
K = 10                           # number of components to be found (per patch or whole FOV depending on whether rf=None)
gSig = [2, 2]                    # radius of average spatial components (in pixels)
ssub = 4                         # spatial subsampling during initialization
tsub = 2                         # temporal subsampling during intialization
method_init = 'graph_nmf'       # initialization method ('sparse_nmf' NOT WORKING!),   'graph_nmf'
seed_method = 'auto'             # methods for choosing seed pixels during greedy_roi or corr_pnr initialization

# merge params
merge_thr = 0.75                 # trace correlation threshold for merging two components.
merge_parallel = False           # perform merging in parallel

# spatial and temporal params
nb = 2                           # number of global background components
method_deconvolution = 'oasis'   # method for solving the constrained deconvolution problem ('oasis','cvx' or 'cvxpy') if method cvxpy, primary and secondary (if problem unfeasible for approx solution)
noise_range = [0.25, 0.5]        # range of normalized frequencies over which to compute the PSD for noise determination
noise_method = 'logmexp'         # PSD averaging method for computing the noise std
p = 1                            # order of the autoregressive system

# quality params
quality_dict = {'min_SNR': 8,        # trace SNR threshold. Traces with SNR above this will get accepted
                'SNR_lowest': 7,     # minimum required trace SNR. Traces with SNR below this will get rejected
                'rval_thr': 0.4,     # space correlation threshold. Components with correlation higher than this will get accepted  
                'rval_lowest': -2,  # minimum required space correlation. Components with correlation below this will get rejected
                'use_cnn': False}    # flag for using the CNN classifier


param_dict = {'fnames': file_path,
              'fr': fr,
              'decay_time': decay_time,
              'dxy': dxy,
              'rf': rf,
              'stride': stride,
              'noise_method': noise_method,
              'only_init': only_init,
              'max_deviation_rigid': max_deviation_rigid,
              'max_shifts': max_shifts,
              'strides': strides,
              'overlaps': overlaps,
              'pw_rigid': pw_rigid,
              'K': K,
              'gSig': gSig,
              'ssub': ssub,
              'tsub': tsub,
              'method_init': method_init,
              'seed_method': seed_method,
              'merge_thr': merge_thr,
              'merge_parallel': merge_parallel,
              'nb': nb,
              'method_deconvolution': method_deconvolution,
              'noise_range': noise_range,
              'noise_method': noise_method,
              'p': p}

param_dict.update(quality_dict)
opts = params.CNMFParams(params_dict=param_dict)

In [None]:
print(opts.quality)

# Motion corection

In [None]:
# if True - Ca channel demostration on
display_movie = True
if display_movie:
    m_orig = cm.load_movie_chain(reg_path)
    ds_ratio = 0.31
    m_orig.resize(1, 1, ds_ratio).play(
        q_max=99.5, fr=50, magnification=1)

In [None]:
# start a cluster for parallel processing (if a cluster already exists it will be closed and a new session will be opened)
if 'dview' in locals():
    cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(
    backend='local', n_processes=None, single_thread=False)

mc = MotionCorrect(reg_path, dview=dview, **opts.get_group('motion'))

mc.motion_correct(save_movie=True)
m_els = cm.load(mc.fname_tot_els)
border_to_0 = 0 if mc.border_nan == 'copy' else mc.border_to_0 

In [None]:
# display results
display_movie = True
save_avi = True
if display_movie:
    m_orig = cm.load_movie_chain(reg_path)
    ds_ratio = 0.2
    cm.concatenate([m_orig.resize(1, 1, ds_ratio) - mc.min_mov*mc.nonneg_movie,
                    m_els.resize(1, 1, ds_ratio)], 
                   axis=2).play(fr=30, q_max=99.5, magnification=1, offset=0, save_movie=save_avi)

In [None]:
# memory map the file in order 'C' saving
reg_mc = cm.save_memmap(mc.mmap_file, base_name=f'{reg_name}_', order='C',
                        border_to_0=border_to_0) # exclude borders

# Fit & refit

In [None]:
if isinstance(reg_memmap, str):
    Yr, dims, T = cm.load_memmap(reg_memmap)
else:
    Yr, dims, T = cm.load_memmap(reg_mc)

reg_images = np.reshape(Yr.T, [T] + list(dims), order='F') 
Cn = cm.local_correlations(reg_images, swap_dim=False)
Cn[np.isnan(Cn)] = 0

plt.figure(figsize=(8, 8))
plt.imshow(Cn, cmap='magma')
plt.show()

save_reg_img = True
if save_reg_img:
    io.imsave(f'{samp_path}/{reg_name}_mov_cor.tif', reg_images)

## Fit section

#### Start/restart cluster

Start cluster

In [None]:
#%% start a cluster for parallel processing (if a cluster already exists it will be closed and a new session will be opened)
if 'dview' in locals():
    cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(
    backend='local', n_processes=None, single_thread=False)

Restart cluster

In [None]:
#%% restart cluster to clean up memory
cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(
    backend='local', n_processes=None, single_thread=False)

#### Fit

In [None]:
cnm = cnmf.CNMF(n_processes, params=opts, dview=dview)
if isinstance(reg_fit, str):
    cnm = cnmf.load_CNMF(reg_fit, n_processes=1, dview=dview)
else:
    cnm = cnm.fit(reg_images)
    save_results = True
    if save_results:
        cnm.save(f'{samp_path}/{reg_name}_fit.hdf5')

cnm.estimates.plot_contours_nb(img=Cn, cmap='magma')
cnm.estimates.nb_view_components(img=Cn, idx=cnm.estimates.idx_components,cmap='magma')

#### Filtering

In [None]:
cnm.params.set('quality', quality_dict)
print(cnm.params.quality)

cnm.estimates.evaluate_components(reg_images, cnm.params, dview=dview)

cnm.estimates.threshold_spatial_components(maxthr=0.3, dview=dview)
cnm.estimates.plot_contours_nb(img=Cn, idx=cnm.estimates.idx_components, cmap='magma')

print(f'{len(cnm.estimates.idx_components)} good components: {cnm.estimates.idx_components}')
print(f'{len(cnm.estimates.idx_components_bad)} bad')

cnm.estimates.nb_view_components(img=Cn, idx=cnm.estimates.idx_components)

In [None]:
cnm.estimates.nb_view_components(img=Cn, idx=cnm.estimates.idx_components_bad)

## Refit section

#### Refit

In [None]:
if isinstance(reg_refit, str):
    cnm2 = cnmf.load_CNMF(reg_refit, n_processes=1, dview=dview)
else:   
    cnm2 = cnm.refit(reg_images, dview=dview)
    save_results = True
    if save_results:
        cnm2.save(f'{samp_path}/{reg_name}_pre_refit.hdf5')

In [None]:
cnm2.estimates.evaluate_components(reg_images, cnm.params, dview=dview)

cnm.estimates.plot_contours_nb(img=Cn, idx=cnm.estimates.idx_components, cmap='magma')
cnm2.estimates.plot_contours_nb(img=Cn, idx=cnm2.estimates.idx_components, cmap='magma')

cnm2.estimates.nb_view_components(img=Cn, idx=cnm2.estimates.idx_components)

#### Finalization

CNMF selection

In [None]:
fin_cnm = cnm
print(fin_cnm.estimates.idx_components)

Restart cluster

In [None]:
#%% restart cluster to clean up memory
cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(
    backend='local', n_processes=None, single_thread=False)

# Plot & output

#### Plot func

In [None]:
def contour_grid_plot(samp_cnmf, samp_img):
    A = samp_cnmf.estimates.A.toarray().reshape(samp_img.shape + (-1,), order='F').transpose([2, 0, 1])

    w = 20
    h = 20
    fig = plt.figure(figsize=(15, 15))
    columns = 4
    rows = 4

    for i in samp_cnmf.estimates.idx_components: 
        A_frame = A[i]
        A_frame[A_frame != 0] = 1
        A_frame = np.ma.masked_where(A_frame == 0, A_frame, copy=False)
        A_contour = np.asarray(measure.find_contours(A_frame, level=0.5))

    # for i in range(1, columns*rows +1):
        img = np.random.randint(10, size=(h,w))
        fig.add_subplot(rows, columns, i+1)
        plt.imshow(samp_img, cmap='magma')
        for contour in A_contour:
            plt.plot(contour[:, 1], contour[:, 0], linewidth=1, color='r')
        # plt.title.set_title(f'ROI {i+1}')
        plt.title(f'ROI {i+1}')
        plt.axis('off')
    plt.show()

def contour_set_save(samp_cnmf, samp_img,save_path):
    img_path = f'{save_path}/pre_contours'
    if not os.path.exists(img_path):
        os.makedirs(img_path)

    import matplotlib.colors
    cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", ["black","green"])

    A = samp_cnmf.estimates.A.toarray().reshape(samp_img.shape + (-1,), order='F').transpose([2, 0, 1])

    for i in samp_cnmf.estimates.idx_components: 
        A_frame = A[i]
        A_frame[A_frame != 0] = 1
        A_frame = np.ma.masked_where(A_frame == 0, A_frame, copy=False)
        A_contour = np.asarray(measure.find_contours(A_frame, level=0.5))

        plt.figure(figsize=(10, 10))
        plt.imshow(samp_img, cmap=cmap)
        for contour in A_contour:
            plt.plot(contour[:, 1], contour[:, 0], linewidth=1, color='r')
        plt.title(f'ROI {i+1}')
        plt.axis('off')
        plt.savefig(f'{img_path}/ROI_{i+1}.png')

def comp_contour_plot(samp_cnmf, samp_img, save_file=None):
    """ All spatial components (A) contours overlap ctrl img

    https://stackoverflow.com/questions/28779559/how-to-set-same-color-for-markers-and-lines-in-a-matplotlib-plot-loop

    """
    A = samp_cnmf.estimates.A.toarray().reshape(samp_img.shape + (-1,), order='F').transpose([2, 0, 1])

    plt.figure(figsize=(10,10))
    plt.imshow(samp_img, cmap='magma')
    ax = plt.gca()

    for i in samp_cnmf.estimates.idx_components: 
        A_frame = A[i]
        A_frame[A_frame != 0] = 1
        A_frame = np.ma.masked_where(A_frame == 0, A_frame, copy=False)
        A_center = measurements.center_of_mass(A_frame)
        A_contour = np.asarray(measure.find_contours(A_frame, level=0.5))

        plt.imshow(A_frame, cmap='jet', alpha=.5)
        color = next(ax._get_lines.prop_cycler)['color']
        for cont in A_contour:            
            plt.plot(cont[:, 1], cont[:, 0], linewidth=1.5, alpha=.75 ,color=color)
        plt.annotate(f'ROI {i+1}',
                    (A_center[1], A_center[0]),
                    textcoords="offset points",
                    xytext=(2,2),
                    ha='center',
                    color='white',
                    weight='bold',
                    fontsize=10)
    plt.axis('off')
    plt.tight_layout()

    if isinstance(save_file, str):
        plt.savefig(save_file, dpi=300)
    else:
        plt.show()

# def img_comp_save(samp_cnmf, samp_img,save_path):
#     img_path = f'{save_path}/pre_corr_imgs'
#     if not os.path.exists(img_path):
#         os.makedirs(img_path)

def dF_cascade_plot(samp_cnmf, y_shift=0.5, save_file=None):
    plt.figure(figsize=(20, 8))

    shift = 0
    for i in samp_cnmf.estimates.idx_components:
        df_prof = samp_cnmf.estimates.F_dff[i]
        plt.plot(df_prof+shift, alpha=.5, label=f'ROI {i+1}')
        shift -= y_shift

    plt.vlines(x=[-20], ymin=[-0.2], ymax=[0.8], linewidth=3, color='k')
    plt.text(x=-60, y=-0.5, s="100% ΔF/F", size=15, rotation=90.)
    plt.axis('off')
    plt.legend(loc=1)
    plt.tight_layout()

    if isinstance(save_file, str):
        plt.savefig(save_file, dpi=300)
    else:
        plt.show()

#### ΔF/F calc & ctrl plot

In [None]:
fin_cnm.estimates.detrend_df_f(quantileMin=5, frames_window=500,
                            flag_auto=True, use_fast=False, detrend_only=False)

In [None]:
# contour_grid_plot(fin_cnm, Cn)
comp_contour_plot(fin_cnm, Cn, save_file=f'{samp_path}/{samp_name}_comp_contours.png')
dF_cascade_plot(fin_cnm, save_file=f'{samp_path}/{samp_name}_dF_profiles.png')

#### Custom plot

In [None]:
# video output

def contour_mov_save(samp_cnmf, samp_img, comp_i, save_path):
    # img_path = f'{save_path}/pre_contours'
    # if not os.path.exists(img_path):
    #     os.makedirs(img_path)

    import matplotlib.colors
    from skimage import exposure
    import matplotlib.cm as cm
    import matplotlib.animation as animation

    # samp_img = samp_img[:100]

    cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", ["black","green"])
    samp_img = exposure.equalize_adapthist(samp_img.astype(int), clip_limit=0.25)
    v_min, v_max = np.min(samp_img), np.max(samp_img)

    A = samp_cnmf.estimates.A.toarray().reshape(samp_img.shape[1:] + (-1,), order='F').transpose([2, 0, 1])
    A_frame = A[comp_i]
    A_frame[A_frame != 0] = 1
    A_frame = np.ma.masked_where(A_frame == 0, A_frame, copy=False)
    A_contour = np.asarray(measure.find_contours(A_frame, level=0.5))

    frames = [] # for storing the generated images
    fig = plt.figure(figsize=(10, 10))
    for i in range(len(samp_img)):
        img = plt.imshow(samp_img[i], cmap=cmap, animated=True, vmin=v_min, vmax=v_max)
        plt.title(f'ROI {comp_i+1}')
        plt.axis('off')
        for contour in A_contour:
            plt.plot(contour[:, 1], contour[:, 0], linewidth=1, color='r', alpha=.75)
        frames.append([img])
    ani = animation.ArtistAnimation(fig, frames, interval=50, blit=True,
                                    repeat_delay=1000)
    ani.save(f'{save_path}/{samp_name}_{reg_name}_ROI{comp_i+1}.mp4')
    # plt.show()
        # plt.figure(figsize=(10, 10))
        # plt.imshow(samp_img, cmap=cmap)
        # for contour in A_contour:
        #     plt.plot(contour[:, 1], contour[:, 0], linewidth=1, color='r')
        # plt.title(f'ROI {i+1}')
        # plt.axis('off')
        # plt.savefig(f'{img_path}/ROI_{i+1}.png')

contour_mov_save(fin_cnm, reg_images, 10, samp_path)

# # OpenCV
# size = 320, 320
# duration = 2
# fps = 25
# out = cv2.VideoWriter(f'{samp_path}/{samp_name}_{reg_name}.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (size[1], size[0]), False)

# eq_images = exposure.equalize_adapthist(reg_images.astype(int), clip_limit=0.03)
# eq_images[eq_images > np.min(eq_images)/0.1] = np.min(eq_images)

# ratio = np.amax(eq_images) / 256
# eq_images = (eq_images/ratio).astype('uint8')

# for frame in eq_images:
#     out.write(frame.astype('uint8'))
# out.release()

# contour_set_save(fin_cnm, exposure.equalize_adapthist(img_ctrl_green, clip_limit=0.025), samp_path)

#### Output CSV saving

In [None]:
def save_prof_df(samp_cnmf, samp_img, samp_name, reg_time, save_path):
    A = samp_cnmf.estimates.A.toarray().reshape(samp_img.shape[1:] + (-1,), order='F').transpose([2, 0, 1])

    # init df
    output_df = pd.DataFrame(columns=['reg_name',      # registration name
                                      'indx',         # frame index
                                      'time',          # registration time
                                      'comp',          # component num
                                      'profile_raw',   # component raw value, total mean
                                      'profile_C',     # component denoised value
                                      'profile_ddf'])   # component detrended ΔF/F value
    
    frame_num = samp_cnmf.estimates.C.shape[1]
    i_col = range(frame_num)
    time_col = np.linspace(0, reg_time, num=frame_num)
    reg_name_col = np.full(frame_num, samp_name)

    for component_num in samp_cnmf.estimates.idx_components:
        component_col = np.full(samp_img.shape[0], component_num+1)
        
        A_frame = A[component_num]
        A_mask = np.copy(A_frame)
        A_mask != 0
        A_mask = np.array(A_mask, dtype=bool)

        # mean by spatial component mask
        est_raw = np.asarray([np.mean(np.ma.masked_where(~A_mask, frame)) for frame in samp_img])

        # temporal component
        est_C = samp_cnmf.estimates.C[component_num]

        # detrended temporal component
        est_df = samp_cnmf.estimates.F_dff[component_num]
        
        component_df = pd.DataFrame({'reg_name':reg_name_col,
                                     'indx':i_col,
                                     'time':time_col,
                                     'comp':component_col,
                                     'profile_raw':est_raw,
                                     'profile_C':est_C,
                                     'profile_ddf':est_df})
        output_df = pd.concat([output_df, component_df], ignore_index=True)

    output_df.to_csv(f'{save_path}/{samp_name}_pre_comp_df.csv', index=False)
    print(output_df.head())

save_prof_df(samp_cnmf=fin_cnm,
             samp_img=reg_images,
             samp_name=samp_name,
             reg_time=samp_meta['Reg_time'],
             save_path=samp_path)

#### Stop cluster and clean up log files