In [1]:
import numpy as np
import os

from utils import set_params
from utils import load_pickle, extract_used_data
from utils import mergeAB, align_track, z_score_on
from utils import plot3d_lines_to_html
from utils import umap_fit

from utils.config import Params

In [2]:

def umap_on(data: dict, params: Params, ana_tt: list[str], ana_bt: list[str]) -> tuple:
    
    blocks = []
    for index in params.ana_index_grid(ana_tt, ana_bt):
        if data["z_scored_firing"][index] is not None:
            blocks.append(data["z_scored_firing"][index])

    if len(blocks) == 0:
        raise ValueError('No data to umap')

    # fr_sum: [num_neurons, (len_track * num_types)]
    fr_sum = np.hstack(blocks)

    model = umap_fit(fr_sum.T,
                     n_neighbors=params.umap_n_neighbors,
                     min_dist=params.umap_min_dist,
                     n_components=params.umap_n_components,
                     metric=params.umap_metric,
                     )

    xlines, ylines, zlines = [], [], []
    lineLabels, lineColors = [], []
    
    for index in params.ana_index_grid(ana_tt, ana_bt):
        
        fr = data["z_scored_firing"][index]
        if fr is None:
            continue

        X = fr.T
        emb = model.transform(X)
            
        xline = emb[:, 0]
        yline = emb[:, 1]
        zline = emb[:, 2]
        
        xlines.append(xline)
        ylines.append(yline)
        zlines.append(zline)
        
        row, col = index
        
        lineLabels.append(f"{params.tt[row]} {params.bt[col]}")
        
        if row >= 4:
            color = "#ea8428"
        elif row >= 1:
            color = "#288eea"
        else:
            color = "#808080"
        lineColors.append(color)
        
    segment_colors = []
    zones = data["aligned_zones_id"]
    for i in range(len(zones)):
        start = zones[i][0]
        end = zones[i][1]
        if i == 0:
            color = "#ffd700"
        elif i == 2:
            color = "#4d03fa"
        elif i == 4:
            color = "#66cd33"
        else:
            color = "#90acc1"
        segment_colors.append((start, end, color))
        

    return xlines, ylines, zlines, lineLabels, lineColors, segment_colors

In [3]:
# Params set

data_path = "../../data/"
results_path = "../../results/"
sub_directory = 'flexible_shift'

params = set_params(tt_preset='merge',
                    bt_preset='basic',
                    data_path=data_path,
                    results_path=results_path,
                    sub_directory=sub_directory,
                    umap_n_neighbors=50,
                    umap_min_dist=0.5,
                    len_pos_average=10)


In [4]:

# data_dict = load_pickle(os.path.join(params.data_path, params.sub_directory))

# for key, value in data_dict.items():
    
#     data = data_dict[key]
#     data = extract_used_data(data)
#     mergeAB(data)
    
#     align_track(data, params)
#     # z_score_on(data, params, ana_tt=['*'], ana_bt=['*'])
#     for trial in params.tt:
#         for behavior in params.bt:
#             z_score_on(data, params, ana_tt=[trial], ana_bt=[behavior])
#     # z_score_on(data, params, ana_tt=['origin'], ana_bt=['*'])
#     # z_score_on(data, params, ana_tt=['pattern_*'], ana_bt=['*'])
#     # z_score_on(data, params, ana_tt=['position_*'], ana_bt=['*'])
#     x_lines, y_lines, z_lines, line_labels, line_colors, segment_colors = umap_on(data, params, ana_tt=['*'], ana_bt=['correct'])

#     # plot
#     out_path = os.path.join(params.results_path, params.sub_directory, "umap_align_track", f"{key}_umap.html")
#     axis_labels = dict(xaxis_title="UMAP1", yaxis_title="UMAP2", zaxis_title="UMAP3")
#     plot3d_lines_to_html(x_lines, y_lines, z_lines, line_labels = line_labels, line_colors = line_colors,
#                          segment_per_line = segment_colors, 
#                          out_html = out_path, 
#                          axis_labels = axis_labels,
#                          start_marker_size=4,
#                          end_marker_size=4)
    

In [5]:
data = load_pickle('../../data/flexible_shift/RDP02-PFCsep.pkl')

data = extract_used_data(data)
mergeAB(data)

align_track(data, params, is_gaussian=False)
# z_score_on(data, params, ana_tt=['*'], ana_bt=['*'])
for trial in params.tt:
    for behavior in params.bt:
        z_score_on(data, params, ana_tt=[trial], ana_bt=[behavior])
# z_score_on(data, params, ana_tt=['origin'], ana_bt=['*'])
# z_score_on(data, params, ana_tt=['pattern_*'], ana_bt=['*'])
# z_score_on(data, params, ana_tt=['position_*'], ana_bt=['*'])
x_lines, y_lines, z_lines, line_labels, line_colors, segment_colors = umap_on(data, params, ana_tt=['*'], ana_bt=['correct'])

axis_labels = dict(xaxis_title="UMAP1", yaxis_title="UMAP2", zaxis_title="UMAP3")
plot3d_lines_to_html(x_lines, y_lines, z_lines, line_labels = line_labels, line_colors = line_colors,
                        segment_per_line = segment_colors, 
                        axis_labels = axis_labels,
                        start_marker_size=4,
                        end_marker_size=4)