In [None]:
%matplotlib widget
from pathlib import Path
from tqdm import tqdm
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import flammkuchen as fl

import seaborn as sns
sns.set(style="ticks", palette="deep")
cols = sns.color_palette()

from lotr.experiment_class import LotrExperiment
from bouter.utilities import crop

from lotr.plotting import add_cbar, color_plot, despine, add_scalebar
from lotr.pca import get_fictive_trajectory, pca_and_phase, \
                     linear_regression, fit_phase_neurons
from lotr.utils import zscore, interpolate, reduce_to_pi, linear_regression, get_rot_matrix

from tqdm import tqdm

In [None]:
master_path = Path("/Users/luigipetrucco/Desktop/all_source_data/full_ring")
file_list = [f.parent for f in master_path.glob("*/*[0-9]_f*/selected.h5")]

In [None]:
from lotr.pca import pca_and_phase

fig, axs_good = plt.subplots(2, 2, figsize=(8., 6.))
fig2, axs_shuf = plt.subplots(2, 2, figsize=(8., 6.))
for f in file_list[:1]:
    exp = LotrExperiment(f)
    for axs, shuf_f in zip([axs_good, axs_shuf], 
                         [lambda x: x, np.random.shuffle]):
        traces = exp.traces 
        hdn_indexes = exp.hdn_indexes
        t_slice = exp.pca_t_slice
        sequence = np.arange(len(hdn_indexes))
        shuf_f(sequence)

        pcaed, phase, pca, hf_c = pca_and_phase(traces[t_slice, hdn_indexes].T)

        coords = exp.coords[exp.hdn_indexes, :][sequence, :]

        de_meaned_coords = coords - np.mean(coords, axis=0)
        norm_coords = coords - np.percentile(coords, 2, axis=0)
        norm_coords[norm_coords < 0] = 0
        norm_coords = (norm_coords / np.sum(norm_coords, 0)) # * 2 - 1
        norm_coords = norm_coords - np.mean(norm_coords, 0)

        centered_pcs = pcaed[:, :2] - hf_c[:2]

        # Replace the following operation over the three coords using einsum:
        avg_vect = np.sum(centered_pcs.T*norm_coords[:, 1], 1)
        avg_angle = np.angle(avg_vect[0] + 1j * avg_vect[1])
        avg_vects = np.einsum('ij,ik->jk', centered_pcs, norm_coords)
        avg_angles = np.angle(avg_vects[0, :] + 1j * avg_vects[1, :])
        mean_angle = np.angle(np.sum(np.cos(avg_angles[1:])) + 1j * np.sum(np.sin(avg_angles[1:])))

        invert_mat = np.array([[1, 0], [0, -np.sign(reduce_to_pi(avg_angles[2] - avg_angles[1]))]])

        rotated_pcs = (get_rot_mat(3*np.pi/4) @ invert_mat @ get_rot_mat(- mean_angle) @ centered_pcs.T).T

        rotated_phase = np.angle(rotated_pcs[:, 0] + 1j * rotated_pcs[:, 1])
        ax = axs[0]

        axs[0, 0].scatter(rotated_pcs[:, 0], rotated_pcs[:, 1], 
                     c=np.angle(norm_coords[:, 1] + 1j * norm_coords[:, 2]), s=8,
                          lw=0, alpha=0.8, cmap="twilight") 
        axs[0, 0].axis("equal")

        axs[0, 1].scatter(de_meaned_coords[:, 1], de_meaned_coords[:, 2], 
                     c=rotated_phase, s=8, lw=0, alpha=0.8,
                   cmap="twilight") 
        axs[0, 1].axis("equal")

        axs[1, 0].scatter(rotated_phase, de_meaned_coords[:, 1], c="k", s=5, alpha=0.3, lw=0) 
        axs[1, 1].scatter(rotated_phase, de_meaned_coords[:, 2], c="k", s=5, alpha=0.3, lw=0) 


In [None]:
for axs in [axs_good, axs_shuf]:
    add_scalebar(axs[0, 0], xlabel="PC1", ylabel="PC2")
    add_scalebar(axs[0, 1], xlabel="lef.-rig.", ylabel="pos.-ant.")

In [None]:
ax = axs_shuf[0, 0]

In [None]:
add_scalebar(axs[0, 1], xlabel="lef.-rig.", ylabel="pos.-ant.")

In [None]:
add_scalebar(axs[0, 1], xlabel="lef.-rig.", ylabel="pos.-ant.")

In [None]:
ax.yaxis.get_ticklocs()[1] - ax.yaxis.get_ticklocs()[0]

In [None]:
b_len = 3
bar_pos_x, bar_pos_y = pcaed[:, :2].min(0) - b_len 
pc_ax.plot([bar_pos_x, bar_pos_x, bar_pos_x+b_len], 
          [bar_pos_y+b_len, bar_pos_y, bar_pos_y], lw=0.5, c=(0.3,)*3)
pc_ax.text(bar_pos_x, bar_pos_y + b_len/2, "PC2", ha="right", va="center", 
           rotation='vertical', fontsize=8)
pc_ax.text(bar_pos_x + b_len/2, bar_pos_y, "PC1", ha="center", va="top", fontsize=8)
pc_ax.axis("off")

In [None]:


f, ax = plt.subplots(1, 1, figsize=(9,3))
ax.plot(np.arange(100)/ 50, np.random.randn(100))

add_scalebar(ax, y_len=2, x_units="s")

In [None]:
for ax, pcs in zip([axs, [centered_pcs, rotated_pcs]]):
    print(ax)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3., 3.))

    ax.scatter(centered_pcs[:, 0], centered_pcs[:, 1], 
                 c=np.angle(norm_coords[:, 1] + 1j * norm_coords[:, 2]), cmap="twilight") 
    ax.plot([0, np.cos(avg_angles[1]) * hf_c[2]], 
            [0, np.sin(avg_angles[1]) * hf_c[2]]) 
    ax.plot([0, np.cos(avg_angles[2]) * hf_c[2]], 
            [0, np.sin(avg_angles[2]) * hf_c[2]]) 
    
    mean_angle = np.angle(np.sum(np.cos(avg_angles[1:])) + 1j * np.sum(np.sin(avg_angles[1:])))
    ax.plot([0, np.cos(mean_angle) * hf_c[2]], 
            [0, np.sin(mean_angle) * hf_c[2]]) 
    ax.set_title(f"{f.name} - {reduce_to_pi(avg_angles[2] - avg_angles[1])}")
    plt.axis("equal")

In [None]:
plt.figure(figsize=(3, 3))