# Motion correction and cells detection
---

Post-synaptic neurons with rGeco or rGeco

In [None]:
import cv2
try:
    cv2.setNumThreads(8)
except():
    pass

import os
import sys
import glob
import logging
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import caiman as cm
from caiman.motion_correction import MotionCorrect
from caiman.source_extraction.cnmf import cnmf as cnmf
from caiman.source_extraction.cnmf import params as params
from caiman.source_extraction.cnmf import utilities as util

from skimage.util import montage
from skimage.filters import rank
from skimage import morphology
from skimage import exposure
from skimage import measure

from scipy.ndimage import measurements

import bokeh.plotting as bpl
import holoviews as hv
bpl.output_notebook()
hv.notebook_extension('bokeh')

# Parameters

#### Input files

In [None]:
samp_name = 'A0005'
samp_path = os.path.join(''.join(sys.path[0].split('neuro')), 'data_neuro', samp_name)

reg_name = 'post_rgeco'
reg_path = f'{samp_path}/{reg_name}.tif'

reg_memmap = f'{samp_path}/post_rgeco_d1_320_d2_320_d3_1_order_C_frames_1091.mmap'
reg_fit = f'{samp_path}/{reg_name}_fit.hdf5'

#### CaImAn parameters

In [None]:
# data params
file_path = [reg_path]
fr = 1                      # imaging rate in frames per second
decay_time = 3              # length of a typical transient in seconds (see source/Getting_Started.rst)
dxy = (0.311, 0.311)        # spatial resolution of FOV in pixels per um

# patch params
rf =  100                    # half-size of the patches in pixels. e.g., if rf=25, patches are 50x50
stride = 50                 # amount of overlap between the patches in pixels

# pre-peocess params
noise_method = 'logmexp'     # PSD averaging method for computing the noise std
only_init = False

# motion correction params
max_deviation_rigid = 3     # maximum shifts deviation allowed for patch with respect to rigid shifts
max_shifts = (10, 10)       # maximum allowed rigid shifts (in pixels)
strides = (80, 80)          # start a new patch for pw-rigid motion correction every x pixels
overlaps = (20, 20)         # overlap between pathes (size of patch strides+overlaps)
pw_rigid = True             # flag for performing non-rigid motion correction

# init params
K = 2                           # number of components to be found (per patch or whole FOV depending on whether rf=None)
gSig = [10, 10]                    # radius of average spatial components (in pixels)
ssub = 3                         # spatial subsampling during initialization
tsub = 2                         # temporal subsampling during intialization
method_init = 'graph_nmf'       # initialization method ('sparse_nmf' NOT WORKING!),   'graph_nmf'
seed_method = 'auto'             # methods for choosing seed pixels during greedy_roi or corr_pnr initialization

# merge params
merge_thr = 0.2                 # trace correlation threshold for merging two components.
merge_parallel = False           # perform merging in parallel

# spatial and temporal params
nb = 3                           # number of global background components
method_deconvolution = 'oasis'   # method for solving the constrained deconvolution problem ('oasis','cvx' or 'cvxpy') if method cvxpy, primary and secondary (if problem unfeasible for approx solution)
noise_range = [0.25, 0.5]        # range of normalized frequencies over which to compute the PSD for noise determination
noise_method = 'logmexp'         # PSD averaging method for computing the noise std
p = 1                            # order of the autoregressive system

# quality params
min_SNR = 3                      # trace SNR threshold. Traces with SNR above this will get accepted
SNR_lowest = 0.1                   # minimum required trace SNR. Traces with SNR below this will get rejected
rval_thr = 0.5                  # space correlation threshold. Components with correlation higher than this will get accepted                 
rval_lowest = -2                 # minimum required space correlation. Components with correlation below this will get rejected
use_cnn = False                   # flag for using the CNN classifier


