# Mesmerize -- Holographic stimulation
### Adapted from Kushal lab "Mesmerize-core"
#### Updated March 16, 2023 - K.steinke


## Activate the virtual envirnoment for mesmerize-core.
### For instructions on how to set this up, see: https://github.com/EricThomson/CCN_caiman_mesmerize_workshop_2023

n/
n/

## Notebook set up

In [None]:
# libraries to load:
from mesmerize_core import *
import numpy as np
from copy import deepcopy
import pandas as pd

from fastplotlib import ImageWidget   #This is important for the visualization
from ipywidgets import VBox, IntSlider, Layout

pd.options.display.max_colwidth = 120   # some display options

## Set the paths

In [None]:
#it is important that you set the parent path. It needs to coorespond with the caiman_data folder.

set_parent_raw_data_path("/home/steinkki/caiman_data/")  #top level raw data directory

batch_path = get_parent_raw_data_path().joinpath("mesmerize-batch/batch.pickle")  #this is where the caiman outputs will be organized

movie_path = get_parent_raw_data_path().joinpath("example_movies/Sue_2x_3000_40_-46.tif") #this is the movie to analyze
#An input movie must be anywhere within raw data path or batch path!

## Create a new batch

In [None]:
# This creates a new pandas DataFrame with the columns that are necessary for mesmerize. 
# In mesmerize this is called the batch DataFrame. 
# You can add additional columns relevant to your experiment, but do not modify columns used by mesmerize.

# create a new batch
df = create_batch(batch_path)


# to load existing batches use `load_batch()`
# df = load_batch(batch_path)

## Motion Correction Parameters

In [None]:

'''
# This is set up with only one set of parameters currently: 

mcorr_params1 =\
{
  'main': # this key is necessary for specifying that these are the "main" params for the algorithm
    {
        'max_shifts': [24, 24],
        'strides': [48, 48],
        'overlaps': [24, 24],
        'max_deviation_rigid': 3,
        'border_nan': 'copy',
        'pw_rigid': True,
        'gSig_filt': None
    },
}

# Another set of params, useful for gridsearches for example:
mcorr_params2 =\
{
  'main':
    {
        'max_shifts': [24, 24],
        'strides': [24, 24],
        'overlaps': [12, 12],
        'max_deviation_rigid': 3,
        'border_nan': 'copy',
        'pw_rigid': True,
        'gSig_filt': None
    },
}

'''
#Use for loop for looking at multiple parameters to decide which to use:

# copy the mcorr_params2 dict to make some changes
#new_params = deepcopy(mcorr_params2)

new_params=\
{
  'main':
    {
        'max_shifts': [24, 24],
        'strides': [24, 24],
        'overlaps': [12, 12],
        'max_deviation_rigid': 3,
        'border_nan': 'copy',
        'pw_rigid': True,
        'gSig_filt': None
    },
}



## Add a Batch Item:

This is a combination of:
1. algorithm to run, algo
2. input movie to run the algorithm on, input_movie_path
3. parameters for the specified algorithm, params
4. a name for you to keep track of things, usually the same as the movie filename, item_name

In [None]:
'''
#%% FROM THE DEMO: %%#

# add an item to the batch
df.caiman.add_item(
    algo='mcorr',  #algorhythm to use
    input_movie_path=movie_path,  #path to the movie to analyze
    params=mcorr_params1,  #parameters to use
    item_name=movie_path.stem,  # filename of the movie, but can be anything
)

df  #view the dataframe

# add other param variant to the batch
df.caiman.add_item(
  algo='mcorr',
  item_name=movie_path.stem,
  input_movie_path=movie_path,
  params=mcorr_params2
)

df
'''
#try a couple of different parameters since we don't know what to use:

