In [None]:
from IPython.core.getipython import get_ipython
from matplotlib import pyplot as plt
import numpy as np
import sys
import h5py
import os
import pandas as pd
import seaborn as sns
import plotly.graph_objects as go
sys.path.append("..")
from placecode import utils as ut
from placecode.from_caiman import *

try:
    if __IPYTHON__:
        get_ipython().run_line_magic('load_ext', 'autoreload')
        get_ipython().run_line_magic('autoreload', '2')
except NameError:
    pass
from datetime import datetime
import scipy
from scipy.ndimage import gaussian_filter1d  # smooth signal strength maps

sns.set(font_scale=3)
sns.set_style("whitegrid")

In [None]:
save_figs = True
if save_figs:
    file_extension = ".pdf"
    output_folder = ut.open_dir("Choose folder to save figures")
    print(f"Saving figures as {file_extension} is turned on. Saving figures to {output_folder}")
    now = datetime.now()
    datetime_str = f"{now.year:04}{datetime.now().month:02}{datetime.now().day:02}-{datetime.now().hour:02}{datetime.now().minute:02}{datetime.now().second:02}" 

In [None]:
# TODO: do a heatmap plotting all persistent cells (that made through analysis): each column is condition, each row contains same cell spatial component.

In [None]:
sigma=2  # sigma for gaussian 1d smoothing of signal strength map (firing rate map)

## Open (hdf5) files

In [None]:
files_list = []
while True:
    fpath = ut.open_file("Open hdf5 file, or press Cancel to finish")
    if fpath == ".":  # user pressed cancel
        break
    else:
        files_list.append(fpath)

In [None]:
Y_list = []
A_list = []
dims_list = []  # Cn entry in workspace # TODO: A_sparse always have lower resolution, probably from cropping... should I define that as dims?
templates = []
p_vals = []
conditions = []
tv_angles = []
tv_lengths = []
ssm_zs = []
ssm_event_masks = []
mouse_ids = []
for fpath in files_list:
    with h5py.File(fpath, "r") as hf:
        mouse_id = hf.attrs["mouse_id"]
        resolution = hf.attrs["resolution"][()]
        n_components = hf.attrs["n_units"]
        condition = hf.attrs["condition"]
        ps = hf["p_values_tuned"][()]
        A_data = hf["A_data"][()]
        A_indices = hf["A_indices"][()]
        A_indptr = hf["A_indptr"][()]
        A_shape = hf["A_shape"][()]
        tv_a = hf["tuned_vector_angles"][()]
        tv_l = hf["tuned_vector_lengths"][()]
        ssm_z = hf["ssm_z"][()]
        ssm_event_mask = hf["ssm_events_mask"][()]
        #spatial = ut.read_spatial(A_data, A_indices, A_indptr, A_shape, n_components, resolution, unflatten=False)
        spatial = scipy.sparse.csc_matrix((A_data, A_indices, A_indptr), shape=A_shape)
        dims_list.append(resolution)
        A_list.append(spatial)  # need to swap: (n_units, n_pixels) -> (n_pixels, n_units)
        p_vals.append(ps)
        conditions.append(condition)
        tv_angles.append(tv_a)
        tv_lengths.append(tv_l)
        ssm_zs.append(ssm_z)
        mouse_ids.append(mouse_id)
        ssm_event_masks.append(ssm_event_mask)

In [None]:
for m_id in mouse_ids[1:]:  # make sure all data belongs to same mouse
    assert m_id == mouse_ids[0]
mouse_id = mouse_ids[0]

In [None]:
# convert tuned vector data into numpy array. To deal with varying number of units per recording (condition), pad each column to the longest with np.nan
max_len = max(len(lst) for lst in tv_angles) 
def convert_to_np(list_of_arrs):
    """
    given a list of 1D arrays, convert to a 2D array, add padding with np.nans to achieve equal column sizes 
    """
    return np.array([np.concatenate([lst, [np.nan]*(max_len - len(lst))]) for lst in list_of_arrs]).T

tv_angles_padded = convert_to_np(tv_angles)
tv_lengths_padded = convert_to_np(tv_lengths)
p_vals_padded = convert_to_np(p_vals)


