## Reorient temporal PCA to match anatomy
In this notebook, we look at the anatomical distribution of head direction neurons across fish, to assess whether there is anatomical consistency over the anatomical location of cells w/r/t their position in the network activity topology. 

#### TODO
 - [ ] redefine the centering of the coordinates after morphing is done
 - [ ] discuss way of finding the two axis, might be suboptimal/not wor for all cases

In [None]:
%matplotlib widget
from pathlib import Path
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import matplotlib.gridspec as gridspec
from scipy.optimize import curve_fit

from lotr import LotrExperiment, DATASET_LOCATION
from lotr.notebook_utils import print_source

from lotr.plotting import despine, add_scalebar, get_circle_xy, color_stack, add_cbar, dark_col
from lotr.pca import pca_and_phase
from lotr.rpca_calculation import get_normalized_coords, reorient_pcs
# from lotr.utils import # reduce_to_pi, get_rot_matrix, get_vect_angle

In [None]:
# List all experiments
master_path = Path(DATASET_LOCATION)
file_list = [f.parent for f in master_path.glob("*/*[0-9]_f*/selected.h5")]

### Showcase analysis for a single fish

In [None]:
path = master_path / "210314_f1" / "210314_f1_natmov"# "210926_f0" / "210926_f0_gainmod" 
exp = LotrExperiment(path)

# Compute PCA in time and fit circle:
traces = exp.traces[exp.pca_t_slice, exp.hdn_indexes]
pcaed, phase, pca, circle_params = pca_and_phase(traces.T)

# Center PCs on center of the fit circle:
centered_pcs = pcaed[:, :2] - circle_params[:2]

f, ax = plt.subplots(figsize=(2.5, 2.5))

ax.scatter(centered_pcs[:, 0], centered_pcs[:, 1])
ax.plot(*get_circle_xy((0, 0, circle_params[2])), c=(0.2,)*3)
ax.axis("equal")
plt.show()
add_scalebar(xlabel="PC1", ylabel="PC2", xlen=30, ylen=30)

### Normalize coordinates

In [None]:
print_source(get_normalized_coords)

In [None]:
coords = exp.coords[exp.hdn_indexes, :]
centered_coords = coords - np.mean(coords, axis=0)
w_coords = get_normalized_coords(coords)

In [None]:
f, axs = plt.subplots(1, 3, figsize=(7.5, 2.5))

for i, ax in enumerate(axs):
    sc = ax.scatter(centered_pcs[:, 0], centered_pcs[:, 1], c=w_coords[:, i], 
               cmap="viridis")
    ax.plot(*get_circle_xy((0, 0, circle_params[2])), c=(0.2,)*3)
    ax.axis("equal")
    add_cbar((0.35 + 0.28*i, 0.7, 0.015, 0.18), sc, ticks=[], 
             label=["pos. z", "pos. left-right", "pos. rostr.-caud."][i], 
             titlesize=9)
    # ax.set_title(["pos. z", "pos. left-right", "pos. rostr.-caud."][i])
    add_scalebar(ax, xlabel="PC1", ylabel="PC2", xlen=30, ylen=30, disable_axis=False)
    despine(ax, "all")

#### Fit to match anatomy 

The following function is then used to reorient the PC axes to match anatomy:

In [None]:
print_source(reorient_pcs)

In [None]:
rotated_pcs = reorient_pcs(centered_pcs, w_coords)

In [None]:
f, ax = plt.subplots(figsize=(2.5, 2.5))

sc = ax.scatter(rotated_pcs[:, 0], rotated_pcs[:, 1], c=w_coords[:, 2])
ax.plot(*get_circle_xy((0, 0, circle_params[2])), c=(0.2,)*3)
ax.axis("equal")
# ax.set_title("Color by rostro.-caud.")
plt.show()
add_scalebar(xlabel="rPC1", ylabel="rPC2", xlen=60, ylen=60)

add_cbar((0.87, 0.7, 0.04, 0.18), sc, ticks=[], label="sagitt. ax.", titlesize=9)

In [None]:
proj.shape

In [None]:
# We can now calculate a phase for each neuron from their position in this rotated space:
rotated_pc_phase = np.angle(rotated_pcs[:, 0] + 1j * rotated_pcs[:, 1])

# And, for plotting, computing an "anatomical phase", aka position around the anatomy:
anatomical_phase = np.angle(-w_coords[:, 1] + 1j * w_coords[:, 2])