#use the for loop here to iterrate through a couple of different parameters.
for shifts in [1, 6, 12]: # some variants of max_shifts
    new_params = deepcopy(new_params)  #use deepcopy it is much safer
    new_params["main"]["max_shifts"] = (shifts, shifts) # assign the "max_shifts"
    
    df.caiman.add_item(
      algo='mcorr',  #algorhythm to use
      item_name=movie_path.stem, #set the item name
      input_movie_path=movie_path, #give the path to the movie
      params=new_params  #parameters to use
    )

df #view the data frame

In [None]:
#look at the unique parameters for each batch:
diffs = df.caiman.get_params_diffs(algo="mcorr", item_name=df.iloc[0]["item_name"])
diffs  #show the parameters

#batches are indexable:
row = df.iloc[0] # get the first batch item

#to see the docstring, look at the documentation for how to do this:
# https://mesmerize-core.readthedocs.io/en/latest/api/common.html#mesmerize_core.CaimanSeriesExtensions

## Run batches

In [None]:
'''
#%% FROM THE DEMO: %%#

# run the first "batch item" (will only run one).
# this will run in a subprocess by default on Linux & Mac
# on windows it will run locally
process = row.caiman.run()

# reload dataframe from disk when done
df = df.caiman.reload_from_disk()

'''
#try a loop to go through all the batches:

for i, row in df.iterrows():
   # if not i > 0: # skip the first item since we've run it already
   #     continue
    process = row.caiman.run()
    # on Windows you MUST reload the batch dataframe after every iteration because it uses the `local` backend.
    # this is unnecessary on Linux & Mac
    # "DummyProcess" is used for local backend so this is automatic
    if process.__class__.__name__ == "DummyProcess":
        df = df.caiman.reload_from_disk()
        
#loop worked

df = df.caiman.reload_from_disk() #ALWAYS RELOAD after running a loop!! 
df #make sure the outputs column as been filled in to make sure that it worked for all


#check to see if the algorhythm 
print("Did the algorhythm run successfully?")
df.iloc[0]["outputs"]["success"] # True if the algo ran succesfully

## Visualization with Fastplotlib

In [None]:
# get the movie and mcorr so you can look at them
# Note: tiff input files returns it as a memmaped array (if possible) with lazy loading
# Note: will try to use a mesmerize LazyArray if the file cannot be memmaped

index = 0  # you can change the index to look at the mcorr results of different batch items

input_movie = df.iloc[index].caiman.get_input_movie()  # get input movie as memmap

mcorr_movie = df.iloc[index].mcorr.get_output()  # load mcorr output movie, also as a memmaped array

#Visualize the mcorr movie and raw movie side by side:
mcorr_iw = ImageWidget(
    data=[input_movie, mcorr_movie], #two movies defined from above
    vmin_vmax_sliders=True, 
    cmap="gnuplot2"
)
mcorr_iw.show()  #look at the two movies side by side

#do frame averaging --> helps with visualling inspecting motion
mcorr_iw.window_funcs = {"t": (np.mean, 17)}  # window function on the "t" (time) dimension, using mean of 17 frames

In [None]:
mcorr_iw.plot.canvas.close()  #close the window to save processing power. 

In [None]:

## With ImageWidget you can view all 5 mcorr results simultaneously! ##
# Depends on the hard drive capabilities

#-----------------------------------------------

movies = [df.iloc[0].caiman.get_input_movie()]   # first item is just the raw movie
subplot_names = ["raw"]   # subplot titles
means = [df.iloc[0].caiman.get_projection("mean")]  # we will use the mean images later

# add all the mcorr outputs to the list
for i, row in df.iterrows():
    movies.append(row.mcorr.get_output())  # add to the list of movies to plot
    subplot_names.append(f"ix: {i}")  # subplot title to show dataframe index
    means.append(row.caiman.get_projection("mean"))  # mean images which we'll use later

# create the widget
mcorr_iw_multiple = ImageWidget(
    data=movies,  # list of movies
    window_funcs={"t": (np.mean, 17)}, # window_funcs is also a kwarg
    vmin_vmax_sliders=True,
    names=subplot_names,  # subplot names used for titles
    cmap="gnuplot2"
)