In [None]:
if len(templates) > 0:
    templates_cropped = []
    for template in templates:
        FOV_shape = template.shape
        cropped_shape = dims_list[0]
        
        x_crop_onesided = (FOV_shape[0] - cropped_shape[0])//2
        assert 2*x_crop_onesided == FOV_shape[0] - cropped_shape[0]

        y_crop_onesided = (FOV_shape[1] - cropped_shape[1])//2
        assert 2*y_crop_onesided == FOV_shape[1] - cropped_shape[1]
        template_cropped = template[y_crop_onesided:-y_crop_onesided,x_crop_onesided:-x_crop_onesided]  # TODO: x and y swapped?
        templates_cropped.append(template_cropped)

## Use `register_multisession()`

The function `register_multisession()` requires 3 arguments:
- `A`: A list of ndarrays or scipy.sparse.csc matrices with (# pixels X # component ROIs) for each session
- `dims`: Dimensions of the FOV, needed to restore spatial components to a 2D image
- `templates`: List of ndarray matrices of size `dims`, template image of each session

In [None]:
spatial_union, assignments, matchings = register_multisession(A=A_list, dims=dims_list[0])

The function returns 3 variables for further analysis:
- `spatial_union`: csc_matrix (# pixels X # total distinct components), the union of all ROIs across all sessions aligned to the FOV of the last session.
- `assignments`: ndarray (# total distinct components X # sessions). `assignments[i,j]=k` means that component `k` from session `j` has been identified as component `i` from the union of all components, otherwise it takes a `NaN` value. Note that for each `i` there is at least one session index `j` where `assignments[i,j]!=NaN`.
- `matchings`: list of (# sessions) lists. Saves `spatial_union` indices of individual components in each session. `matchings[j][k] = i` means that component `k` from session `j` is represented by component `i` in the union of all components `spatial_union`. In other words `assignments[matchings[j][k], j] = j`.

## Create various subgroups

### Filter conditions

In [None]:
n_conditions = len(conditions)
assignments_filtered = assignments[~np.isnan(assignments).all(axis=1)]  # filter out rows full of np.nan

### Take only omnipresent cells
(omnipresent cell = cell that could be identified in all recordings)

In [None]:
assignments_omnipresent = assignments_filtered[~np.isnan(assignments_filtered).any(axis=1)].astype(np.int16)

### Match (pair) values for same cell from different conditions (recordings) 

In [None]:
# for each omnipresent unit, get the vector length for each included condition
tv_lengths_paired = np.zeros(assignments_omnipresent.shape)
tv_angles_paired = np.zeros(assignments_omnipresent.shape)
p_vals_paired = np.zeros(assignments_omnipresent.shape)
for i_cond in range(len(conditions)):
    tv_lengths_paired[:, i_cond] = tv_lengths_padded[ assignments_omnipresent.T[i_cond],i_cond]
    tv_angles_paired[:, i_cond] = tv_angles_padded[ assignments_omnipresent.T[i_cond],i_cond]
    p_vals_paired[:, i_cond] = p_vals_padded[assignments_omnipresent.T[i_cond], i_cond]

# check that np.nans (coming from analysis where cells did not fulfill criteria to be included) match for all variables
assert (~np.isnan(p_vals_paired).any(axis=1) == ~np.isnan(tv_angles_paired).any(axis=1) ).all()
assert (~np.isnan(p_vals_paired).any(axis=1) == ~np.isnan(tv_lengths_paired).any(axis=1) ).all()

### Get persistent cels
(persistent cell = omnipresent cell that fulfilled requirement for getting included in place coding analysis for each condition)

In [None]:
i_persistent = ~np.isnan(p_vals_paired).any(axis=1)

In [None]:
tv_lengths_persistent = tv_lengths_paired[i_persistent]
tv_angles_persistent = tv_angles_paired[i_persistent]
p_vals_persistent = p_vals_paired[i_persistent]
assignments_persistent = assignments_omnipresent[i_persistent]

assert tv_lengths_persistent.shape == tv_angles_persistent.shape
assert tv_angles_persistent.shape == p_vals_persistent.shape
assert p_vals_persistent.shape == assignments_persistent.shape

* `assignments_persistent` contains one row per persistent cell where it fulfilled analysis criteria (minimum number of events...) for all included conditions. For each row, each column contains the original cell index in the recording of the corresponding `conditions_to_use` condition (i.e. `assignments_persistent[0][0]==8` means the first persistent cell is cell 8 (with indexing starting at 0) in the baseline recording. The same cell might be cell 253 in the second condition (`assignments_persistent[0][1]==253`) )
* `tv_lengths_persistent`, `tv_angles_persistent`, `p_vals_persistent` contain the tuning vector lengths, angles, and the p value, each row one neuron tracked over the conditions (that fulfilled analysis criteria). The rows and columns match those of `assignments_persistent` (i.e. the same cell, same condition is in the same row and column)

### Get persistent cells that are initially place coding (ipc) and not initially place coding (nipc)

In [None]:
i_ipc = np.where(p_vals_persistent[:,0] <= 0.05)[0]

In [None]:
tv_lengths_ipc = tv_lengths_persistent[i_ipc]
tv_angles_ipc = tv_angles_persistent[i_ipc]
p_vals_ipc = p_vals_persistent[i_ipc]
assignments_ipc = assignments_persistent[i_ipc]

tv_lengths_nipc = tv_lengths_persistent[~i_ipc]
tv_angles_nipc = tv_angles_persistent[~i_ipc]
p_vals_nipc = p_vals_persistent[~i_ipc]
assignments_nipc = assignments_persistent[~i_ipc]

# Analysis

## Calculate mean event rate (average over all cells) per condition


In [None]:
# create dataframe for seaborn
# columns: condition (bl, 30min, 60min...); cell type (npc, pc, low activity (la)); event rate
# This includes all cells, not just omnipresent/persistent cells!
col_event_rates = []
col_cell_types = []
col_conds = []
col_cell_idxs = []

for i_cond in range(len(ssm_event_masks)):
    npc_mask = p_vals[i_cond] > 0.05
    pc_mask = p_vals[i_cond] <= 0.05
    q_mask = np.isnan(p_vals[i_cond])
    # shape of ssm_event_masks[i_cond][mask]: (n_masked_cells, n_rounds, n_bins)
    # sum up events for each round (i.e. sum up bins, axis=2), calculate average over rounds (i.e. over axis=1)
    event_rate_npc = np.mean(np.sum(ssm_event_masks[i_cond][npc_mask], axis=2), axis=1)
    event_rate_pc = np.mean(np.sum(ssm_event_masks[i_cond][pc_mask], axis=2), axis=1)
    event_rate_la = np.mean(np.sum(ssm_event_masks[i_cond][q_mask], axis=2), axis=1)

    # get the index of the cells in the first condition (baseline)
    idx_npc = assignments[np.argsort(assignments[:,i_cond])][np.where(npc_mask)[0]][:,0]
    idx_pc = assignments[np.argsort(assignments[:,i_cond])][np.where(pc_mask)[0]][:,0]
    idx_q = assignments[np.argsort(assignments[:,i_cond])][np.where(q_mask)[0]][:,0]


    col_event_rates.extend(event_rate_npc)
    col_conds.extend([conditions[i_cond]]*len(event_rate_npc))
    col_cell_types.extend(["npc"]*len(event_rate_npc))
    col_cell_idxs.extend(idx_npc)

    col_event_rates.extend(event_rate_pc)
    col_conds.extend([conditions[i_cond]]*len(event_rate_pc))
    col_cell_types.extend(["pc"]*len(event_rate_pc))
    col_cell_idxs.extend(idx_pc)
    
    col_event_rates.extend(event_rate_la)
    col_conds.extend([conditions[i_cond]]*len(event_rate_la))
    col_cell_types.extend(["la"]*len(event_rate_la))
    col_cell_idxs.extend(idx_q)


df_event_rates = pd.DataFrame({"condition": col_conds, "cell_type": col_cell_types, "event_rate": col_event_rates, "cell_bl_id": col_cell_idxs})  # cell_bl_id must be float because of the np.NaNs


In [None]:
fig, axs = plt.subplots(1, len(conditions), sharey=True, sharex=True, figsize=(24, 10))
plt.suptitle(mouse_id)
for i_cond in range(len(conditions)):
    sns.histplot(
        df_event_rates[(df_event_rates["condition"] == conditions[i_cond])],
        x="event_rate", hue="cell_type",
        multiple="layer",
        edgecolor=".3",
        linewidth=.5,
        log_scale=(False, True),
        ax=axs[i_cond]
    )
    axs[i_cond].set_title(f'{conditions[i_cond]}, n_q={len(df_event_rates[(df_event_rates["condition"] == conditions[i_cond]) & (df_event_rates["cell_type"] == "la")])}')
if save_figs:
    out_fpath = os.path.join(output_folder, f"pca_{mouse_id}_hist_cell_types_{datetime_str}{file_extension}")
    plt.savefig(out_fpath)
    print(f"Saved to {out_fpath}")
plt.show()

### Scatter plot in 3D space the event rate of each unit
Two methods:
1. only keep persistent units
2. replace nan with 0 (ansatz: not identified cells were not firing)

In [None]:
if len(conditions) == 3:
    def get_event_rate(grp, cond):
        er = grp[grp["condition"] == cond]["event_rate"]
        if len(er) > 0:
            assert len(er) == 1  # assert unique event rate for condition and cell id
            return er.iloc[0]
        else:
            return 0  # assume event rate is 0 if cell was not identified for specified condition

    xs = []
    ys = []
    zs = []
    cids = []
    for i, g in df_event_rates.groupby("cell_bl_id"):
        # get x, y, z coordinates as event rates for each cell in bl, 30min, 60min (or cond[0], cond[1], cond[2])
        # if a condition is missing, fill in event rate as 0
        x = get_event_rate(g, conditions[0])
        y = get_event_rate(g, conditions[1])
        z = get_event_rate(g, conditions[2])

        xs.append(x)
        ys.append(y)
        zs.append(z)
        cids.append(i)

    df_matched_event_rates = pd.DataFrame({"cell_bl_id": cids, "event_rate_0": xs, "event_rate_1": ys, "event_rate_2": zs})

In [None]:
if len(conditions) == 3:
    df_matched_event_rates_filt = df_matched_event_rates[df_matched_event_rates["event_rate_0"] != 0]  # filter non-zero baseline event rate
    fig = plt.figure(figsize=(12,12))
    plt.scatter(df_matched_event_rates_filt["event_rate_1"]/df_matched_event_rates_filt["event_rate_0"], df_matched_event_rates_filt["event_rate_2"]/df_matched_event_rates_filt["event_rate_0"])
    ax = plt.gca()
    plt.suptitle("Event rate ratios")
    ax.set_xlabel(f'{conditions[1]}/{conditions[0]}')
    ax.set_ylabel(f'{conditions[2]}/{conditions[0]}')
    if save_figs:
        out_fpath = os.path.join(output_folder, f"pca_{mouse_id}_event_rate_ratios_{datetime_str}{file_extension}")
        plt.savefig(out_fpath)
        print(f"Saved to {out_fpath}")
    plt.show()
    

## Check movement between place-coding, non-place-coding, low activity cells
Low activity (la) cells: cells that were not included in PC analysis (minimum event number criterion not fulfilled)

In [None]:
# silent cells appear as np.nan in p_vals. Make sure they return FALSE for both PC and nPC conditions
assert not(np.nan > 0.05)
assert not(np.nan <= 0.05)
assert np.isnan(np.nan)

labels = [] 
colors = []
for condition in conditions:
  labels.extend([f"PC {condition}", f"nPC {condition}", f"la {condition}"])  # for each condition, check categories PC, not-PC and low activity
  colors.extend(["red", "blue", "grey"])  # 255, 0, 0;  0, 255, 0; 0, 0, 0
# in each condition, we have PC and nPC categories, each have PC and nPC targets in the next category
sources = []  # should be 0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5, ...
targets = []  # should be 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, ...
values = []
link_colors = []
for i_condition in range(n_conditions-1):  # last condition does not have output
  # PC, nPC and S sources flow to PC in target
  # i. e. PC[i_condition] -> PC[i_condition+1], nPC[i_condition] -> PC[i_condition+1], Q[i_condition] -> PC[i_condition+1]
  n_PC_to_PC = np.sum(np.logical_and(p_vals_paired[:,i_condition] <= 0.05,  p_vals_paired[:,i_condition+1] <= 0.05))
  n_nPC_to_PC = np.sum(np.logical_and(p_vals_paired[:,i_condition] > 0.05,  p_vals_paired[:,i_condition+1] <= 0.05))
  n_la_to_PC = np.sum(np.logical_and(np.isnan(p_vals_paired[:, i_condition]), p_vals_paired[:,i_condition+1] <= 0.05))
  sources.extend([3*i_condition,3*i_condition+1, 3*i_condition+2])
  targets.extend([3*(i_condition+1), 3*(i_condition+1), 3*(i_condition+1)])
  values.extend([n_PC_to_PC, n_nPC_to_PC, n_la_to_PC])
  link_colors.extend(["rgba(255, 0, 0, 0.4)", "rgba(0, 0, 255, 0.4)", "rgba(0, 0, 0, 0.4)"])  # PC -> x is light blue, nPC -> x is light red, Q -> y is "light black"

  # PC, nPC and Q sources flow to nPC in target
  # i. e. PC[i_condition] -> nPC[i_condition+1], nPC[i_condition] -> nPC[i_condition+1], Q[i_condition] -> nPC[i_condition+1]
  n_PC_to_nPC = np.sum(np.logical_and(p_vals_paired[:,i_condition] <= 0.05,  p_vals_paired[:,i_condition+1] > 0.05))
  n_nPC_to_nPC = np.sum(np.logical_and(p_vals_paired[:,i_condition] > 0.05,  p_vals_paired[:,i_condition+1] > 0.05))
  n_la_to_nPC = np.sum(np.logical_and(np.isnan(p_vals_paired[:, i_condition]), p_vals_paired[:,i_condition+1] > 0.05))
  sources.extend([3*i_condition,3*i_condition+1, 3*i_condition+2])
  targets.extend([3*(i_condition+1)+1, 3*(i_condition+1)+1, 3*(i_condition+1)+1])
  values.extend([n_PC_to_nPC, n_nPC_to_nPC, n_la_to_nPC])
  link_colors.extend(["rgba(255, 0, 0, 0.4)", "rgba(0, 0, 255, 0.4)", "rgba(0, 0, 0, 0.4)"])  # PC -> x is light blue, nPC -> x is light red, Q -> y is "light black"


  # PC, nPC and Q sources flow to S in target
  # i. e. PC[i_condition] -> Q[i_condition+1], nPC[i_condition] -> Q[i_condition+1], Q[i_condition] -> Q[i_condition+1]
  n_PC_to_la = np.sum(np.logical_and(p_vals_paired[:,i_condition] <= 0.05,  np.isnan(p_vals_paired[:,i_condition+1]) ))
  n_nPC_to_la = np.sum(np.logical_and(p_vals_paired[:,i_condition] > 0.05,  np.isnan(p_vals_paired[:,i_condition+1])  ))
  n_la_to_la = np.sum(np.logical_and(np.isnan(p_vals_paired[:, i_condition]), np.isnan(p_vals_paired[:,i_condition+1])  ))
  sources.extend([3*i_condition,3*i_condition+1, 3*i_condition+2])
  targets.extend([3*(i_condition+1)+2, 3*(i_condition+1)+2, 3*(i_condition+1)+2])
  values.extend([n_PC_to_la, n_nPC_to_la, n_la_to_la])
  link_colors.extend(["rgba(255, 0, 0, 0.4)", "rgba(0, 0, 255, 0.4)", "rgba(0, 0, 0, 0.4)"])  # PC -> x is light blue, nPC -> x is light red, Q -> y is "light black"


fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 15,
      thickness = 20,
      line = dict(color = "black", width = 0.5),
      label = labels,
      color = colors
    ),
    link = dict(
      source = sources, # indices correspond to labels, eg A1, A2, A1, B1, ...
      target = targets,
      value = values,
      color = link_colors
  ))])