# And plot it over the anatomy:
f, axs = plt.subplots(2, 3, figsize=(8, 4))
for j, (color_scheme, lab) in enumerate(zip([rotated_pc_phase, anatomical_phase],
                                         ["col. by rPC phase", "col. by anatomy phase"])):
    for i, scatter_vals in enumerate([rotated_pcs, centered_coords[:, 1:]]):
        x = scatter_vals[:, 0] if i == 0 else -scatter_vals[:, 0]
        axs[j, i].scatter(x, scatter_vals[:, 1],
                      c=color_scheme, cmap="twilight")
        if i == 1:
            add_scalebar(axs[j, i], xlabel="lf.-rt.", ylabel="post.-ant.", xlen=60, ylen=60, 
                         ypos=-60)
        else:
            add_scalebar(axs[j, i], xlabel="rPC1", ylabel="rPC2", xlen=45, ylen=45,
                        disable_axis=False)
        axs[j, i].axis("equal")
        despine(axs[j, i], "all")
    n_phases = np.full(exp.n_rois, np.nan)
    n_phases[exp.hdn_indexes] = color_scheme
    colored = color_stack(exp.rois_stack, variable=n_phases, color_scheme="twilight")

    #proj = # np.zeros(colored.shape[:2] + (4,), dtype=np.uint8)
    proj = colored.max(0)
    #for i in range(colored.shape[2]):
    #    this_plane = colored[:, :, i, :]
    #    proj[this_plane > 0] = this_plane[this_plane > 0]
    axs[j, 2].imshow(proj[20:-20, 120:-50, :].swapaxes(0,1))
    add_scalebar(axs[j, 2], xlabel="lf.-rt.", ylabel="post.-ant.", xlen=60, ypos=120, ylen=-60)

    axs[j, 0].set_ylabel(lab)

plt.suptitle(path.name)

# Perform analysis over all fish
Here we run the same analysis over all fish:

In [None]:
results_list = []

# Loop over all fish, and compute values for data and after shuffling coords itentity:
for path in file_list:
    for group, shuf_f in zip(["data", "shuf"], 
                         [lambda x: x, np.random.shuffle]):
        exp = LotrExperiment(path) 
        sequence = np.arange(len(exp.hdn_indexes))
        shuf_f(sequence)

        # Compute PCA in time, fit circle and center:
        traces = exp.traces[exp.pca_t_slice, exp.hdn_indexes]
        pcaed, phase, pca, circle_params = pca_and_phase(traces.T)
        centered_pcs = pcaed[:, :2] - circle_params[:2]

        # center coords:
        coords = exp.coords[exp.hdn_indexes, :][sequence, :]  # shuffle if necessary
        centered_coords = coords - np.mean(coords, axis=0)
        w_coords = get_normalized_coords(coords)

        # rotate pcs:
        rotated_pcs = reorient_pcs(centered_pcs, w_coords)

        # We can now calculate a phase for each neuron from their position in this rotated space:
        rotated_pc_phase = np.angle(rotated_pcs[:, 0] + 1j * rotated_pcs[:, 1])

        # And, for plotting, computing an "anatomical phase", aka position around the anatomy:
        anatomical_phase = np.angle(-centered_coords[:, 1] + 1j * centered_coords[:, 2]) 

        results_list.append(pd.DataFrame(dict(rpc1_proj=rotated_pcs[:, 0],
                                              rpc2_proj=rotated_pcs[:, 1],
                                              lr_pos=-centered_coords[:, 1],  # invert for scatterplots
                                              ap_pos=centered_coords[:, 2],
                                              rpc_phase=rotated_pc_phase,
                                              anatomical_phase=anatomical_phase,
                                              exp_id=path.name,
                                              group=group)))

pooled_df = pd.concat(results_list, axis=0).reset_index()

In [None]:
f, axs = plt.subplots(2, 2, figsize=(5, 5))
s = 8

pooled_data = pooled_df[pooled_df["group"] == "data"]

for j, (color_k, lab) in enumerate(zip(["rpc_phase", "anatomical_phase"],
                                     ["color by phase", "color by anatomy"])):
        for ax, scatt_ks in zip(axs[j, :], 
                                [["rpc1_proj", "rpc2_proj"], ["lr_pos", "ap_pos"]]):
            ax.scatter(pooled_data[scatt_ks[0]], 
                       pooled_data[scatt_ks[1]], lw=0, 
                          c=pooled_data[color_k], cmap="twilight", s=s, alpha=0.5)

        axs[j, 0].set_ylabel(lab)
            
for j in range(2):
    add_scalebar(axs[j, 1], xlabel="lf.-rt.", ylabel="post.-ant.", 
                 xlen=80, ylen=80, disable_axis=False)
    despine(axs[j, 1], "all")
    despine(axs[j, 0], "all")

    add_scalebar(axs[j, 0], xlabel="rPC1", ylabel="rPC2", 
                 xlen=40, ylen=40, disable_axis=False)

plt.tight_layout()

In [None]:
n_rois = len(pooled_df[pooled_df["group"] == "data"])

# create the function we want to fit
def my_cos(x, amplitude, phase, offset, freq):
    return np.cos(x*freq + phase) * amplitude + offset