#-----------------------------------------------

mcorr_iw_multiple.show()

#-----------------------------------------------

# too look at the different parameters:
df.caiman.get_params_diffs(algo="mcorr", item_name=df.iloc[0]["item_name"])

#-----------------------------------------------

#if you want to modify the windows:
#mcorr_iw_multiple.window_funcs["t"].window_size = 5


In [None]:
# You can use frame apply to look for motion aberrations 
#first set up the code to subtract means from each movie:
subtract_means = {
    0: lambda x: x - means[0],
    1: lambda x: x - means[1],
    2: lambda x: x - means[2],
    3: lambda x: x - means[3],
    4: lambda x: x - means[4],
    5: lambda x: x - means[5]
}
mcorr_iw_multiple.frame_apply = subtract_means #apply the function
#Spits out 2D frame
#-----------------------------------------------
#Different colormaps can help with seeing the motion:
for sp in mcorr_iw_multiple.plot:
    sp.graphics[0].cmap = "jet"

In [None]:
#-----------------------------------------------
#remove the frame_apply:
mcorr_iw_multiple.frame_apply = dict()

## Dataframe clean up

In [None]:
# make a list of rows we want to keep using the uuids
rows_keep = [df.iloc[3].uuid]   #change to the index number that you want to keep
rows_keep

# You can remove batch items (i.e. rows) using df.caiman.remove_item(<item_uuid>)


#WINDOWS NOTE: calling remove_item() will raise a PermissionError if you have the memmap file open.
#WORK AROUND: Kill the kernel.

#loop for removing all the rows that you do not want:
for i, row in df.iterrows():
    if row.uuid not in rows_keep:
        df.caiman.remove_item(row.uuid)

df

# CNMF

In [None]:
# Enter parameters into main:

params_cnmf =\
{
    'main': # indicates that these are the "main" params for the CNMF algo
        {
            'fr': 30, # framerate, very important!
            'p': 1,
            'nb': 2,
            'merge_thr': 0.85,
            'rf': 15,
            'stride': 6, # "stride" for cnmf, "strides" for mcorr
            'K': 4,
            'gSig': [4, 4],
            'ssub': 1,
            'tsub': 1,
            'method_init': 'greedy_roi',
            'min_SNR': 2.0,
            'rval_thr': 0.7,
            'use_cnn': True,
            'min_cnn_thr': 0.8,
            'cnn_lowest': 0.1,
            'decay_time': 0.4,
        },
    'refit': True, # If `True`, run a second iteration of CNMF
}


# Add a batch item like before:
df.caiman.add_item(
    algo='cnmf', # algo is cnmf
    input_movie_path=df.iloc[0],  # use mcorr output from a completed batch item (this is for one only!)
    params=params_cnmf,   #parameters from above
    item_name=df.iloc[0]["item_name"], # use the same item name
)

## perform a parameter search to find the parameters to use for the dataset! ##

from itertools import product

# variants of several parameters
gSig_variants = [6, 8]
K_variants = [4, 8]
merge_thr_variants = [0.8, 0.95]

new_params_cnmf = deepcopy(params_cnmf)  # always use deepcopy like before
parameter_grid = product(gSig_variants, K_variants, merge_thr_variants)  # create a parameter grid

# a single for loop to go through all the various parameter combinations
for gSig, K, merge_thr in parameter_grid:
    new_params_cnmf = deepcopy(new_params_cnmf)
    
    new_params_cnmf["main"]["gSig"] = [gSig, gSig]
    new_params_cnmf["main"]["K"] = K
    new_params_cnmf["main"]["merge_thr"] = merge_thr
    
    # add param combination variant to batch
    df.caiman.add_item(
        algo="cnmf",
        item_name=df.iloc[0]["item_name"],
        input_movie_path=df.iloc[0],
        params=new_params_cnmf
    )
    
    
