## Reorient temporal PCA to match anatomy
In this notebook, we look at the anatomical distribution of head direction neurons across fish, to assess whether there is anatomical consistency over the anatomical location of cells w/r/t their position in the network activity topology. 

#### TODO
 - [ ] redefine the centering of the coordinates after morphing is done
 - [ ] discuss way of finding the two axis, might be suboptimal/not wor for all cases

In [None]:
%matplotlib widget
from pathlib import Path
import numpy as np
from matplotlib import pyplot as plt

from lotr import LotrExperiment, DATASET_LOCATION

from lotr.plotting import despine, add_scalebar, get_circle_xy, color_stack
from lotr.pca import pca_and_phase
from lotr.utils import reduce_to_pi, get_rot_matrix, get_vect_angle

# import seaborn as sns
# sns.set(style="ticks", palette="deep")
# cols = sns.color_palette()

In [None]:
# List all experiments
master_path = Path(DATASET_LOCATION)
file_list = [f.parent for f in master_path.glob("*/*[0-9]_f*/selected.h5")]

### Showcase analysis for a single fish

In [None]:
path = master_path / "210926_f0" / "210926_f0_gainmod"  # "210314_f1" / "210314_f1_natmov"
exp = LotrExperiment(path)

# Compute PCA in time and fit circle:
traces = exp.traces[exp.pca_t_slice, exp.hdn_indexes]
pcaed, phase, pca, circle_params = pca_and_phase(traces.T)

# Center PCs on center of the fit circle:
centered_pcs = pcaed[:, :2] - circle_params[:2]

f, ax = plt.subplots(figsize=(2.5, 2.5))

ax.scatter(centered_pcs[:, 0], centered_pcs[:, 1])
ax.plot(*get_circle_xy((0, 0, circle_params[2])), c=(0.2,)*3)
ax.axis("equal")
plt.show()
add_scalebar(xlabel="PC1", ylabel="PC2", xlen=30, ylen=30)

In [None]:
# !!! This will change with registered coordinates
def get_normalized_coords(coords):
    # Subtract the mean to coarsely register them for plotting:
    centered_coords = coords - np.mean(coords, axis=0)

    # For PCA registrastion, instead, convert coordinates in weights of mean 0 to be used as weights.
    # Use percentiles along every axis to normalize them for removing outliers.
    # w_coords = coords - np.percentile(coords, 2, axis=0)
    # w_coords[w_coords < 0] = 0

    #for i in range(coords.shape[1]):
    #    thr = np.percentile(coords[:, i], 90)
    #    coords[coords[:, i] > thr, i] = thr

    # w_coords = (w_coords / np.sum(w_coords, 0))
    # w_coords = w_coords - np.mean(w_coords, 0)
    
    norm_coords = coords - np.percentile(coords, 2, axis=0)
    norm_coords[norm_coords < 0] = 0
    norm_coords = (norm_coords / np.sum(norm_coords, 0)) # * 2 - 1
    norm_coords = norm_coords - np.mean(norm_coords, 0)
    
    return centered_coords, norm_coords

centered_coords, w_coords = get_normalized_coords(exp.coords[exp.hdn_indexes, :])

In [None]:
f, axs = plt.subplots(1, 3, figsize=(7.5, 2.5))

for i, ax in enumerate(axs):
    ax.scatter(centered_pcs[:, 0], centered_pcs[:, 1], c=w_coords[:, i], 
               cmap="viridis")
    ax.plot(*get_circle_xy((0, 0, circle_params[2])), c=(0.2,)*3)
    ax.axis("equal")
    ax.set_title(["pos. z", "pos. left-right", "pos. rostr.-caud."][i])
    add_scalebar(ax, xlabel="PC1", ylabel="PC2", xlen=30, ylen=30, disable_axis=False)
    despine(ax, "all")

#### Fit to match anatomy 