fig.update_layout(title_text="Place coding (PC) - non-place coding (nPC) - low activity (la)", font_size=10)

if save_figs:
    out_fpath = os.path.join(output_folder, f"pca_{mouse_id}_sankey_{datetime_str}.html")
    out_fpath_original_ext = os.path.join(output_folder, f"pca_{mouse_id}_sankey_{datetime_str}{file_extension}")
    fig.write_html(out_fpath)
    fig.write_image(out_fpath_original_ext)  # requires kaleido package
    print(f"Saved to {out_fpath}")
fig.show()


## % of place cells at each time point

In [None]:
place_cell_ratio = np.zeros(len(p_vals))  # the % of place cells for each condition
for i_condition in range(len(p_vals)):
    place_cell_ratio[i_condition] = np.sum(p_vals[i_condition] <= 0.05)/len(p_vals[i_condition])


In [None]:
fig, axs = plt.subplots(1, 2, figsize=(16, 8))
axs[0].plot(place_cell_ratio*100.)
axs[0].set_xticks(range(len(conditions)), conditions)
axs[0].set_ylabel("% place cells")

axs[1].plot([len(p_vals[i_cond]) for i_cond in range(len(p_vals))])
axs[1].set_xticks(range(len(conditions)), conditions)
axs[1].set_ylabel("# cells (total)")
if save_figs:
    out_fpath = os.path.join(output_folder, f"pca_{mouse_id}_percent_{datetime_str}{file_extension}")
    plt.savefig(out_fpath)
    print(f"Saved to {out_fpath}")
