# Bl1-bl2-30min-60min place coding cell evolution

In [None]:
from IPython.core.getipython import get_ipython
from matplotlib import pyplot as plt
import numpy as np
import sys
import h5py
import os
import pandas as pd
import seaborn as sns
import plotly.graph_objects as go
sys.path.append("..")
from placecode import utils as ut
from placecode.from_caiman import *

try:
    if __IPYTHON__:
        get_ipython().run_line_magic('load_ext', 'autoreload')
        get_ipython().run_line_magic('autoreload', '2')
except NameError:
    pass
from datetime import datetime
import scipy
from scipy.ndimage import gaussian_filter1d  # smooth signal strength maps

sns.set(font_scale=3)
sns.set_style("whitegrid")

In [None]:
save_figs = False
if save_figs:
    file_extension = ".pdf"
    output_folder = ut.open_dir("Choose folder to save figures")
    print(f"Saving figures as {file_extension} is turned on. Saving figures to {output_folder}")
    now = datetime.now()
    datetime_str = f"{now.year:04}{datetime.now().month:02}{datetime.now().day:02}-{datetime.now().hour:02}{datetime.now().minute:02}{datetime.now().second:02}" 

## Open (hdf5) files

In [None]:
def extract_data(files_list, dict_mouse_data):
    Y_list = []
    A_list = []
    dims_list = []  # Cn entry in workspace # TODO: A_sparse always have lower resolution, probably from cropping... should I define that as dims?
    templates = []  # TODO: add templates to hdf5 files.. caiman unfortunately does not save them for some reason. need to manually care about this.
    p_vals = []
    conditions = []
    tv_angles = []
    tv_lengths = []
    ssm_zs = []
    ssm_event_masks = []
    mouse_ids = []
    for fpath in files_list:
        with h5py.File(fpath, "r") as hf:
            mouse_id = hf.attrs["mouse_id"]
            resolution = hf.attrs["resolution"][()]
            n_components = hf.attrs["n_units"]
            condition = hf.attrs["condition"]
            ps = hf["p_values_tuned"][()]
            A_data = hf["A_data"][()]
            A_indices = hf["A_indices"][()]
            A_indptr = hf["A_indptr"][()]
            A_shape = hf["A_shape"][()]
            tv_a = hf["tuned_vector_angles"][()]
            tv_l = hf["tuned_vector_lengths"][()]
            ssm_z = hf["ssm_z"][()]
            ssm_event_mask = hf["ssm_events_mask"][()]
            #spatial = ut.read_spatial(A_data, A_indices, A_indptr, A_shape, n_components, resolution, unflatten=False)
            spatial = scipy.sparse.csc_matrix((A_data, A_indices, A_indptr), shape=A_shape)
            dims_list.append(resolution)
            A_list.append(spatial)  # need to swap: (n_units, n_pixels) -> (n_pixels, n_units)
            p_vals.append(ps)
            conditions.append(condition)
            tv_angles.append(tv_a)
            tv_lengths.append(tv_l)
            ssm_zs.append(ssm_z)
            mouse_ids.append(mouse_id)
            ssm_event_masks.append(ssm_event_mask)
    for m_id in mouse_ids[1:]:  # make sure all data belongs to same mouse
        assert m_id == mouse_ids[0]
    mouse_id = mouse_ids[0]
    
    dict_mouse_data[mouse_id] = {"Y_list": Y_list, "A_list": A_list, "dims_list": dims_list, "templates": templates, "p_vals": p_vals, "conditions":conditions, "tv_angles": tv_angles, "tv_lengths": tv_lengths, "ssm_zs": ssm_zs, "ssm_event_masks": ssm_event_masks}

In [None]:
dict_mouse_data = dict()  # 
i_mouse = 1
next_mouse = True
while next_mouse:
    files_list = []
    conditions = ["bl1", "bl2", "30min", "60min"]
    for cond in conditions:
        fpath = ut.open_file(f"Mouse #{i_mouse}: Open hdf5 file for time point {cond}")
        if fpath == ".":  # user pressed cancel
            next_mouse = False
            break
        else:
            files_list.append(fpath)
    if len(conditions) == len(files_list): 
        extract_data(files_list, dict_mouse_data)
    else:
        if len(files_list) > 0:  # do not throw error if no files at all chosen for next mouse
            raise Exception(f"Not enough files chosen! Expected {len(conditions)}, received {len(files_list)}")
    i_mouse += 1

