# Online processing of volumetric data
This is a simple demo on simulated toy 3d data for motion correction, source extraction and deconvolution comparing CaImAn batch with CaImAn online (OnACID).

In [None]:
try:
    ipython().run_line_magic('load_ext', 'autoreload')
    ipython().run_line_magic('autoreload', 2)
except:
    pass

import logging
import matplotlib.pyplot as plt
import numpy as np
import os
from scipy.stats.qmc import Halton

import caiman as cm
from caiman.utils.visualization import nb_view_patches3d
import caiman.source_extraction.cnmf as cnmf
from caiman.source_extraction.cnmf.utilities import gaussian_filter

import bokeh.plotting as bpl
bpl.output_notebook()

logfile = None # Replace with a path if you want to log to a file
logger = logging.getLogger('caiman')
# Set to logging.INFO if you want much output, potentially much more output
logger.setLevel(logging.WARNING)
logfmt = logging.Formatter('%(relativeCreated)12d [%(filename)s:%(funcName)20s():%(lineno)s] [%(process)d] %(message)s')
if logfile is not None:
    handler = logging.FileHandler(logfile)
else:
    handler = logging.StreamHandler()
handler.setFormatter(logfmt)
logger.addHandler(handler)


Define a function to create some toy data

In [None]:
def gen_data(p=1, noise=.05, T=500, framerate=30, firerate=2., motion=True, init_batch=200):
    if p == 2:
        gamma = np.array([1.5, -.55])
    elif p == 1:
        gamma = np.array([.9])
    else:
        raise
    dims = (70, 50, 10)  # size of image
    sig = (4, 4, 2)      # neurons size
    bkgrd = 1.           # background magnitude
    N = 20               # number of neurons
    np.random.seed(42)
    centers = np.round(np.array(sig) + (np.array(dims)-2*np.array(sig)) * 
                       Halton(d=3, scramble=False).random(n=N)).astype(int)
    
    S = np.random.rand(N, T) < firerate / float(framerate)
    S[:, 0] = 0
    S[N//2:,:init_batch] = 0 # half of the neurons aren't active in the initial batch
    C = S.astype(np.float32)
    for i in range(2, T):
        if p == 2:
            C[:, i] += gamma[0] * C[:, i - 1] + gamma[1] * C[:, i - 2]
        else:
            C[:, i] += gamma[0] * C[:, i - 1]
            
    if motion:
        sig_m = np.array(sig)
        shifts = -np.transpose([np.convolve(np.random.randn(T-10), np.ones(11)/11*s) for s in sig_m])
    else:
        sig_m = np.zeros(3, dtype=int)
        shifts = None
        
    A = np.zeros(tuple(np.array(dims) + sig_m * 4) + (N,), dtype='float32')
    for i in range(N):
        A[tuple(centers[i] + sig_m*2) + (i,)] = 1.
    A = gaussian_filter(A, sig + (0,), truncate=1.5)
    A /= np.sqrt(np.sum(np.sum(np.sum(A**2,0),0),0))  
    f = np.ones(T, dtype='float32')
    b = bkgrd * np.ones(A.shape[:-1], dtype='float32')  

    Yr = np.outer(b.reshape(-1, order='F'), f) + A.reshape((-1, N), order='F').dot(C)
    Yr += noise * np.random.randn(*Yr.shape)
    Y = Yr.T.reshape((-1,) + tuple(np.array(dims) + sig_m * 4), order='F').astype(np.float32)
    if motion:
        Y = np.array([cm.motion_correction.apply_shifts_dft(img, (sh[0], sh[1], sh[2]), 0,
                                                            is_freq=False, border_nan='copy')
                           for img, sh in zip(Y, -shifts)])
        Y = Y[:, 2*sig_m[0]:-2*sig_m[0], 2*sig_m[1]:-2*sig_m[1], 2*sig_m[2]:-2*sig_m[2]]
        A = A[2*sig_m[0]:-2*sig_m[0], 2*sig_m[1]:-2*sig_m[1], 2*sig_m[2]:-2*sig_m[2]]
        b = b[2*sig_m[0]:-2*sig_m[0], 2*sig_m[1]:-2*sig_m[1], 2*sig_m[2]:-2*sig_m[2]]
    return Y, C, S, A.reshape((-1, N), order='F'), b.reshape(-1, order='F'), f, centers, dims, shifts

### Select file(s) to be processed
Create a file with a toy 3d dataset.

In [None]:
fname = os.path.join(cm.paths.caiman_datadir(), 'example_movies', 'demoMovie3D.nwb')
Y, C, S, A, b, f, centers, dims, shifts = gen_data()
cm.movie(Y).save(fname)
print(fname)
N, T = C.shape

In [None]:
plt.figure(figsize=(9,3))

plt.subplot(121)
plt.colorbar(plt.imshow(A.T.dot(A)))
plt.title('overlap of A')
np.max(A.T.dot(A)-np.eye(N))

plt.subplot(122)
plt.colorbar(plt.imshow(np.corrcoef(C)))
plt.title('correlation of C')
np.max(A.T.dot(A)-np.eye(N)), np.max(np.corrcoef(C)-np.eye(N))

### Inspect the data
First, view a max-projection of the correlation image

In [None]:
Y = cm.load(fname)
Cn = cm.local_correlations(Y, swap_dim=False)
d1, d2, d3 = dims
x, y = (int(1.2 * (d1 + d3)), int(1.2 * (d2 + d3)))
scale = 6/x
fig = plt.figure(figsize=(scale*x, scale*y))

axz = fig.add_axes([1-d1/x, 1-d2/y, d1/x, d2/y])
plt.imshow(Cn.max(2).T, cmap='gray')
plt.title('Max.proj. z')
plt.xlabel('x')
plt.ylabel('y')

axy = fig.add_axes([0, 1-d2/y, d3/x, d2/y])
plt.imshow(Cn.max(0), cmap='gray')
plt.title('Max.proj. x')
plt.xlabel('z')
plt.ylabel('y')

axx = fig.add_axes([1-d1/x, 0, d1/x, d3/y])
plt.imshow(Cn.max(1).T, cmap='gray')
plt.title('Max.proj. y')
plt.xlabel('x')
plt.ylabel('z');
plt.show()

## Play the movie (optional). 
This will require loading the movie in memory which in general is not needed by the pipeline. Displaying the movie uses the OpenCV library. Press `q` to close the video panel.

In [None]:
Y[...,5].play(magnification=2)

## Set parameters

In [None]:
params_dict = {'fnames': fname,               # filename(s) to be processed
               'fr': 30,                      # frame rate (Hz)
               'K': N,                        # (upper bound on) number of components
               'is3D': True,                  # flag for volumetric data
               'decay_time': 1,               # length of typical transient in seconds
               'gSig': (4, 4, 2),             # gaussian width of a 3D gaussian kernel, which approximates a neuron
               'p': 1,                        # order of the autoregressive system
               'nb': 1,                       # number of background components
               'only_init': False,            # whether to run only the initialization
               'normalize_init': False,       # whether to equalize the movies during initialization
               'motion_correct': True,        # flag for performing motion correction
               'max_shifts': (4, 4, 2),       # maximum allowed rigid shifts (in pixels)
               'nonneg_movie': False,         # flag for producing a non-negative movie
               'init_batch': 200,             # length of mini batch for initialization
               'init_method': 'cnmf',         # initialization method for initial batch
               'batch_update_suff_stat': True,# flag for updating sufficient statistics (used for updating shapes)
               'thresh_overlap': 0,           # space overlap threshold for screening new components
              }
opts = cnmf.params.CNMFParams(params_dict=params_dict)

## Run batch version for comparison

In [None]:
%%capture
#%% start a cluster for parallel processing (if a cluster already exists it will be closed and a new session will be opened)
if 'dview' in locals():
    cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(
    backend='multiprocessing', n_processes=None, single_thread=False)

In [None]:
# %% fit with batch object
cnmB = cnmf.CNMF(n_processes=n_processes, params=opts, dview=dview)
cnmB.fit_file(motion_correct=True)

In [None]:
# STOP CLUSTER
cm.stop_server(dview=dview)

### View the results
View components per plane

In [None]:
cnmB.estimates.nb_view_components_3d(image_type='max', dims=dims, axis=2, cmap='viridis');

### Compare with ground truth

In [None]:
def plot_A(cnm):
    order = list(map(np.argmax, np.corrcoef(A.T, cnm.estimates.A.T.toarray())[:N,N:]))
    plt.subplot(131)
    try:
        plt.imshow(cnm.estimates.A.T.toarray().reshape((-1,)+dims, order='F').max(0).max(-1))
    except:
        plt.imshow(np.array(cnm.estimates.A).T.reshape((-1,)+dims, order='F').max(0).max(-1))
    plt.title('inferred A')
    plt.subplot(132)
    plt.imshow(A.T.reshape((-1,)+dims, order='F').max(0).max(-1))
    plt.title('true A')
    plt.subplot(133)
    plt.imshow(Y.max(0).max(-1))
    plt.title('max Y projection');

    plt.figure(figsize=(5,3))
    overlap = cnm.estimates.A.T[order].dot(A)
    plt.colorbar(plt.imshow(overlap))
    plt.title('overlap')
    plt.show()
    overlap = overlap.diagonal()
    print(f'Overlap of neural shapes   Min: {overlap.min():.4f},  Mean: {overlap.mean():.4f},  Max: {overlap.max():.4f}')
    
plot_A(cnmB)

In [None]:
def plot_C(cnm):
    order = list(map(np.argmax, np.corrcoef(C, cnm.estimates.C)[:N,N:]))
    if len(order) != len(tuple(order)):
        raise 

    plt.figure(figsize=(12,5))
    plt.subplot(211)
    plt.plot(cnm.estimates.C[order].T)
    plt.title('inferred C')
    plt.subplot(212)
    plt.plot(C.T)
    plt.title('true C')

    plt.figure(figsize=(5,3))
    corr = np.corrcoef(C, cnm.estimates.C[order])[:N,N:]
    plt.colorbar(plt.imshow(corr))
    plt.title('correlation')
    plt.show()
    corr = corr.diagonal()
    print(f'Correlation of (denoised) fluor. C   Min: {corr.min():.4f},  Mean: {corr.mean():.4f},  Max: {corr.max():.4f}')

plot_C(cnmB)

In [None]:
def plot_shifts(cnm):
    plt.figure(figsize=(12,5))
    plt.subplot(211)
    if cnm.params.motion['pw_rigid']:
        if len(cnm.estimates.shifts)==T:
            est_shifts = np.array(cnm.estimates.shifts)
        else:
            est_shifts = np.transpose(cnm.estimates.shifts, (1,2,0))
        plt.plot(est_shifts[:,0])
        print('Correlation with true shifts  ', np.corrcoef(
            np.transpose(shifts), est_shifts.T[:,0])[:3,3:].diagonal())
    else:
        plt.plot(cnm.estimates.shifts)
        print('Correlation with true shifts  ', np.corrcoef(
            np.transpose(shifts), np.transpose(cnm.estimates.shifts))[:3,3:].diagonal())
    plt.title('inferred shifts')
    plt.ylabel('pixels')
    plt.subplot(212)
    for k in (0,1,2):
        plt.plot(np.array(shifts)[:,k], label=('x','y','z')[k])
    plt.legend()
    plt.title('true shifts')
    plt.xlabel('frames')
    plt.ylabel('pixels')
    
plot_shifts(cnmB)

## Run online version

In [None]:
# only half of the neurons are active in the initial batch
params_dict['K'] = N//2

In [None]:
# %% fit with online object
opts = cnmf.params.CNMFParams(params_dict=params_dict)
cnmO = cnmf.online_cnmf.OnACID(params=opts)
cnmO.fit_online();

### View the results
View components per plane

In [None]:
cnmO.estimates.nb_view_components_3d(image_type='max', dims=dims, axis=2, cmap='viridis');

### Compare with ground truth

In [None]:
plot_A(cnmO)

In [None]:
plot_C(cnmO)

In [None]:
plot_shifts(cnmO)