param_dict = {'fnames': file_path,
              'fr': fr,
              'decay_time': decay_time,
              'dxy': dxy,
              'rf': rf,
              'stride': stride,
              'noise_method': noise_method,
              'only_init': only_init,
              'max_deviation_rigid': max_deviation_rigid,
              'max_shifts': max_shifts,
              'strides': strides,
              'overlaps': overlaps,
              'pw_rigid': pw_rigid,
              'K': K,
              'gSig': gSig,
              'ssub': ssub,
              'tsub': tsub,
              'method_init': method_init,
              'seed_method': seed_method,
              'merge_thr': merge_thr,
              'merge_parallel': merge_parallel,
              'nb': nb,
              'method_deconvolution': method_deconvolution,
              'noise_range': noise_range,
              'noise_method': noise_method,
              'p': p,
              'min_SNR': min_SNR,
              'SNR_lowest': SNR_lowest,
              'rval_thr': rval_thr,
              'rval_lowest': rval_lowest,
              'use_cnn': use_cnn}

opts = params.CNMFParams(params_dict=param_dict)

In [None]:
print('a')

# Motion correction

In [None]:
# if True - Ca channel demostration on
display_movie = True
if display_movie:
    m_orig = cm.load_movie_chain(reg_path)
    ds_ratio = 0.31
    m_orig.resize(1, 1, ds_ratio).play(
        q_max=99.5, fr=50, magnification=1)

In [None]:
# start a cluster for parallel processing (if a cluster already exists it will be closed and a new session will be opened)
if 'dview' in locals():
    cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(
    backend='local', n_processes=None, single_thread=False)

mc = MotionCorrect(reg_path, dview=dview, **opts.get_group('motion'))

mc.motion_correct(save_movie=True)
m_els = cm.load(mc.fname_tot_els)
border_to_0 = 0 if mc.border_nan == 'copy' else mc.border_to_0 

In [None]:
# display results
display_movie = True
save_avi = True
if display_movie:
    m_orig = cm.load_movie_chain(reg_path)
    ds_ratio = 0.2
    cm.concatenate([m_orig.resize(1, 1, ds_ratio) - mc.min_mov*mc.nonneg_movie,
                    m_els.resize(1, 1, ds_ratio)], 
                   axis=2).play(fr=30, q_max=99.5, magnification=1, offset=0, save_movie=save_avi)

In [None]:
# memory map the file in order 'C' saving
reg_mc = cm.save_memmap(mc.mmap_file, base_name=f'{reg_name}_', order='C',
                        border_to_0=border_to_0) # exclude borders

# Fit & refit

In [None]:
if isinstance(reg_memmap, str):
    Yr, dims, T = cm.load_memmap(reg_memmap)
else:
    Yr, dims, T = cm.load_memmap(reg_mc)

reg_images = np.reshape(Yr.T, [T] + list(dims), order='F') 
Cn = cm.local_correlations(reg_images, swap_dim=False)
Cn[np.isnan(Cn)] = 0

plt.figure(figsize=(8, 8))
plt.imshow(Cn, cmap='magma')
plt.show()

## Fit section

#### Start/restart cluster

Start cluster

In [None]:
#%% start a cluster for parallel processing (if a cluster already exists it will be closed and a new session will be opened)
if 'dview' in locals():
    cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(backend='local', n_processes=None, single_thread=False)

Restart cluster

In [None]:
#%% restart cluster to clean up memory
cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(
    backend='local', n_processes=None, single_thread=False)

#### Fit

In [None]:
cnm = cnmf.CNMF(n_processes, params=opts, dview=dview)
if isinstance(reg_fit, str):
    cnm = cnmf.load_CNMF(reg_fit, n_processes=1, dview=dview)
else:
    cnm = cnm.fit(reg_images)
    save_results = True
    if save_results:
        cnm.save(f'{samp_path}/{reg_name}_fit.hdf5')

cnm.estimates.plot_contours_nb(img=Cn, cmap='magma')

## Refit section

#### Filtering

In [None]:
cnm.params.set('quality', {'use_cnn': False})
cnm.estimates.evaluate_components(reg_images, cnm.params, dview=dview)

# min_size = 100              # minimal component area in px
# max_size = 512^2             # maximal component area in px
# cnm.estimates.threshold_spatial_components(maxthr=0.5, dview=dview)
# cnm.estimates.remove_small_large_neurons(min_size_neuro=min_size, max_size_neuro=max_size)

cnm.estimates.plot_contours_nb(img=Cn, idx=cnm.estimates.idx_components, cmap='magma')