df  # if you want to view the generated batch items
df.caiman.get_params_diffs(algo="cnmf", item_name=df.iloc[1]["item_name"])   #just look at the unique diffs


## Run the batch items through CNMF

In [None]:
# if you want to filter the columns and rows
df[
    (df["algo"] == "cnmf") &  # algo
    (df["item_name"] == df.iloc[0]["item_name"])  # item name
]

In [None]:
# run only these items:

for i, row in df[
    (df["algo"] == "cnmf") &
    (df["item_name"] == df.iloc[0]["item_name"])
].iterrows():
    
    process = row.caiman.run()
    
    # on Windows you MUST reload the batch dataframe after every iteration because it uses the `local` backend.
    # this is unnecessary on Linux & Mac
    # "DummyProcess" is used for local backend so this is automatic
    if process.__class__.__name__ == "DummyProcess":
        df = load_batch(df.paths.get_batch_path())
        
#this made cnmf outputs:
df = df.caiman.reload_from_disk()
df[df["algo"] == "cnmf"]

# see which batch items completed succcessfully
df[df["algo"] == "cnmf"]["outputs"].apply(lambda x: x["success"])

# CNMF Visualization

## WARNING: Still in development, use with caution! 

In [None]:
plt.imshow[df.iloc[0].caiman.get_projection['mean']] #look at different projections, you can change out the 'mean' to look at different projections

## Load outputs
CNMF pandas extensions API: https://mesmerize-core.readthedocs.io/en/latest/api/cnmf.html

In [None]:
index = 1   # Change this to plot the outputs for different batch items

cnmf_movie = df.iloc[index].caiman.get_input_movie()  # get the motion corrected input movie as a memmap
contours, coms = df.iloc[index].cnmf.get_contours("all", swap_dim=False) # get the contours of the spatial components
#can pass many things here, read documentation.

temporal = df.iloc[index].cnmf.get_temporal("all")  # and temporal components

ixs_good = df.iloc[index].cnmf.get_good_components()   #for plotting these in different colors 
ixs_bad = df.iloc[index].cnmf.get_bad_components()

## Visualization

In [None]:
from fastplotlib.graphics.line_slider import LineSlider # LineSlider is very new and experimental

# for the image data and contours
iw_cnmf = ImageWidget(cnmf_movie, vmin_vmax_sliders=True, cmap="gnuplot2")

# add good contours to the plot within the widget
contours_graphic = iw_cnmf.plot.add_line_collection(contours, colors="cyan", name="contours")
contours_graphic[ixs_good].colors = "cyan"
contours_graphic[ixs_bad].colors = "magenta"


# temporal plot
plot_temporal = Plot()

temporal_graphic = plot_temporal.add_line_collection(temporal, colors="cyan", name="temporal")
temporal_graphic[ixs_good].colors = "cyan"
temporal_graphic[ixs_bad].colors = "magenta"

# a vertical line that is syncronized to the image widget "t" (timepoint) slider
_ls = LineSlider(x_pos=0, bounds=(temporal.min(), temporal.max()), slider=iw_cnmf.sliders["t"])
plot_temporal.add_graphic(_ls)

# stack them
VBox([plot_temporal.show(), iw_cnmf.show()])

# Auto-scale temporal plot, you can only do this after calling show()
plot_temporal.auto_scale()
plot_temporal.camera.scale.x = 0.85



In [None]:
#Define function
def euclidean(source, target, event, new_data):   # I am confused by this
    """maps click events to contour"""
    # calculate coms of line collection
    indices = np.array(event.pick_info["index"])
    
    coms = list()

    for contour in target.graphics:
        coors = contour.data()[~np.isnan(contour.data()).any(axis=1)]
        com = coors.mean(axis=0)
        coms.append(com)

    # euclidean distance to find closest index of com 
    indices = np.append(indices, [0])
    
    ix = int(np.linalg.norm((coms - indices), axis=1).argsort()[0])
    
    target._set_feature(feature="colors", new_data=new_data, indices=ix)
    
    return None