In [None]:
def convert_to_np(list_of_arrs):
    """
    given a list of 1D arrays, convert to a 2D array, add padding with np.nans to achieve equal column sizes 
    """
    return np.array([np.concatenate([lst, [np.nan]*(max_len - len(lst))]) for lst in list_of_arrs]).T

for mouse_id in dict_mouse_data.keys():
    tv_angles = dict_mouse_data[mouse_id]["tv_angles"]
    tv_lengths = dict_mouse_data[mouse_id]["tv_lengths"]
    p_vals = dict_mouse_data[mouse_id]["p_vals"]

    # convert tuned vector data into numpy array. To deal with varying number of units per recording (condition), pad each column to the longest with np.nan
    max_len = max(len(lst) for lst in tv_angles) 

    dict_mouse_data[mouse_id]["tv_angles_padded"] = convert_to_np(tv_angles)
    dict_mouse_data[mouse_id]["tv_lengths_padded"] = convert_to_np(tv_lengths)
    dict_mouse_data[mouse_id]["p_vals_padded"] = convert_to_np(p_vals)

    

In [None]:
for mouse_id in dict_mouse_data.keys():
    templates = dict_mouse_data[mouse_id]["templates"]
    dims_list = dict_mouse_data[mouse_id]["dims_list"]
    if len(templates) > 0:
        templates_cropped = []
        for template in templates:
            FOV_shape = template.shape
            cropped_shape = dims_list[0]
            
            x_crop_onesided = (FOV_shape[0] - cropped_shape[0])//2
            assert 2*x_crop_onesided == FOV_shape[0] - cropped_shape[0]

            y_crop_onesided = (FOV_shape[1] - cropped_shape[1])//2
            assert 2*y_crop_onesided == FOV_shape[1] - cropped_shape[1]
            template_cropped = template[y_crop_onesided:-y_crop_onesided,x_crop_onesided:-x_crop_onesided]  # TODO: x and y swapped?
            templates_cropped.append(template_cropped)
        dict_mouse_data[mouse_id]["templates_cropped"] = templates_cropped
    # TODO: use templates for multisession registration
    

## Use `register_multisession()`