In [None]:
def reorient_pcs(centered_pcs, w_coords):
    
    FINAL_TH_SHIFT = 3*np.pi/4  # arbitrary final rotation for plotting purposes
    
    # We compute the PCs vector averages across the population using coordinates along 
    # each anatomical axis as weights.
    # The result is a 3 x 2 matrix containing the average vector for each of the 
    # 3 anatomical axes used as weights
    avg_vects = np.einsum('ij,ik->jk', centered_pcs, w_coords)
    
    # We then compute average angle for all axes:
    avg_angles = get_vect_angle(avg_vects)
    # and we take the mean between left-right and front-caud axes angles:
    mean_angle = np.angle(np.sum(np.cos(avg_angles[1:])) + 1j * np.sum(np.sin(avg_angles[1:])))
    
    # Since PC signs can be arbitrary, we need to find whether left and right were flipped.
    # This we decide based on the difference in sign from the angle between the lateral and sagittal axes
    # fit in PC space:
    s = -np.sign(reduce_to_pi(avg_angles[2] - avg_angles[1]))
    invert_mat = np.array([[1, 0], [0, s]])

    # Finally, we combine all transformation. The last rotation is an arbitrary one so that
    #  the most rostral ROIs are in the upper part of the plot:
    return  (get_rot_matrix(FINAL_TH_SHIFT) @ invert_mat @ get_rot_matrix(-mean_angle) @ centered_pcs.T).T

rotated_pcs = reorient_pcs(centered_pcs, w_coords)

In [None]:
f, ax = plt.subplots(figsize=(2.5, 2.5))

ax.scatter(rotated_pcs[:, 0], rotated_pcs[:, 1], c=w_coords[:, 2])
ax.plot(*get_circle_xy((0, 0, circle_params[2])), c=(0.2,)*3)
ax.axis("equal")
plt.show()
add_scalebar(xlabel="PC1 (lat.)", ylabel="PC2 (sag.)", xlen=60, ylen=60)

In [None]:
# We can now calculate a phase for each neuron from their position in this rotated space:
rotated_pc_phase = np.angle(rotated_pcs[:, 0] + 1j * rotated_pcs[:, 1])

# And, for plotting, computing an "anatomical phase", aka position around the anatomy:
anatomical_phase = np.angle(w_coords[:, 1] + 1j * w_coords[:, 2])

# And plot it over the anatomy:
f, axs = plt.subplots(2, 3, figsize=(8, 4))
for j, (color_scheme, lab) in enumerate(zip([rotated_pc_phase, anatomical_phase],
                                         ["color by phase", "color by anatomy"])):
    for i, scatter_vals in enumerate([rotated_pcs, centered_coords[:, 1:]]):
        x = scatter_vals[:, 0] if i == 0 else -scatter_vals[:, 0]
        axs[j, i].scatter(x, scatter_vals[:, 1],
                      c=color_scheme, cmap="twilight")
        if i == 1:
            add_scalebar(axs[j, i], xlabel="lf.-rt.", ylabel="post.-ant.", xlen=60, ylen=60, 
                         ypos=-60)
        else:
            add_scalebar(axs[j, i], xlabel="rPC1", ylabel="rPC2", xlen=45, ylen=45,
                        disable_axis=False)
        axs[j, i].axis("equal")
        despine(axs[j, i], "all")
    n_phases = np.full(exp.n_rois, np.nan)
    n_phases[exp.hdn_indexes] = color_scheme
    colored = color_stack(exp.rois_stack, variable=n_phases, color_scheme="twilight")

    proj = np.zeros(colored.shape[:2] + (4,), dtype=np.uint8)
    for i in range(colored.shape[2]):
        this_plane = colored[:, :, i, :]
        proj[this_plane > 0] = this_plane[this_plane > 0]
    axs[j, 2].imshow(proj[-50:120:-1, -20:20:-1, :])
    add_scalebar(axs[j, 2], xlabel="lf.-rt.", ylabel="post.-ant.", xlen=60, ypos=120, ylen=-60)

    axs[j, 0].set_ylabel(lab)


In [None]:
rois_proj = exp.rois_stack.max(2)

In [None]:
n_phases = np.full(exp.n_rois, np.nan)
n_phases[exp.hdn_indexes] = rotated_pc_phase

In [None]:
colored = color_stack(exp.rois_stack, variable=n_phases, color_scheme="twilight")

In [None]:
proj = np.zeros(colored.shape[:2] + (4,), dtype=np.uint8)
for i in range(colored.shape[2]):
    this_plane = colored[:, :, i, :]
    proj[this_plane > 0] = this_plane[this_plane > 0]