# so we can view them one by one, first hide all of them
temporal_graphic[:].present = False

image_graphic = iw_cnmf.plot["image"]

# link image to contours
image_graphic.link(
    "click",
    target=contours_graphic,
    feature="colors", 
    new_data="w", 
    callback=euclidean
)

# link contour color changes (which are triggered by the click events as defined above) to everything else

# thickness of contour
contours_graphic.link("colors", target=contours_graphic, feature="thickness", new_data=5)

# toggle temporal component when contour changes color
contours_graphic.link("colors", target=temporal_graphic, feature="present", new_data=True)
# autoscale temporal plot to the current temporal component
temporal_graphic[:].present.add_event_handler(plot_temporal.auto_scale)

In [None]:
#close canvases for slow GPU

plot_temporal.canvas.close()
iw_cnmf.plot.canvas.close()

## View the reconstructed movie, residuals, and reconstructed background
We can get each of these as a mesmerize LazyArray which allows fast visualization of larger-than-RAM arrays that can be computed on the fly.

In [None]:
# reconstructed movie, A * C
rcm = df.iloc[index].cnmf.get_rcm()
rcm

#rcm.shape   #behaves like numpy array, easy to work with
# rcm.nbytes_gb  #show the size of the array
#rcm.max #Some lazy arrays contain pre-computed min and max for the array, and other useful properties

# visualize 
gp = GridPlot((2, 2), controllers="sync")    #make a 2 by 2 array of images

for sp, img in zip(gp, [rcm.max_image, rcm.min_image, rcm.mean_image, rcm.std_image]):
    sp.add_image(img)
    
gp.show()

In [None]:
gp.canvas.close()

# Visualize RCM, RCB and Residuals

In [None]:
rcb = df.iloc[index].cnmf.get_rcb()
residuals = df.iloc[index].cnmf.get_residuals()

iw_cnmf_grid = ImageWidget(
    data=[cnmf_movie, rcm, rcb, residuals],
    vmin_vmax_sliders=True,
    cmap="gnuplot2",
    names=["movie", "A * C", "b * f", "residuals"]
)

for subplot in iw_cnmf_grid.plot:
    _contours = subplot.add_line_collection(contours, thickness=1.0, name="contours")
    _contours[ixs_good].colors = "cyan"
    _contours[ixs_bad].colors = "magenta"

iw_cnmf_grid.show()

'''
##options for looking at the data:##

# temporarily hide bad components
for subplot in iw_cnmf_grid.plot:
    subplot["contours"][ixs_bad].present = False
    
# hide good components
for subplot in iw_cnmf_grid.plot:
    subplot["contours"][ixs_good].present = False

# make everything un-hidden, indexing [:] means "everything"
for subplot in iw_cnmf_grid.plot:
    subplot["contours"][:].present = True

'''

In [None]:
iw_cnmf_grid.plot.canvas.close()   # close the canvas to save GPU

## Visualize movie, rcm, and stack of temporal components
This example shows only good components, but as shown before you can also compare between good and bad components if you want. You could also use a GridPlot or ImageWidget to view the contours on top of the residuals and reconstructed background to evaluate if CNMF captured everything.

In [None]:
# 1 row, 3 columns, sync the first 2 subplots plots
cnmf_grid_more = GridPlot((1, 3), controllers=[[0, 0, 1]], names=[["movie", "rcm", "temporal"]])

# movie and rcm, rcm is a lazy array and behaves similar to numpy arrays
movie_graphic = cnmf_grid_more["movie"].add_image(cnmf_movie[0], cmap="gnuplot2")
rcm_graphic = cnmf_grid_more["rcm"].add_image(rcm[0], cmap="gnuplot2")