The function `register_multisession()` requires 3 arguments:
- `A`: A list of ndarrays or scipy.sparse.csc matrices with (# pixels X # component ROIs) for each session
- `dims`: Dimensions of the FOV, needed to restore spatial components to a 2D image
- `templates`: List of ndarray matrices of size `dims`, template image of each session

In [None]:
n_mice = len(dict_mouse_data.keys())
for i_mouse, mouse_id in enumerate(dict_mouse_data.keys()):
    print(f"Working on {mouse_id}, mouse #{i_mouse+1}/{n_mice}")
    A_list = dict_mouse_data[mouse_id]["A_list"]
    dims_list = dict_mouse_data[mouse_id]["dims_list"]
    spatial_union, assignments, matchings = register_multisession(A=A_list, dims=dims_list[0])
    dict_mouse_data[mouse_id]["spatial_union"] = spatial_union
    dict_mouse_data[mouse_id]["assignments"] = assignments
    dict_mouse_data[mouse_id]["matchings"] = matchings


The function returns 3 variables for further analysis:
- `spatial_union`: csc_matrix (# pixels X # total distinct components), the union of all ROIs across all sessions aligned to the FOV of the last session.
- `assignments`: ndarray (# total distinct components X # sessions). `assignments[i,j]=k` means that component `k` from session `j` has been identified as component `i` from the union of all components, otherwise it takes a `NaN` value. Note that for each `i` there is at least one session index `j` where `assignments[i,j]!=NaN`.
- `matchings`: list of (# sessions) lists. Saves `spatial_union` indices of individual components in each session. `matchings[j][k] = i` means that component `k` from session `j` is represented by component `i` in the union of all components `spatial_union`. In other words `assignments[matchings[j][k], j] = j`.

## Plot the matching

In [None]:
# TODO: extract as function that takes lists or something.
# Goal: be able to use it for plotting various scenarios: plot all cells, plot cell categories (red=PC, ...)
#   plot stable baseline cells as red, all rest as grey

In [None]:
fig, axs = plt.subplots(n_mice, len(conditions), figsize=(24, 24))

use_continuous_cmap = False
if use_continuous_cmap:
    cm = plt.get_cmap('gist_rainbow')
    colors_arr = cm(np.linspace(0, 1, 30))
    i_shuffled_colors=np.arange(len(colors_arr))  # shuffle colors
    np.random.shuffle(i_shuffled_colors)
    colors_arr = colors_arr[i_shuffled_colors]
else:
    cm = plt.get_cmap("tab20")
    colors_arr = cm(np.linspace(0, 1, 20))

for i_id, mouse_id in enumerate(dict_mouse_data.keys()):
    print(mouse_id)
    assignments = dict_mouse_data[mouse_id]["assignments"]  # (n_independent_components, n_conditions)
    n_conditions = assignments.shape[1]
    dims = dict_mouse_data[mouse_id]["dims_list"][0]  # should be [512, 512]
    dims_4d = dims.copy()
    dims_4d = np.concatenate([dims_4d,[3]])  # RGB colors
    dims_4d = np.concatenate([dims_4d, [n_conditions]])  # individual conditions
    frames = np.zeros(dims_4d)  # shape (x, y, 3, n_conditions) create image data to show for each condition.
    # go over each assignment row (same cells over all conditions). Add colored pixel
    for i_component in range(assignments.shape[0]):
        idxs_component = assignments[i_component]
        rgba = colors_arr[i_component%len(colors_arr)]  # cycle over the colors
        # for each condition, add spatial component of cell to image as specific colored pixels
        for i_condition in range(n_conditions):
            i_cell = idxs_component[i_condition]
            if not np.isnan(i_cell):  # if nan, no presence of cell was found in that condition
                i_cell = int(i_cell)
                # set the cell pixels to the corresponding r, g, b
                for i_color in range(3):  # r, g, b
                    frames[dict_mouse_data[mouse_id]["A_list"][i_condition][:, i_cell].todense().reshape((512, 512)) > 0, i_color, i_condition] = rgba[i_color]
    for i_condition, condition in enumerate(conditions):
        ax = axs[i_id, i_condition]
        ax.title.set_text(f"{mouse_id} - {condition}")
        ax.imshow(frames[:,:,:, i_condition])
        ax.set_axis_off()
plt.tight_layout()
plt.show()

## Get "stable baseline place coding cells"
i.e. cells that were place coding in both baseline recordings

In [None]:
for mouse_id in dict_mouse_data.keys():
    print(mouse_id)
    p_vals = dict_mouse_data[mouse_id]["p_vals"] 
    assignments = dict_mouse_data[mouse_id]["assignments"]

    # drop all cells with nan in any of the bl
    assignments_stable_bl_pc = assignments[np.logical_and(~np.isnan(assignments[:, 0]), ~np.isnan(assignments[:, 1]))]
    # filter assignments to place coding cells in first bl
    #   sort assignment in first bl
    print(f"Number of cells with p-value in both bl: {len(assignments_stable_bl_pc)}")
    idx_sorted_bl1 = np.argsort(assignments_stable_bl_pc[:, 0])
    assignments_stable_bl_pc = assignments_stable_bl_pc[idx_sorted_bl1]
    #   take only place cells
    idx_pc_bl1 = np.nonzero(p_vals[0][assignments_stable_bl_pc[:, 0].astype(np.int32)] <= 0.05)[0]  # indices of place coding cells in bl1
    assignments_stable_bl_pc = assignments_stable_bl_pc[idx_pc_bl1]
    print(f"Number of bl1 pc cells: {len(assignments_stable_bl_pc)}")

    # filter pc in second bl
    #   sort assignment in second bl
    idx_sorted_bl2 = np.argsort(assignments_stable_bl_pc[:, 1])
    assignments_stable_bl_pc = assignments_stable_bl_pc[idx_sorted_bl2]
    #   take place cells
    idx_pc_bl2 = np.nonzero(p_vals[1][assignments_stable_bl_pc[:, 1].astype(np.int32)] <= 0.05)[0]  # indices of place coding cells in bl2
    assignments_stable_bl_pc = assignments_stable_bl_pc[idx_pc_bl2]
    print(f"Number of bl1+bl2 pc cells: {len(assignments_stable_bl_pc)}")

    # check that indeed no nans left in baseline
    assert ~np.isnan(assignments_stable_bl_pc[:,0]).any()
    assert ~np.isnan(assignments_stable_bl_pc[:,1]).any()
    # check that indeed all the cells ar place coding in baselines
    assert (p_vals[0][assignments_stable_bl_pc[:,0].astype(np.int32)] <= 0.05).all()
    assert(p_vals[1][assignments_stable_bl_pc[:,1].astype(np.int32)] <= 0.05).all()

    dict_mouse_data[mouse_id]["assignments_stable_bl_pc"] = assignments_stable_bl_pc
    print()

### Plot stable baseline cells

## Pool mice

In [None]:
for mouse_id in dict_mouse_data.keys():
    assignments_stable_bl_pc = dict_mouse_data[mouse_id]["assignments_stable_bl_pc"] 
    print(len(assignments_stable_bl_pc))

In [None]:
np.squeeze(assignments_stable_bl_pc)

In [None]:
for mouse_id in dict_mouse_data.keys():
    assignments = dict_mouse_data[mouse_id]["assignments"] 
    print()

## Sankey-plot of stable initial place coding cells

In [None]:
dict_mouse_data["WEZ8917"]["assignments_stable_bl_pc"][:, 2]

In [None]:
dict_mouse_data["WEZ8917"]["assignments_stable_bl_pc"][:, 0]

In [None]:
a = np.zeros(10)
b = np.array([1, 2, 3])
a[:len(b)] = b

In [None]:
labels = [] 
colors= []
xs = []  # location of boxes
ys = []
# PC cells: p value <= 0.05, assignment exists
# nPC cells: p value > 0.05, assignment exists
# LA cells: p value == np.nan, assignment exists
# IN cells: assignment does not exist

n_classes = 4  # PC, nPC, LA (low activity), IN (invisible)
for i_condition, condition in enumerate(conditions):
  labels.extend([f"PC {condition}", f"nPC {condition}", f"lowA {condition}", f"invisible {condition}"])  # for each condition, check categories PC and not-PC
  colors.extend(["red", "blue", "slategrey", "black"])
  xs.extend([0.2*i_condition]*n_classes)
  ys.extend([0.2*i for i in range(n_classes)])
n_conditions = len(conditions)

# in each condition, we have 4 categories, each have 4 targets in the next category
sources = []  # should be 0, 1, 2, 3, 0, 1, 2, 3, ..., 0, 1, 2, 3, 4, 5, 6, 7, ...
targets = []  # should be 4, 4, 4, 4, 5, 5, 5, 5, ..., 7, 7, 7, 7, 8, 8, 8, 8, ...
values = []
link_colors = []
n_cells = len(dict_mouse_data["WEZ8917"]["assignments_stable_bl_pc"][:, 0])

for i_condition in range(n_conditions-1):  # last condition does not have output
  print(i_condition)
  # get PC+nPC+LA cell indices sorted by first baseline
  idx_cells_source = dict_mouse_data["WEZ8917"]["assignments_stable_bl_pc"][:, i_condition]
  idx_cells_source = idx_cells_source[~np.isnan(idx_cells_source)].astype(np.int32)
  idx_cells_target = dict_mouse_data["WEZ8917"]["assignments_stable_bl_pc"][:, i_condition+1]
  idx_cells_target = idx_cells_target[~np.isnan(idx_cells_target)].astype(np.int32)

  # get p values, indices matched (i. e. first p value is for the same neuron in both list)
  # set p values as following:
  #   PC, nPC: keep original p
  #   LA: keep np.nan as p
  #   IN: set p to -1 (<0)
  p_vals_source = np.full(n_cells, -1.0)  # set default value to invisible cell p value
  p_vals_target = np.full(n_cells, -1.0)  # set default value to invisible cell p value

  # set PC, nPC, LA cell p values
  p_vals_temp = dict_mouse_data["WEZ8917"]["p_vals"][i_condition][idx_cells_source]
  assert n_cells >= len(p_vals_temp)
  p_vals_source[:len(p_vals_temp)] = p_vals_temp
  p_vals_temp = dict_mouse_data["WEZ8917"]["p_vals"][i_condition][idx_cells_target]
  assert n_cells >= len(p_vals_temp)
  p_vals_target[:len(p_vals_temp)] = p_vals_temp 

  # PC, nPC, lowA, invisible sources flow to PC in target
  # i. e. PC[i_condition] -> PC[i_condition+1], nPC[i_condition] -> PC[i_condition+1], LA[i_condition] -> PC[i_condition+1], IN[i_condition] -> PC[i_condition+1]
  n_PC_to_PC = np.sum(np.logical_and(p_vals_source <= 0.05,  p_vals_target <= 0.05))
  n_nPC_to_PC = np.sum(np.logical_and(p_vals_source > 0.05,  p_vals_target <= 0.05))
  n_LA_to_PC = np.sum(np.logical_and(np.isnan(p_vals_source),  p_vals_target <= 0.05))
  n_IN_to_PC = np.sum(np.logical_and(p_vals_source < 0, p_vals_target <= 0.05))  # np.sum(np.logical_and(p_vals_source <= 0.05,  p_vals_target <= 0.05))
  sources.extend([n_classes*i_condition, n_classes*i_condition+1, n_classes*i_condition+2, n_classes*i_condition+3])
  targets.extend([n_classes*(i_condition+1), n_classes*(i_condition+1), n_classes*(i_condition+1), n_classes*(i_condition+1)])
  values.extend([n_PC_to_PC, n_nPC_to_PC, n_LA_to_PC, n_IN_to_PC])
  link_colors.extend(["rgba(255, 0, 0, 0.4)", "rgba(0, 0, 255, 0.4)", "rgba(220,220,220, 0.4)", "rgba(255, 255, 255, 0.4)"])  # PC -> x is light blue, nPC -> x is light red
  
  # PC and nPC sources flow to nPC in target
  # i. e. PC[i_condition] -> nPC[i_condition+1], nPC[i_condition] -> nPC[i_condition+1] ...
  n_PC_to_nPC = np.sum(np.logical_and(p_vals_source <= 0.05,  p_vals_target > 0.05))
  n_nPC_to_nPC = np.sum(np.logical_and(p_vals_source > 0.05,  p_vals_target > 0.05))
  n_LA_to_nPC = np.sum(np.logical_and(np.isnan(p_vals_source),  p_vals_target > 0.05))
  n_IN_to_nPC = np.sum(np.logical_and(p_vals_source < 0, p_vals_target > 0.05))
  sources.extend([n_classes*i_condition, n_classes*i_condition+1, n_classes*i_condition+2, n_classes*i_condition+3])
  targets.extend([n_classes*(i_condition+1)+1, n_classes*(i_condition+1)+1, n_classes*(i_condition+1)+1, n_classes*(i_condition+1)+1])
  values.extend([n_PC_to_nPC, n_nPC_to_nPC, n_LA_to_nPC, n_IN_to_nPC])
  link_colors.extend(["rgba(255, 0, 0, 0.4)", "rgba(0, 0, 255, 0.4)", "rgba(220,220,220, 0.4)", "rgba(255, 255, 255, 0.4)"])  # PC -> x is light blue, nPC -> x is light red

  # PC, nPC, LA, IN sources flow to LA in target
  n_PC_to_LA = np.sum(np.logical_and(p_vals_source <= 0.05,  np.isnan(p_vals_target)))
  n_nPC_to_LA = np.sum(np.logical_and(p_vals_source > 0.05,  np.isnan(p_vals_target)))
  n_LA_to_LA = np.sum(np.logical_and(np.isnan(p_vals_source),  np.isnan(p_vals_target)))
  n_IN_to_LA = np.sum(np.logical_and(p_vals_source < 0, np.isnan(p_vals_target)))
  sources.extend([n_classes*i_condition, n_classes*i_condition+1, n_classes*i_condition+2, n_classes*i_condition+3])
  targets.extend([n_classes*(i_condition+1)+2, n_classes*(i_condition+1)+2, n_classes*(i_condition+1)+2, n_classes*(i_condition+1)+2])
  values.extend([n_PC_to_LA, n_nPC_to_LA, n_LA_to_LA, n_IN_to_LA])
  link_colors.extend(["rgba(255, 0, 0, 0.4)", "rgba(0, 0, 255, 0.4)", "rgba(220,220,220, 0.4)", "rgba(255, 255, 255, 0.4)"])  # PC -> x is light blue, nPC -> x is light red

  # PC, nPC, LA, IN sources flow to IN in target
  n_PC_to_IN = np.sum(np.logical_and(p_vals_source <= 0.05,  p_vals_target < 0))
  n_nPC_to_IN = np.sum(np.logical_and(p_vals_source > 0.05,  p_vals_target < 0))
  n_LA_to_IN = np.sum(np.logical_and(np.isnan(p_vals_source),  p_vals_target < 0))
  n_IN_to_IN = np.sum(np.logical_and(p_vals_source < 0, p_vals_target < 0))

  sources.extend([n_classes*i_condition, n_classes*i_condition+1, n_classes*i_condition+2, n_classes*i_condition+3])
  targets.extend([n_classes*(i_condition+1)+3, n_classes*(i_condition+1)+3, n_classes*(i_condition+1)+3, n_classes*(i_condition+1)+3])
  values.extend([n_PC_to_IN, n_nPC_to_IN, n_LA_to_IN, n_IN_to_IN])
  link_colors.extend(["rgba(255, 0, 0, 0.4)", "rgba(0, 0, 255, 0.4)", "rgba(220,220,220, 0.4)", "rgba(255, 255, 255, 0.4)"])  # PC -> x is light blue, nPC -> x is light red


#xs = [0.0, 0.2, 0.2, 0.2, 0.2, 0.4, 0.4, 0.4, 0.6, 0.6, 0.6, 0.6]
#ys = [0.5, 0.2, 0.4, 0.6, 0.8, 0.2, 0.4, 0.6, 0.3, 0.7, 0.3, 0.7]
fig = go.Figure(data=[go.Sankey(
  arrangement="freeform",
    node = dict(
      pad = 10,
      #thickness = 20,
      line = dict(color = "black", width = 0.5),
      label = labels,
      color = colors,
      x = xs,
      y = ys,
    ),
    link = dict(
      source = sources, # indices correspond to labels, eg A1, A2, A1, B1, ...
      target = targets,
      value = values,
      color=link_colors
  ))])

fig.update_layout(title_text="PC-nPC-LA-IN", font_size=10)
#fig.write_html("D:\\Downloads\\pc_npc.html")
fig.show()


In [None]:
# TODO: the connection numbers are not good. Add assert that input equals output flow?

In [None]:
for i in range(len(xs)):
    print(f"{labels[i]}: {xs[i]} {ys[i]}")

In [None]:
fig = go.Figure(go.Sankey(
    arrangement = "snap",
    node = {
        "label": ["A", "B", "C", "D", "E", "F", "G"],
        "x": [0.2, 0.2, 0.5, 0.7, 0.3, 0.5, 0.5],
        "y": [0.7, 0.5, 0.2, 0.4, 0.2, 0.3, 1.0],
        'pad':10},  # 10 Pixels
    link = {
        "source": [0, 0, 1, 2, 5, 4, 3, 5, 1, 4],
        "target": [5, 3, 4, 3, 0, 2, 2, 3, 6, 6],
        "value": [1, 2, 1, 1, 1, 1, 1, 2, 2, 3]}))

fig.show()

In [None]:
# TODO: despite low quantity, make the shankey plot for pooled place cells