plt.show()

## Check persistent cells behaviour

In [None]:
labels = [] 
colors= []
for condition in conditions:
  labels.extend([f"PC {condition}", f"nPC {condition}"])  # for each condition, check categories PC and not-PC
  colors.extend(["red", "blue"])
# in each condition, we have PC and nPC categories, each have PC and nPC targets in the next category
sources = []  # should be 0, 1, 0, 1, 2, 3, 2, 3, ...
targets = []  # should be 2, 3, 2, 3, 4, 5, 4, 5, ...
values = []
link_colors = []
for i_condition in range(len(conditions)-1):  # last condition does not have output
  # PC and nPC sources flow to PC in target
  # i. e. PC[i_condition] -> PC[i_condition+1], nPC[i_condition] -> PC[i_condition+1]
  n_PC_to_PC = np.sum(np.logical_and(p_vals_persistent[:,i_condition] <= 0.05,  p_vals_persistent[:,i_condition+1] <= 0.05))
  n_nPC_to_PC = np.sum(np.logical_and(p_vals_persistent[:,i_condition] > 0.05,  p_vals_persistent[:,i_condition+1] <= 0.05))
  sources.extend([2*i_condition,2*i_condition+1])
  targets.extend([2*(i_condition+1), 2*(i_condition+1)])
  values.extend([n_PC_to_PC, n_nPC_to_PC])
  link_colors.extend(["rgba(255, 0, 0, 0.4)", "rgba(0, 0, 255, 0.4)"])  # PC -> x is light blue, nPC -> x is light red
  
  # PC and nPC sources flow to nPC in target
  # i. e. PC[i_condition] -> nPC[i_condition+1], nPC[i_condition] -> nPC[i_condition+1]
  n_PC_to_nPC = np.sum(np.logical_and(p_vals_persistent[:,i_condition] <= 0.05,  p_vals_persistent[:,i_condition+1] > 0.05))
  n_nPC_to_nPC = np.sum(np.logical_and(p_vals_persistent[:,i_condition] > 0.05,  p_vals_persistent[:,i_condition+1] > 0.05))
  sources.extend([2*i_condition,2*i_condition+1])
  targets.extend([2*(i_condition+1)+1, 2*(i_condition+1)+1])
  values.extend([n_PC_to_nPC, n_nPC_to_nPC])
  link_colors.extend(["rgba(255, 0, 0, 0.4)", "rgba(0, 0, 255, 0.4)"])  # PC -> x is light blue, nPC -> x is light red

fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 15,
      thickness = 20,
      line = dict(color = "black", width = 0.5),
      label = labels,
      color = colors
    ),
    link = dict(
      source = sources, # indices correspond to labels, eg A1, A2, A1, B1, ...
      target = targets,
      value = values,
      color=link_colors
  ))])

fig.update_layout(title_text="Place coding (PC) - non-place coding (nPC) of persistent cells", font_size=10)
fig.write_html("D:\\Downloads\\pc_npc.html")
fig.show()

### Check tuning vector direction change/stability for initially place-coding cells

In [None]:
fig, ax = plt.subplots(subplot_kw={'projection': 'polar'}, figsize=(12,12))
#ax.set_yscale('log')
for angles in tv_angles_ipc:
    radii = [i+1 for i in range(len(angles))]#tv_vector_lengths_paired[i_unit]
    ax.plot(angles, radii, linewidth=0.3, marker='o')  # -pi to pi
if save_figs:
    out_fpath = os.path.join(output_folder, f"pca_{mouse_id}_directions_{datetime_str}{file_extension}")
    plt.savefig(out_fpath)
    print(f"Saved to {out_fpath}")
plt.show()

## Plot ssm of initial place coding cells

In [None]:
ssm_z_ipc = [ssm_zs[i_cond][assignments_ipc[:,i_cond]] for i_cond in range(n_conditions)]
# average over rounds
ssm_z_ipc_avg = [np.average(ssm_z_ipc[i_cond], axis=1) for i_cond in range(n_conditions)]
# order cells by maximum of avg ssm in baseline
# 1. find index of maximum in ssm for each cell
# 2. sort cell indices by the corresponding entries in ascending order
idx_onset_sorted = [np.argsort(np.argmax(ssm_z_ipc_avg[i_cond], axis=1)) for i_cond in range(n_conditions)]
ssm_z_ipc_avg_sorted = [ssm_z_ipc_avg[i_cond][idx_onset_sorted[i_cond]] for i_cond in range(n_conditions)]
# smooth each row of ssm
ssm_z_ipc_avg_sorted_smooth = ssm_z_ipc_avg_sorted.copy()
for i_row in range(len(ssm_z_ipc_avg_sorted_smooth)):
    ssm_z_ipc_avg_sorted_smooth[i_row]= gaussian_filter1d(ssm_z_ipc_avg_sorted_smooth[i_row], sigma)