# contours for good components
contours_good, coms = df.iloc[index].cnmf.get_contours("good", swap_dim=False)

# random colors for contours and temporal components
# make an RGBA array for each color
rand_colors = np.random.rand(len(contours_good), 4)  # [n_contours, RGBA]
rand_colors[:, -1] = 1 # set alpha = 1

# get temporal of only good components
temporal_good = df.iloc[index].cnmf.get_temporal("good")

# add contours to both movie and rcm subplots
contours_movie = cnmf_grid_more["movie"].add_line_collection(contours_good, colors=rand_colors)
contours_rcm = cnmf_grid_more["rcm"].add_line_collection(contours_good, colors=rand_colors)

# line stack of temporal components
temporal_stack = cnmf_grid_more["temporal"].add_line_stack(temporal_good, colors=rand_colors, thickness=3.0, separate=15)

# plot single temporal, just like before
plot_temporal_single = Plot()
temporal_graphic = plot_temporal_single.add_line_collection(temporal_good, colors=rand_colors)

# since this is a GridPlot and not an ImageWidget we need to define sliders
slider = IntSlider(min=0, max=cnmf_movie.shape[0] - 1, value=0, step=1)

# vertical line sliders
_ls = LineSlider(x_pos=0, bounds=(temporal.min(), temporal.max()), slider=slider)
_ls2 = LineSlider(x_pos=0, bounds=(temporal.min(), temporal.max() + temporal_stack.graphics[-1].position.y), slider=slider)
plot_temporal_single.add_graphic(_ls)
cnmf_grid_more["temporal"].add_graphic(_ls2)

# function to update each frame
def update_frame(change):
    ix = change["new"]
    movie_graphic.data = cnmf_movie[ix]
    rcm_graphic.data = rcm[ix]
    
slider.observe(update_frame, "value")

@plot_temporal_single.renderer.add_event_handler("resize")
def update_slider_width(*args):
    width, h = plot_temporal_single.renderer.logical_size
    slider.layout = Layout(width=f"{width}px")
    
VBox([plot_temporal_single.show(), cnmf_grid_more.show(), slider])

#### If you need to autoscale

In [None]:
plot_temporal_single.auto_scale()
plot_temporal_single.camera.scale.x = 0.85
cnmf_grid_more["temporal"].auto_scale()

### If you would like to make it interactive:

In [None]:
# so we can view them one by one, first hide all of them
temporal_graphic[:].present = False

# link image to contours
movie_graphic.link(
    "click",
    target=contours_movie,
    feature="colors", 
    new_data="w", 
    callback=euclidean  # we cam re-use it from before
)

# link image to contours
rcm_graphic.link(
    "click",
    target=contours_rcm,
    feature="colors", 
    new_data="w", 
    callback=euclidean  # we cam re-use it from before
)

# contours colors -> contour thickness
contours_movie.link("colors", target=contours_movie, feature="thickness", new_data=5)
contours_rcm.link("colors", target=contours_rcm, feature="thickness", new_data=5)

# contours_movie <-> contours_rcm
contours_rcm.link("colors", target=contours_movie, feature="colors", new_data="w", bidirectional=True)

# temporal stack events
temporal_stack.link("click", target=temporal_stack, feature="colors", new_data="w")
temporal_stack.link("colors", target=temporal_stack, feature="thickness", new_data=4)

# contours <-> temporal stack
contours_movie.link("colors", target=temporal_stack, feature="colors", new_data="w", bidirectional=True)

# temporal stack -> temporal single
temporal_stack.link("colors", target=temporal_graphic, feature="present", new_data=True)

# autoscale temporal plot to the current temporal component
temporal_graphic[:].present.add_event_handler(plot_temporal_single.auto_scale)

In [None]:
#Close the canvases

plot_temporal_single.canvas.close()
cnmf_grid_more.canvas.close()