# Perform the fit. 
# We will fit on the first half of the data, and then measure R^2 over the left outs:
initial_guesses = dict(lr_pos=(80, np.pi, 0, 1),
                       ap_pos=(80, np.pi/2, 0, 1))

fit_data_slice = slice(0, n_rois // 2)
test_slice = slice(n_rois // 2, None)

fit_results = []
for group in ["data", "shuf"]:
    group_data = pooled_df[pooled_df["group"] == group]
    
    for ax in ["lr_pos", "ap_pos"]:
        fit = curve_fit(my_cos, 
                        group_data["rpc_phase"].values[fit_data_slice],
                        group_data[ax].values[fit_data_slice],
                        p0=initial_guesses[ax])
        
        fit_results.append(dict(group=group, 
                                ax=ax,
                                amp=fit[0][0],
                                ph=fit[0][1],
                                off=fit[0][2],
                                freq=fit[0][3]))
        
fit_results = pd.DataFrame(fit_results)

In [None]:
n_bins = 40
cols = dict(shuf=(0.8,)*3, 
            data=(0.8, 0.4, 0.3))
ax_labels = dict(lr_pos="l-r pos. (μm)",
                 ap_pos="a-p pos. (μm)")

fig = plt.figure( figsize=(6, 3.5))

gs = gridspec.GridSpec(4, 6, figure=fig)
axs_scat = [fig.add_subplot(gs[1:, :3]), fig.add_subplot(gs[1:, 3:])]
axs_comp = [fig.add_subplot(gs[0, 1:3]), fig.add_subplot(gs[0, 4:])]
# ax2 = ax1

x_array = np.arange(-np.pi, np.pi, 0.1)

# f_scat, all_axs = plt.subplots(2, 2, figsize=(6, 5))
# axs_scat = all_axs[0, :]
# axs_hist = all_axs[1, :]
res = []
    
for j, ax in enumerate(["lr_pos", "ap_pos"]):
    residuals = []
    for group in ["shuf", "data"]:
        group_data = pooled_df[pooled_df["group"] == group]
        
        axs_scat[j].scatter(group_data["rpc_phase"], group_data[ax], 
                        color=cols[group], s=5, alpha=0.8, lw=0, label=group) 
        
        fit_params = fit_results[(fit_results["group"]==group) & (fit_results["ax"]==ax)]
        y = my_cos(x_array, *list(fit_params.iloc[0, 2:].values))
        axs_scat[j].plot(x_array, y, color=dark_col(cols[group]), label="_nolegend_") 

        axs_scat[j].set(xlabel="rPC phase (rad)", ylabel=ax_labels[ax]) 
        despine(axs_scat[j])
        
        group_data = pooled_df[pooled_df["group"] == group]

        # fit_params = fit_results[(fit_results["group"]==group) & (fit_results["ax"]==ax)]
        predicted = my_cos(group_data["rpc_phase"].values[test_slice], 
                           *list(fit_params.iloc[0, 2:].values))
        
        # h_range = (np.percentile(pooled_df[ax]**2, (1, 98)))
        residuals.append((predicted - group_data[ax].values[test_slice])**2)
    bplot = axs_comp[j].boxplot(residuals, notch=False, showfliers=False, vert=False,
            patch_artist=True, showcaps=False, widths=0.6)
    for patch, med, group in zip(bplot["boxes"], bplot["medians"], ["shuf", "data"]):
        patch.set(fc=cols[group], lw=1, ec=cols[group])
        med.set(color=(0.3,)*3)
    for whisk in bplot["whiskers"]:
        whisk.set(lw=1, color=(0.3,)*3)
        
    # 
    axs_comp[j].set_yticklabels(["shuf", "data"])
    axs_comp[j].tick_params(axis='both', which='both', labelsize=8)
    axs_comp[j].xaxis.tick_top()
    axs_comp[j].xaxis.set_label_position('top') 
    axs_comp[j].set_xlabel("residuals (cross. val.)")
    # despine(axs_comp[j], ["bottom", "right"])
    [axs_comp[j].axes.spines[s].set_visible(False) for s in ["bottom", "right"]]
        
        # print(np.percentile(residuals, [25, 50, 75]))
        # res.append(residuals)
        # axs_hist[j].hist(residuals, 
        #                  np.arange(h_range[0], h_range[1], (h_range[1] - h_range[0])/n_bins), 
        #                 alpha=0.7, fc=col, lw=0, label=group)

        # despine(axs_hist[j])
        # axs_hist[j].set(yscale="log", xlabel=f"Square dist. ({lab[:3]})", ylabel="log(count)") 
            
axs_scat[0].set_ylim(-120, 140)
axs_scat[1].set_ylim(-50, 80)
axs_scat[1].legend(frameon=False, fontsize=10, markerscale=2, #bbox_to_anchor=(0.8, 0.9),
                  handletextpad=-0.3)

plt.tight_layout()