In [None]:
fig, axs = plt.subplots(1, n_conditions, figsize=(18, 8))
plt.suptitle(f"{mouse_id} initial PC")
for i_cond in range(n_conditions):
    axs[i_cond].imshow(ssm_z_ipc_avg_sorted_smooth[i_cond], aspect="auto", cmap="jet")
    axs[i_cond].title.set_text(conditions[i_cond])
if save_figs and False:
    fig_fpath = os.path.join(output_folder, f"pca_{mouse_id}_ssm_ipc_gsig={sigma}_"+datetime_str+file_extension)
    plt.savefig(fig_fpath)
    print(f"Saved to {fig_fpath}")
plt.show()

## Plot ssm of all persistent cells

In [None]:
ssm_z_persistent = [ssm_zs[i_cond][assignments_persistent[:,i_cond]] for i_cond in range(n_conditions)]
# average over rounds
ssm_z_persistent_avg = [np.average(ssm_z_persistent[i_cond], axis=1) for i_cond in range(n_conditions)]
# order cells by maximum of avg ssm in baseline
# 1. find index of maximum in ssm for each cell
# 2. sort cell indices by the corresponding entries in ascending order
idx_onset_sorted = [np.argsort(np.argmax(ssm_z_persistent_avg[i_cond], axis=1)) for i_cond in range(n_conditions)]
ssm_z_persistent_avg_sorted = [ssm_z_persistent_avg[i_cond][idx_onset_sorted[i_cond]] for i_cond in range(n_conditions)]

ssm_z_persistent_avg_sorted_smooth = ssm_z_persistent_avg_sorted.copy()
for i_row in range(len(ssm_z_persistent_avg_sorted_smooth)):
    ssm_z_persistent_avg_sorted_smooth[i_row]= gaussian_filter1d(ssm_z_persistent_avg_sorted_smooth[i_row], sigma)

In [None]:
fig, axs = plt.subplots(1, n_conditions, figsize=(18, 8))
plt.suptitle(f"{mouse_id} persistent")
for i_cond in range(n_conditions):
    axs[i_cond].imshow(ssm_z_persistent_avg_sorted_smooth[i_cond], aspect="auto", cmap="jet")
    axs[i_cond].title.set_text(conditions[i_cond])
if save_figs:
    fig_fpath = os.path.join(output_folder, f"pca_{mouse_id}_ssm_persistent_gsig={sigma}_"+datetime_str+file_extension)
    plt.savefig(fig_fpath)
    print(f"Saved to {fig_fpath}")
plt.show()

# Save multisession tracking results

In [None]:
#spatial_union, assignments, matchings