# Multisession registration with CaImAn

This notebook will help to demonstrate how to use CaImAn on movies recorded in multiple sessions. CaImAn has in-built functions that align movies from two or more sessions and try to recognize components that are imaged in some or all of these recordings.

The basic function for this is `caiman.base.rois.register_ROIs()`. It takes two sets of spatial components and finds components present in both using an intersection over union metric and the Hungarian algorithm for optimal matching.
`caiman.base.rois.register_multisession()` takes a list of spatial components, aligns sessions 1 and 2, keeps the union of the matched and unmatched components to register it with session 3 and so on.

In [None]:
from IPython import get_ipython
from matplotlib import pyplot as plt
import numpy as np
import h5py
import glob
import pims_nd2
import pandas as pd
from collections import OrderedDict
import plotly.graph_objects as go
import plotly.io as pio
import seaborn as sns
import os
import scipy.stats as stats
from itertools import combinations

from caiman.base.rois import register_multisession
from caiman.utils import visualization
from caiman.source_extraction.cnmf import cnmf as cnmf
from caiman.utils.utils import download_demo

import sys
sys.path.append('..')
from placecode.cross_registration_functions import CellTrackingSingleAnimal as celltrack

try:
    if __IPYTHON__:
        get_ipython().run_line_magic('load_ext', 'autoreload')
        get_ipython().run_line_magic('autoreload', '2')
except NameError:
    pass

initializing the files for cross registration

In [None]:
files_list=[]
templates_list=[]
cnmf_list=[]
spatials=[]
dims=[]

for condition in conditions:

    #finding file paths
    fpath= glob.glob(f"{home_folder}/{animal}_{condition}/*.hdf5")[0]
    files_list.append(fpath)

    #sstoring cnfs
    cnmf_ind=cnmf.load_CNMF(fpath,'r')
    cnmf_list.append(cnmf_ind)
    #sotring spatial components
    spatial=cnmf_ind.estimates.A
    spatials.append(spatial)
    #storing dims
    dim=cnmf_ind.dims
    dims.append(dim)

    #template paths
    # template_path=glob.glob(f"{raw_data_folder}/*{condition}/*.nd2")[0]
    # nikon_movie=pims_nd2.ND2_Reader(template_path)
    # template=np.mean(nikon_movie[:600],axis=0)
    # templates_list.append(template)


In [None]:
from matplotlib.colors import LinearSegmentedColormap

In [None]:
# cmap = LinearSegmentedColormap.from_list('custom_cmap', ['black', 'red'], N=2)

# plt.imshow(template3,cmap=cmap)

## Use `register_multisession()`

