In [1]:
import numpy as np
import os

from utils import set_params
from utils import load_pickle, extract_used_data
from utils import mergeAB, align_track, z_score_on
from utils import plot3d_lines_to_html
from utils import pca_fit
from utils import JPCA

from utils.config import Params

In [2]:
def jpca_on(data: dict, params: Params, ana_tt: list[str], ana_bt: list[str]):
    
    blocks = []
    for index in params.ana_index_grid(ana_tt, ana_bt):
        if data["z_scored_firing"][index] is not None:
            blocks.append(data["z_scored_firing"][index])

    if len(blocks) == 0:
        raise ValueError('No data to jpca')

    # fr_sum: [num_neurons, (len_track * num_types)]
    fr_sum = np.hstack(blocks)
    
    _ , denoise = pca_fit(fr_sum, params.pca_n_components)
    fr_sum = denoise @ fr_sum
    
    jpca = JPCA()
    jpca.fit(fr_sum.T)
    
    colors = [
        [0.20, 0.20, 0.20],
        [0.99, 0.22, 0.24],
        [0.99, 0.55, 0.16],
        [0.80, 0.12, 0.46],
        [0.20, 0.45, 0.85],
        [0.00, 0.62, 0.45],
        [0.53, 0.36, 0.78],
    ]

    xlines, ylines, zlines = [], [], []
    lineLabels, lineColors = [], []
    
    for index in params.ana_index_grid(ana_tt, ana_bt):
        
        fr = data["z_scored_firing"][index]
        if fr is None:
            continue
    
        fr = denoise @ fr
        
        vecs = jpca.transform(fr.T, n_planes=2)

        xline = vecs[:,0]
        yline = vecs[:,1]
        zline = vecs[:,2]
        
        xlines.append(xline)
        ylines.append(yline)
        zlines.append(zline)
        
        row, col = index
        
        lineLabels.append(f"{params.tt[row]} {params.bt[col]}")
        
        if "position" in params.tt[row]:
            color = "#ea8428"
        elif "pattern" in params.tt[row]:
            color = "#288eea"
        else:
            color = "#808080"
        # color = colors[row]
        lineColors.append(color)


    segment_colors = []
    zones = data["aligned_zones_id"]
    for i in range(len(zones)):
        start = zones[i][0]
        end = zones[i][1]
        if i == 0:
            color = "#ffd700"
        elif i == 2:
            color = "#4d03fa"
        elif i == 4:
            color = "#66cd33"
        else:
            color = "#90acc1"
        segment_colors.append((start, end, color))
        
    return xlines, ylines, zlines, lineLabels, lineColors, segment_colors



In [3]:
# Params set

data_path = "../../data/"
results_path = "../../results/"
sub_directory = 'flexible_shift'

params = set_params(tt_preset='merge',
                         bt_preset='basic',
                         data_path=data_path,
                         results_path=results_path,
                         sub_directory=sub_directory,
                         pca_n_components=10)


In [4]:

data_dict = load_pickle(os.path.join(params.data_path, params.sub_directory))

for key, value in data_dict.items():
    
    data = data_dict[key]
    data = extract_used_data(data)
    mergeAB(data)
    
    align_track(data, params)
    # z_score_on(data, params, ana_tt=['*'], ana_bt=['*'])
    z_score_on(data, params, ana_tt=['origin'], ana_bt=['*'])
    z_score_on(data, params, ana_tt=['pattern_*'], ana_bt=['*'])
    z_score_on(data, params, ana_tt=['position_*'], ana_bt=['*'])
    x_lines, y_lines, zlines, line_labels, line_colors, segment_colors= jpca_on(data, params, ana_tt=['*'], ana_bt=['correct'])

    # plot
    out_path = os.path.join(params.results_path, params.sub_directory, "jpca_align_track", f"{key}_jpca_3D.html")
    axis_labels = dict(xaxis_title="JPCA1", yaxis_title="JPCA2", zaxis_title="JPCA3")
    plot3d_lines_to_html(x_lines, y_lines, zlines,
                         line_labels = line_labels,
                         line_colors = line_colors,
                         axis_labels = axis_labels,
                         segment_per_line = segment_colors,
                         out_html = out_path,
                         line_width = 3,
                         start_marker_size = 3,
                         end_marker_size = 3)
    