# Motion correction only
---

In [None]:
import cv2
try:
    cv2.setNumThreads(8)
except():
    pass

import os
import sys
import glob
import yaml

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import caiman as cm
from caiman.motion_correction import MotionCorrect
from caiman.source_extraction.cnmf import cnmf as cnmf
from caiman.source_extraction.cnmf import params as params
from caiman.source_extraction.cnmf import utilities as util

from skimage.util import montage
from skimage.filters import rank
from skimage import morphology
from skimage import exposure
from skimage import measure
from skimage import io

from scipy.ndimage import measurements

import bokeh.plotting as bpl
import holoviews as hv
bpl.output_notebook()
hv.notebook_extension('bokeh')

# Parameters

#### Input files path

In [None]:
samp_name = 'A0013'
samp_type = 'pre'
samp_path = os.path.join(''.join(sys.path[0].split('neuro')), 'data_neuro', samp_name)

reg_path = f'{samp_path}/{samp_name}_{samp_type}.tif'
print(reg_path)

#### CaImAn parameters

In [None]:
# data params
file_path = [reg_path]
fr = 1                      # imaging rate in frames per second
decay_time = 3              # length of a typical transient in seconds (see source/Getting_Started.rst)
dxy = (0.311, 0.311)        # spatial resolution of FOV in pixels per um

# patch params
rf =  100                    # half-size of the patches in pixels. e.g., if rf=25, patches are 50x50
stride = 50                 # amount of overlap between the patches in pixels

# pre-peocess params
noise_method = 'logmexp'     # PSD averaging method for computing the noise std
only_init = False

# motion correction params
max_deviation_rigid = 3     # maximum shifts deviation allowed for patch with respect to rigid shifts
max_shifts = (10, 10)       # maximum allowed rigid shifts (in pixels)
strides = (80, 80)          # start a new patch for pw-rigid motion correction every x pixels
overlaps = (20, 20)         # overlap between pathes (size of patch strides+overlaps)
pw_rigid = True             # flag for performing non-rigid motion correction

# init params
K = 10                           # number of components to be found (per patch or whole FOV depending on whether rf=None)
gSig = [2, 2]                    # radius of average spatial components (in pixels)
ssub = 4                         # spatial subsampling during initialization
tsub = 2                         # temporal subsampling during intialization
method_init = 'graph_nmf'       # initialization method ('sparse_nmf' NOT WORKING!),   'graph_nmf'
seed_method = 'auto'             # methods for choosing seed pixels during greedy_roi or corr_pnr initialization

# merge params
merge_thr = 0.75                 # trace correlation threshold for merging two components.
merge_parallel = False           # perform merging in parallel

# spatial and temporal params
nb = 2                           # number of global background components
method_deconvolution = 'oasis'   # method for solving the constrained deconvolution problem ('oasis','cvx' or 'cvxpy') if method cvxpy, primary and secondary (if problem unfeasible for approx solution)
noise_range = [0.25, 0.5]        # range of normalized frequencies over which to compute the PSD for noise determination
noise_method = 'logmexp'         # PSD averaging method for computing the noise std
p = 1                            # order of the autoregressive system

# quality params
quality_dict = {'min_SNR': 8,        # trace SNR threshold. Traces with SNR above this will get accepted
                'SNR_lowest': 7,     # minimum required trace SNR. Traces with SNR below this will get rejected
                'rval_thr': 0.4,     # space correlation threshold. Components with correlation higher than this will get accepted  
                'rval_lowest': -2,  # minimum required space correlation. Components with correlation below this will get rejected
                'use_cnn': False}    # flag for using the CNN classifier


param_dict = {'fnames': file_path,
              'fr': fr,
              'decay_time': decay_time,
              'dxy': dxy,
              'rf': rf,
              'stride': stride,
              'noise_method': noise_method,
              'only_init': only_init,
              'max_deviation_rigid': max_deviation_rigid,
              'max_shifts': max_shifts,
              'strides': strides,
              'overlaps': overlaps,
              'pw_rigid': pw_rigid,
              'K': K,
              'gSig': gSig,
              'ssub': ssub,
              'tsub': tsub,
              'method_init': method_init,
              'seed_method': seed_method,
              'merge_thr': merge_thr,
              'merge_parallel': merge_parallel,
              'nb': nb,
              'method_deconvolution': method_deconvolution,
              'noise_range': noise_range,
              'noise_method': noise_method,
              'p': p}

param_dict.update(quality_dict)
opts = params.CNMFParams(params_dict=param_dict)

# Motion corection

In [None]:
# if True - Ca channel demostration on
display_movie = False
if display_movie:
    m_orig = cm.load_movie_chain(reg_path)
    ds_ratio = 0.31
    m_orig.resize(1, 1, ds_ratio).play(
        q_max=99.5, fr=50, magnification=1)

In [None]:
# start a cluster for parallel processing (if a cluster already exists it will be closed and a new session will be opened)
if 'dview' in locals():
    cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(
    backend='local', n_processes=None, single_thread=False)

mc = MotionCorrect(reg_path, dview=dview, **opts.get_group('motion'))

mc.motion_correct(save_movie=True)
m_els = cm.load(mc.fname_tot_els)
border_to_0 = 0 if mc.border_nan == 'copy' else mc.border_to_0 

In [None]:
# memory map the file in order 'C' saving
reg_mc = cm.save_memmap(mc.mmap_file, base_name=f'{samp_name}_{samp_type}_', order='C',
                        border_to_0=border_to_0) # exclude borders

Yr, dims, T = cm.load_memmap(reg_mc)

reg_images = np.reshape(Yr.T, [T] + list(dims), order='F') 
Cn = cm.local_correlations(reg_images, swap_dim=False)
Cn[np.isnan(Cn)] = 0

plt.figure(figsize=(8, 8))
plt.imshow(Cn, cmap='magma')
plt.show()

save_reg_img = True
if save_reg_img:
    io.imsave(f'{samp_path}/{samp_name}_{samp_type}_mov_cor.tif', reg_images)

In [None]:
#%% STOP CLUSTER and clean up log files
cm.stop_server(dview=dview)
log_files = glob.glob('*_LOG_*')
for log_file in log_files:
    os.remove(log_file)