In [None]:
# proj = 
plt.figure(figsize=(3, 2))
plt.imshow(proj[:, :, :])

# Perform analysis over all fish
Here we run the same analysis over all fish:

In [None]:
from dataclasses import dataclass

@dataclass
class RotationResults:
    rotated_pcs : np.array
    centered_coords : np.array
    rotated_pc_phase : np.array
    anatomical_phase : np.array

In [None]:
results_dict = {}

# Loop over all fish, and compute values for data and after shuffling coords itentity:
for path in file_list:
    fish_dict = {}
    for k, shuf_f in zip(["data", "shuf"], 
                         [lambda x: x, np.random.shuffle]):
        exp = LotrExperiment(path) 
        sequence = np.arange(len(exp.hdn_indexes))
        shuf_f(sequence)

        # Compute PCA in time, fit circle and center:
        traces = exp.traces[exp.pca_t_slice, exp.hdn_indexes]
        pcaed, phase, pca, circle_params = pca_and_phase(traces.T)
        centered_pcs = pcaed[:, :2] - circle_params[:2]

        # center coords:
        coords = exp.coords[exp.hdn_indexes, :][sequence, :]  # shuffle if necessary
        centered_coords, w_coords = get_normalized_coords(coords)

        # rotate pcs:
        rotated_pcs = reorient_pcs(centered_pcs, w_coords)

        # We can now calculate a phase for each neuron from their position in this rotated space:
        rotated_pc_phase = np.angle(rotated_pcs[:, 0] + 1j * rotated_pcs[:, 1])

        # And, for plotting, computing an "anatomical phase", aka position around the anatomy:
        anatomical_phase = np.angle(centered_coords[:, 1] + 1j * centered_coords[:, 2]) 

        fish_dict[k] = RotationResults(rotated_pcs=rotated_pcs,
                                       centered_coords=centered_coords,
                                       rotated_pc_phase=rotated_pc_phase,
                                       anatomical_phase=anatomical_phase)
            
    results_dict[path.name] = fish_dict

In [None]:
f, axs = plt.subplots(2, 2, figsize=(5, 5))

for k, res in results_dict.items():
    result = res["data"]
    for j, (color_scheme, lab) in enumerate(zip([result.rotated_pc_phase, result.anatomical_phase],
                                             ["color by phase", "color by anatomy"])):
            for i, scatter_vals in enumerate([result.rotated_pcs, result.centered_coords[:, 1:]]):
                axs[j, i].scatter(scatter_vals[:, 0], scatter_vals[:, 1], lw=0, 
                              c=color_scheme, cmap="twilight", s=s, alpha=0.5)

            axs[j, 0].set_ylabel(lab)
            
for j in range(2):
    add_scalebar(axs[j, 1], xlabel="lf.-rt.", ylabel="post.-ant.", 
                 xlen=80, ylen=80, disable_axis=False)
    despine(axs[j, 1], "all")
    despine(axs[j, 0], "all")
add_scalebar(axs[0, 0], xlabel="cPC1", ylabel="cPC2", 
                 xlen=30, ylen=30, disable_axis=False)
add_scalebar(axs[1, 0], xlabel="rPC1", ylabel="rPC2", 
                 xlen=30, ylen=30, disable_axis=False)

plt.tight_layout()

In [None]:
f_scat, axs_scat = plt.subplots(1, 2, figsize=(6, 2.5), sharex=True)
for col, d in zip([(0.8,)*3, (0.8, 0.3, 0.2)], ["shuf", "data"]):
    for k, res in results_dict.items():
        result = res[d]

        for j, lab in enumerate(["lateral pos. (um)", "sagittal pos. (um)"]):
            axs_scat[j].scatter(result.rotated_pc_phase, result.centered_coords[:, 1+j], 
                            color=col, s=5, alpha=0.8, lw=0) 
            
            axs_scat[j].set(xlabel="rPC phase (rad)", ylabel=lab) 
            despine(axs_scat[j])
            
axs_scat[0].set_ylim(-120, 120)
axs_scat[1].set_ylim(-50, 60)

plt.tight_layout()