The function `register_multisession()` requires 3 arguments:
- `A`: A list of ndarrays or scipy.sparse.csc matrices with (# pixels X # component ROIs) for each session
- `dims`: Dimensions of the FOV, needed to restore spatial components to a 2D image
- `templates`: List of ndarray matrices of size `dims`, template image of each session

In [None]:
spatial_union, assignments, matchings = register_multisession(A=spatials, dims=dims[0],max_dist=10)#, templates=templates)

finding cells that were cross registered only for two sessions
dictionary is based on the first group that I am comparing two

In [None]:
cr_reg_cells_two_ses=OrderedDict()
df=pd.DataFrame(data=assignments,columns=conditions)
for i,condition in enumerate(conditions[:-1]):
    df_per_sessions=df[[conditions[i],conditions[i+1]]]
    df_per_sessions=df_per_sessions.dropna().astype(int)
    cr_reg_cells_two_ses[condition]=df_per_sessions
    

In [None]:
# Find the rows where there are no NaN values
#meaning where are the cells in all of the days
cr_reg_cells = assignments[~np.isnan(assignments).any(axis=1)].astype(int)
cr_reg_cells_df=pd.DataFrame(data=cr_reg_cells,columns=conditions)


In [None]:
#find the number of cells that were place in all three conditions
output_file=f"D:/sd_project_pbox/results/{animal}"

#first store the place cells somewhere
pc_cells=OrderedDict()
non_pc_cells=OrderedDict()
silent_cells=OrderedDict()

for condition in conditions:
    output_df_path=f'{output_file}/{condition}/output_info.h5'

    pc_cells[condition]=celltrack.isolating_cell_type(output_df_path,'place_cells_tuned_vector')
    non_pc_cells[condition]=celltrack.isolating_cell_type(output_df_path,'non_place_cells_tuned_vector')
    silent_cells[condition]=celltrack.isolating_cell_type(output_df_path,'silent_cells')

find cross registered place cells between sessions 
eg baseline 1 to baseline 2
baseline 2 to 24h post etc

In [None]:
pc_cells_per_sessions=OrderedDict()

for i,condition in enumerate(conditions[:-1]):
    sessions=[condition,conditions[i+1]]
    df=cr_reg_cells_two_ses[condition]
    pc_cells_per_sessions[condition]=celltrack.filter_cross_registered_place_cells(df, sessions, pc_cells)


store the avergafe firing rate maps only of the cross registered place cells between sessions

In [None]:
#indetify only the cross registered place cells in each of the recordings
fr_rate_maps_to_compare=OrderedDict() #these are before indexing
fr_rate_maps_to_compare_indexed=OrderedDict() #these are after indexing in order to compare place cells firing

for j,condition in enumerate(conditions[:-1]):

    #isolating the average firing rate maps
    fr_rate_maps_to_compare[f'{condition}___{conditions[j+1]}']=[]
    fr_rate_maps_to_compare_indexed[f'{condition}___{conditions[j+1]}']=[]
    avr_fr_rate_map_init=h5py.File(f'{output_folder}/{condition}/output_info.h5')['avr_firing_rate_maps']
    avr_fr_rate_map_end=h5py.File(f'{output_folder}/{conditions[j+1]}/output_info.h5')['avr_firing_rate_maps']
    avr_fr_rate_map_init=pd.DataFrame(avr_fr_rate_map_init)
    avr_fr_rate_map_end=pd.DataFrame(avr_fr_rate_map_end)
    #avr firing rate maps only for the cross registered place cells in the first session
    cr_reg_cells_values_init=pc_cells_per_sessions[condition].iloc[:,0].values
    avr_fr_rate_map_init=avr_fr_rate_map_init.iloc[cr_reg_cells_values_init]
    #avr firing rate maps only for the cross registered place cells in the second session
    cr_reg_cells_values_end=pc_cells_per_sessions[condition].iloc[:,1].values
    avr_fr_rate_map_end=avr_fr_rate_map_end.iloc[cr_reg_cells_values_end]

    
    fr_rate_maps_to_compare[f'{condition}___{conditions[j+1]}'].extend([avr_fr_rate_map_init,avr_fr_rate_map_end])
    fr_rate_maps_to_compare_indexed[f'{condition}___{conditions[j+1]}'].extend([avr_fr_rate_map_init,avr_fr_rate_map_end])



In [None]:
#plotting the original place cell coding diagrams
n_sessions_to_compare=2

sel_fontize=20
for j,condition in enumerate(conditions[:-1]):
    sessions=f'{condition}___{conditions[j+1]}'

    #if there not place cells between the session its doesnt plot anything
    if fr_rate_maps_to_compare[sessions][0].empty:
        continue

    fig,ax=plt.subplots(2,n_sessions_to_compare,figsize=(20,20))
    fig.suptitle(f'{animal}\ncross registered place cells\n comparison\n{sessions}\n Up Original sorting\nBelow sorted to {condition}',fontsize=sel_fontize+5)
                        
    custom_ticks = [0, 50, 100, 150]  # Specify the positions where you want the ticks
    custom_labels = ['0', '50', '100', '150']  

    for i in range(n_sessions_to_compare):


        place_cell_activity=fr_rate_maps_to_compare[sessions][i]
        place_cell_indexes_max=np.argmax(place_cell_activity,axis=1) #finding the place of the max activity of every cell
        place_cell_indexed_filtered=place_cell_activity.apply(np.argmax, axis=1).sort_values().index #there are the new indexes of the place cells just so thez can form the nice place cell diagram
        transpose=place_cell_indexed_filtered
        place_cell_activity=place_cell_activity.reindex(index=transpose) #reindexing the place cell panda frame according to the max activitz and the position in the belt


        sns.heatmap(place_cell_activity,ax=ax[0,i],cmap='viridis',cbar=False)#,vmax=4)
        if i==0:
            ax[0,i].set_ylabel('cross registered place cells',fontsize=sel_fontize)
        else:
            ax[0,i].set_ylabel('')
        ax[0,i].set_xlabel('')
        ax[0,i].set_yticks([])
        ax[0,i].set_yticklabels([])
        ax[0,i].set_title(conditions[j+i],fontsize=sel_fontize)
        ax[0,i].set_xticks([])
        ax[0,i].set_xticklabels([])

      
    #plotting place coding diagram transposed on the first one

    place_cell_activity=fr_rate_maps_to_compare[sessions][0].reset_index(drop=True)
    place_cell_indexes_max=np.argmax(place_cell_activity,axis=1) #finding the place of the max activity of every cell
    place_cell_indexed_filtered=place_cell_activity.apply(np.argmax, axis=1).sort_values().index 
    transpose=place_cell_indexed_filtered

    for i in range(n_sessions_to_compare):
        place_cell_activity=fr_rate_maps_to_compare[sessions][i].reset_index(drop=True)
        place_cell_activity=place_cell_activity.reindex(index=transpose)
        fr_rate_maps_to_compare_indexed[sessions][i]=place_cell_activity
        sns.heatmap(place_cell_activity,ax=ax[1,i],cmap='viridis',cbar=False)
        if i==0:
            ax[1,i].set_ylabel('cross registered place cells',fontsize=sel_fontize)
        else:
            ax[1,i].set_ylabel('')
        ax[1,i].set_xlabel('Belt(cm)',fontsize=sel_fontize)
        ax[1,i].set_yticks([])
        ax[1,i].set_yticklabels([])
        ax[1,i].set_xticks(custom_ticks)
        ax[1,i].set_xticklabels(custom_labels,fontsize=sel_fontize)


    plt.tight_layout()

    save_folder=f'{output_folder}/cross_registration'

    if not os.path.exists(save_folder):
        os.makedirs(save_folder)

    plt.savefig(f'{save_folder}/{animal}_{sessions}_place_cells_comparison.png', format='png', dpi=300, bbox_inches='tight')
    plt.savefig(f'{save_folder}/{animal}_{sessions}_place_cells_comparison.pdf', format='pdf', dpi=300, bbox_inches='tight')

In [None]:
labels = [] 
colors = []
for condition in conditions:
  labels.extend([f"PC {condition}", f"nPC {condition}", f"sil {condition}"])  # for each condition, check categories PC, not-PC and low activity
  colors.extend(["red", "blue", "grey"])  # 255, 0, 0;  0, 255, 0; 0, 0, 0
# in each condition, we have PC and nPC categories, each have PC and nPC targets in the next category
sources = []  # should be 0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5, ...
targets = []  # should be 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, ...
values = []
link_colors = []
df=cr_reg_cells_df

for i_condition in range(len(conditions)-1): 

  #target cells place cells
  n_PC_to_PC=len(celltrack.cell_movement(df,conditions[i_condition],conditions[i_condition+1],source_type=pc_cells,target_type=pc_cells))
  n_non_PC_to_PC=len(celltrack.cell_movement(df,conditions[i_condition],conditions[i_condition+1],source_type=non_pc_cells,target_type=pc_cells))
  n_sil_to_PC=len(celltrack.cell_movement(df,conditions[i_condition],conditions[i_condition+1],source_type=silent_cells,target_type=pc_cells))
  sources.extend([3*i_condition,3*i_condition+1, 3*i_condition+2])
  targets.extend([3*(i_condition+1), 3*(i_condition+1), 3*(i_condition+1)])
  values.extend([n_PC_to_PC, n_non_PC_to_PC, n_sil_to_PC])
  link_colors.extend(["rgba(255, 0, 0, 0.4)", "rgba(0, 0, 255, 0.4)", "rgba(0, 0, 0, 0.4)"])  # PC -> x is light blue, nPC -> x is light red, Q -> y is "light black"

  #target cells non place cells
  n_PC_to_non_PC=len(celltrack.cell_movement(df,conditions[i_condition],conditions[i_condition+1],source_type=pc_cells,target_type=non_pc_cells))
  n_sil_to_non_PC=len(celltrack.cell_movement(df,conditions[i_condition],conditions[i_condition+1],source_type=silent_cells,target_type=non_pc_cells))
  n_non_PC_to_non_PC=len(celltrack.cell_movement(df,conditions[i_condition],conditions[i_condition+1],source_type=non_pc_cells,target_type=non_pc_cells))
  sources.extend([3*i_condition,3*i_condition+1, 3*i_condition+2])
  targets.extend([3*(i_condition+1)+1, 3*(i_condition+1)+1, 3*(i_condition+1)+1])
  values.extend([n_PC_to_non_PC, n_non_PC_to_non_PC, n_sil_to_non_PC])
  link_colors.extend(["rgba(255, 0, 0, 0.4)", "rgba(0, 0, 255, 0.4)", "rgba(0, 0, 0, 0.4)"])  # PC -> x is light blue, nPC -> x is light red, Q -> y is "light black"


  #targer cell silent cells  
  n_PC_to_sil=len(celltrack.cell_movement(df,conditions[i_condition],conditions[i_condition+1],source_type=pc_cells,target_type=silent_cells))
  n_non_PC_to_sil=len(celltrack.cell_movement(df,conditions[i_condition],conditions[i_condition+1],source_type=non_pc_cells,target_type=silent_cells))
  n_sil_to_sil=len(celltrack.cell_movement(df,conditions[i_condition],conditions[i_condition+1],source_type=silent_cells,target_type=silent_cells))

  sources.extend([3*i_condition,3*i_condition+1, 3*i_condition+2])
  targets.extend([3*(i_condition+1)+2, 3*(i_condition+1)+2, 3*(i_condition+1)+2])
  values.extend([n_PC_to_sil, n_non_PC_to_sil, n_sil_to_sil])
  link_colors.extend(["rgba(255, 0, 0, 0.4)", "rgba(0, 0, 255, 0.4)", "rgba(0, 0, 0, 0.4)"])  # PC -> x is light blue, nPC -> x is light red, Q -> y is "light black"

  
fig = go.Figure(data=[go.Sankey(
  node = dict(
    pad = 15,
    thickness = 20,
    line = dict(color = "black", width = 0.5),
    label = labels,
    color = colors
  ),
  link = dict(
    source = sources, # indices correspond to labels, eg A1, A2, A1, B1, ...
    target = targets,
    value = values,
    color = link_colors
))])

fig.update_layout(title_text=f'{animal}\n\n{conditions}', font_size=10)

save_folder=f'D:/sd_project_pbox/results/{animal}/cross_registration'

if not os.path.exists(save_folder):
  os.makedirs(save_folder)

fig.write_html(f'D:/sd_project_pbox/results/{animal}/cross_registration/{conditions}.html')


#pio.write_image(fig, F'{output_folder}/{animal}_sankey_diagram.pdf')



Plotting firing rate maps of place cells in circular plots

In [None]:
angle_between_sessions=OrderedDict()
dif_max_between_sessions=OrderedDict()

theta = np.linspace(0, 2*np.pi, num_bins, endpoint=False)


for i,con in enumerate(conditions[:-1]):

    sessions=f'{con}___{conditions[i+1]}'

    angle_between_sessions[sessions]=[]
    dif_max_between_sessions[sessions]=[]

    cell_n=len(fr_rate_maps_to_compare_indexed[sessions][0])

    for cell in range(cell_n):

        fig = plt.figure()
        #plt.title(f'{animal}\nComparing between\n{sessions}')
        ax = fig.add_subplot(1,1,1, projection='polar')
        ax.set_title(f'{animal}\nComparing between\n{sessions}\nCell Index {cell}')
        ax.set_theta_direction(-1)  # Set clockwise direction
        ax.set_theta_zero_location('N')  # Set zero angle at North
        # Customize the tick labels on the circles (bins)
        circle_ticks = [0, 120, 240]  # Tick positions in degrees
        circle_labels = ['0/150 cm', '50 cm', '100 cm']  # Labels for the ticks
        ax.set_xticks(np.radians(circle_ticks))  # Set the tick positions in radians
        ax.set_xticklabels(circle_labels)

        max_indices = [] #indices where the maximum activation is
        max_values=[] #what are the maximum values


        for ses in range(n_sessions_to_compare):
            event_radii = np.ones_like(theta) * (ses + 1)
            cyclic_data=fr_rate_maps_to_compare_indexed[sessions][ses].iloc[cell]
             # Plot the firing rate map as colored points along the edge of the circle
            sc = ax.scatter(theta, event_radii, c=cyclic_data, cmap='viridis', s=200, edgecolors='face')

            # Find the index of the maximum value
            max_idx = np.argmax(cyclic_data)
            max_indices.append(max_idx)
            max_values.append(max(cyclic_data))

            # Draw line from center to maximum value
            ax.plot([0, theta[max_idx]], [0, event_radii[max_idx]], label=f'Session {ses + 1} Max', linestyle='-',linewidth=3)

        
    
        angle_between_lines=celltrack.find_angle_between_sessions(theta=theta,max_indices=max_indices)

        angle_between_sessions[sessions].append(angle_between_lines)
        dif=max_values[1]-max_values[0]
        dif_max_between_sessions[sessions].append(abs(dif))



        # Add a color bar
        ax.yaxis.set_visible(False)
        plt.colorbar(sc, ax=ax, orientation='vertical', pad=0.1)
        save_folder=f'{output_folder}/cross_registration'
        plt.savefig(f'{save_folder}/{animal}_{sessions}_Cell_{cell}_circular_plot.png', format='png', dpi=300, bbox_inches='tight')
        #plt.show()


In [None]:
# Calculate mean and standard deviation for each condition
data=angle_between_sessions

means = {k: np.mean(v) for k, v in data.items()}
stds = {k: np.std(v) for k, v in data.items()}

# Plotting the bar plot with error bars
fig, ax = plt.subplots()

conditions_to_compare = list(data.keys())
mean_values = [means[cond] for cond in conditions_to_compare]
std_values = [stds[cond] for cond in conditions_to_compare]

bars=ax.bar(conditions_to_compare, mean_values, yerr=std_values, capsize=5, color=['blue', 'orange', 'green', 'red', 'purple'])
ax.set_ylabel('Angle drift')
ax.set_title(f'{animal}\nAverage drift between sessions')

# Perform pairwise t-tests
p_values = {}
for cond1, cond2 in combinations(conditions_to_compare, 2):
    t_stat, p_value = stats.ttest_ind(data[cond1], data[cond2])
    p_values[(cond1, cond2)] = p_value

plt.xticks(rotation=45)

# Display p-values on the plot
# Adjust the position as needed
y_offset = -0.3#min(mean_values)# + 0.5* max(std_values)
# for (cond1, cond2), p_value in p_values.items():
#     plt.figtext(0.5, y_offset, f'{cond1} vs {cond2} p-value: {p_value:.5f}', ha='center', fontsize=10)
#     y_offset -= 0.05


# Display p-values on the plot and annotate significance
starting_y = max(mean_values) + max(std_values) + 1  # Start lines above the tallest bar with some margin
line_height = 10  # Height of each line
asterisk_y_offset = 1.5 * line_height  # Vertical space increment for asterisks
question_y_offset = 2.5 * line_height  # Vertical space increment for question marks

for (cond1, cond2), p_value in p_values.items():
    if p_value < 0.05:
        symbol = '*'
        y_offset = asterisk_y_offset
    elif 0.1> p_value >= 0.05:
        symbol = '?'
        y_offset = question_y_offset
    else:
        continue  # No symbol or line for p-values >0.01

    idx1, idx2 = conditions_to_compare.index(cond1), conditions_to_compare.index(cond2)
    x1, x2 = bars[idx1].get_x() + bars[idx1].get_width() / 2, bars[idx2].get_x() + bars[idx2].get_width() / 2
    y = starting_y
    h = line_height
    ax.plot([x1, x1, x2, x2], [y, y + h, y + h, y], color='black')
    ax.text((x1 + x2) / 2, y + h, symbol, ha='center', va='bottom', color='black', fontsize=12)
    starting_y += y_offset  # Increment y for the next line to avoid overlap

save_folder=f'{output_folder}/cross_registration'
plt.savefig(f'{save_folder}/{animal}_drift_conditions.png', format='png', dpi=300, bbox_inches='tight')
plt.show()


# Print the p-values for all comparisons
for (cond1, cond2), p_value in p_values.items():
    print(f'{cond1} vs {cond2} T-statistic: {t_stat:.5f}, p-value: {p_value:.5f}')

In [None]:
# # Calculate mean and standard deviation for each condition
# data=angle_between_sessions

# means = {k: np.mean(v) for k, v in data.items()}
# stds = {k: np.std(v) for k, v in data.items()}

# # Plotting the bar plot with error bars
# fig, ax = plt.subplots()

# conditions_to_compare = list(data.keys())
# mean_values = [means[cond] for cond in conditions_to_compare]
# std_values = [stds[cond] for cond in conditions_to_compare]

# ax.bar(conditions_to_compare, mean_values, yerr=std_values, capsize=5, color=['blue', 'orange'])
# ax.set_ylabel('Angle drift')
# ax.set_title('Average drift between sessions')

# # Perform a t-test to check for statistical significance
# condition1_values = data[conditions_to_compare[0]]
# condition2_values = data[conditions_to_compare[1]]
# t_stat, p_value = stats.ttest_ind(condition1_values, condition2_values)

# # Display the p-value on the plot
# plt.figtext(0.5, -0.1, f'p-value: {p_value:.5f}', ha='center', fontsize=12)

# plt.show()

# # Print the p-value
# print(f'T-statistic: {t_stat:.5f}, p-value: {p_value:.5f}')

In [None]:
# data=dif_max_between_sessions

# means = {k: np.mean(v) for k, v in data.items()}
# stds = {k: np.std(v) for k, v in data.items()}

# # Plotting the bar plot with error bars
# fig, ax = plt.subplots()

# conditions_to_compare = list(data.keys())
# mean_values = [means[cond] for cond in conditions_to_compare]
# std_values = [stds[cond] for cond in conditions_to_compare]

# ax.bar(conditions_to_compare, mean_values, yerr=std_values, capsize=5, color=['blue', 'orange'])
# ax.set_ylabel('Max difference')
# ax.set_title('Max value difference between sessions')

# # Perform a t-test to check for statistical significance
# condition1_values = data[conditions_to_compare[0]]
# condition2_values = data[conditions_to_compare[1]]
# t_stat, p_value = stats.ttest_ind(condition1_values, condition2_values)

# # Display the p-value on the plot
# plt.figtext(0.5, -0.1, f'p-value: {p_value:.5f}', ha='center', fontsize=12)

# plt.show()

# # Print the p-value
# print(f'T-statistic: {t_stat:.5f}, p-value: {p_value:.5f}')