print(f'{len(cnm.estimates.idx_components)} good components: {cnm.estimates.idx_components}')
print(f'{len(cnm.estimates.idx_components_bad)} bad')

#### Refit
_NOT working_

In [None]:
cnm2 = cnm.refit(reg_images, dview=dview)

save_results = True
if save_results:
    cnm2.save(f'{samp_path}/{reg_name}_refit.hdf5')

In [None]:
cnm2.estimates.evaluate_components(reg_images, cnm.params, dview=dview)

cnm.estimates.plot_contours_nb(img=Cn, idx=cnm.estimates.idx_components, cmap='magma')
cnm2.estimates.plot_contours_nb(img=Cn, idx=cnm2.estimates.idx_components, cmap='magma')

#### Finalization

CNMF selection

In [None]:
fin_cnm = cnm
print(fin_cnm.estimates.idx_components)

# Plot & output

#### Plot func

In [None]:
def comp_contour_plot(samp_cnmf, samp_img):
    """ All spatial components (A) contours overlap ctrl img

    """
    A = samp_cnmf.estimates.A.toarray().reshape(samp_img.shape + (-1,), order='F').transpose([2, 0, 1])

    plt.figure(figsize=(10,10))
    plt.imshow(samp_img, cmap='magma')

    for i in samp_cnmf.estimates.idx_components: 
        A_frame = A[i]
        A_frame[A_frame != 0] = 1
        A_frame = np.ma.masked_where(A_frame == 0, A_frame, copy=False)
        A_center = measurements.center_of_mass(A_frame)
        A_contour = np.asarray(measure.find_contours(A_frame, level=0.5))

        plt.imshow(A_frame, cmap='jet', alpha=.5)
        plt.plot(A_contour[0][:, 1], A_contour[0][:, 0], linewidth=2)
        plt.annotate(f'ROI {i+1}',
                    (A_center[1], A_center[0]),
                    textcoords="offset points",
                    xytext=(2,2),
                    ha='center',
                    color='white',
                    weight='bold',
                    fontsize=10)
    plt.axis('off')
    plt.show()


def contour_grid_plot(samp_cnmf, samp_img):
    A = samp_cnmf.estimates.A.toarray().reshape(samp_img.shape + (-1,), order='F').transpose([2, 0, 1])

    w = 20
    h = 20
    fig = plt.figure(figsize=(15, 15))
    columns = 4
    rows = 4

    for i in samp_cnmf.estimates.idx_components: 
        A_frame = A[i]
        A_frame[A_frame != 0] = 1
        A_frame = np.ma.masked_where(A_frame == 0, A_frame, copy=False)
        A_contour = np.asarray(measure.find_contours(A_frame, level=0.5))

    # for i in range(1, columns*rows +1):
        img = np.random.randint(10, size=(h,w))
        fig.add_subplot(rows, columns, i+1)
        plt.imshow(samp_img, cmap='magma')
        for contour in A_contour:
            plt.plot(contour[:, 1], contour[:, 0], linewidth=1, color='r')
        # plt.title.set_title(f'ROI {i+1}')
        plt.title(f'ROI {i+1}')
        plt.axis('off')
    plt.show()


def comp_df_plot(samp_cnmf, y_shift=0.5):
    plt.figure(figsize=(20, 8))

    shift = 0
    for i in samp_cnmf.estimates.idx_components:
        df_prof = samp_cnmf.estimates.F_dff[i]
        plt.plot(df_prof+shift, alpha=.5, label=f'ROI {i+1}')
        shift -= y_shift

    plt.vlines(x=[-20], ymin=[-0.2], ymax=[0.8], linewidth=3, color='k')
    plt.text(x=-60, y=-0.5, s="100% ΔF/F", size=15, rotation=90.)
    plt.axis('off')
    plt.legend(loc=1)
    plt.show()

#### ΔF/F calc & ctrl plot

In [None]:
fin_cnm.estimates.detrend_df_f(quantileMin=5, frames_window=500,
                            flag_auto=True, use_fast=False, detrend_only=False)

In [None]:
contour_grid_plot(fin_cnm, Cn)
comp_df_plot(fin_cnm)