## Anatomical organization of the network

In this notebook, we first look at the anatomical distribution of head direction neurons across fish, to assess whether there is anatomical consistency over the anatomical location of cells w/r/t their position in the network activity topology. 

By doing that, we find a unified phase space that replaces the arbitrary orientation of cells in phase space for each fish; by orienting all of them in the same way with respect to the anatomy, we ensure that a given network phase always corresponds to a specific location of activity in the aHB. This rotated phase is what will actually be used in all subsequent analyses.

**TODO**
 - [ ] redefine the centering of the coordinates after morphing is done

In [None]:
%matplotlib widget
from pathlib import Path

import matplotlib.gridspec as gridspec
import numpy as np
import pandas as pd
from lotr import A_FISH, FIGURES_LOCATION, LotrExperiment, dataset_folders
from lotr.notebook_utils import print_source
from lotr.pca import pca_and_phase
from lotr.plotting import (
    COLS,
    add_anatomy_scalebar,
    add_cbar,
    add_phase_cbar,
    add_scalebar,
    boxplot,
    color_stack,
    dark_col,
    despine,
    get_circle_xy,
)
from lotr.rpca_calculation import get_zero_mean_weights, reorient_pcs
from matplotlib import pyplot as plt
from scipy.optimize import curve_fit

# from lotr.utils import # reduce_to_pi, get_rot_matrix, get_vect_angle

## Anatomical organization of the network

First, let's see whether the organization of cells in PC space matches their anatomical location.

In [None]:
exp = LotrExperiment(A_FISH)

# Compute PCA in time and angles for each neuron:
pca_scores, angles, _, _ = pca_and_phase(exp.traces[exp.pca_t_slice, exp.hdn_indexes].T)
coords = exp.coords_um[exp.hdn_indexes, :]

# And plot it over the anatomy:
f, axs = plt.subplots(1, 3, figsize=(6, 2.5))

axs[0].scatter(pca_scores[:, 0], pca_scores[:, 1], c=angles, cmap=COLS["phase"])
add_scalebar(axs[0], xlabel="PC1", ylabel="PC2", xlen=30, ylen=30)
axs[0].axis("equal")

# Show coordinates:
sc = axs[1].scatter(coords[:, 1], coords[:, 2], c=angles, cmap=COLS["phase"], s=15)
# add_anatomy_scalebar(axs[1])
axs[1].set(xlim=exp.lr_extent_um, ylim=exp.pa_extent_um)
axs[1].axis("off")
add_phase_cbar(sc, axs[1], inset_loc=(-0.0, 0.15, 0.25, 0.04), orientation="horizontal")

# Show ROI map:
colored_rois = exp.color_rois_by(angles, color_scheme=COLS["phase"])
axs[2].imshow(colored_rois.max(0), extent=exp.plane_ext_um, origin="lower")
add_anatomy_scalebar(axs[2])


plt.tight_layout()

f.savefig(FIGURES_LOCATION / "anatomical distribution_onefish.pdf")

Yay! There seems to be an interesting anatomical localization, as expected from the raw activity after bouts.

## Registration across fish

Now, if we want to compare anatomical distributions between different fish in the same figure, we need to somehow match the ROI coordinates in PC space from one fish to the other. Lukily, they are arranged in a circle! It should not be too hard to match the phases across fish with the following steps:
 1. Compute PCs across time
 2. Center them to have 0 mean
 3. Find a rotation so that the ROIs on the left of the "rotated PC" (**rPC**) circle (0 phase) are the left ROIs, and the ROIs on top of the PC circle are the rostral ROIs
 4. Define a new "rPC" angle for all neurons in this new space, and a use them to define a phase which is now in the same space across all fish

In [None]:
# Compute PCA in time and fit circle:
pca_scores, phase, _, _ = pca_and_phase(exp.traces[exp.pca_t_slice, exp.hdn_indexes].T)

# And plot it over the anatomy:
f, axs = plt.subplots(1, 3, figsize=(8, 4))
for j, (color_scheme, color_val, lab) in enumerate(
    zip([COLS["phase"]], [phase], ["col. by rPC phase"],)
):
    for i, scatter_vals in enumerate([pca_scores, exp.coords[:, 1:]]):
        x = scatter_vals[:, 0]  # if i == 0
        axs[j, i].scatter(x, scatter_vals[:, 1], c=color_val, cmap=color_scheme)

        if i == 1:
            add_scalebar(
                axs[j, i],
                xlabel="lf.-rt.",
                ylabel="post.-ant.",
                xlen=60,
                ylen=60,
                ypos=-60,
            )
        else:
            add_scalebar(
                axs[j, i],
                xlabel="PC1",
                ylabel="PC2",
                xlen=45,
                ylen=45,
                disable_axis=False,
            )
        axs[j, i].axis("equal")
        despine(axs[j, i], "all")
    n_phases = np.full(exp.n_rois, np.nan)
    n_phases[exp.hdn_indexes] = color_val
    colored = color_stack(exp.rois_stack, variable=n_phases, color_scheme=color_scheme)

    proj = colored.max(0)

    axs[j, 2].imshow(proj[:, :, :])
    add_scalebar(
        axs[j, 2], xlabel="lf.-rt.", ylabel="post.-ant.", xlen=60, ypos=120, ylen=-60
    )

    axs[j, 0].set_ylabel(lab)

plt.suptitle(path.name)

In [None]:
# And, for plotting, computing an "anatomical phase", aka position around the anatomy:
anatomical_phase = np.angle(-w_coords[:, 1] + 1j * w_coords[:, 2])

### Normalize coordinates

In [None]:
print_source(get_zero_mean_weights)

In [None]:
coords = exp.coords[exp.hdn_indexes, :]
centered_coords = coords - np.mean(coords, axis=0)
w_coords = get_zero_mean_weights(coords)

In [None]:
f, axs = plt.subplots(1, 3, figsize=(7.5, 2.5))
for i, ax in enumerate(axs):
    sc = ax.scatter(
        centered_pcs[:, 0], centered_pcs[:, 1], c=w_coords[:, i], cmap="viridis"
    )
    ax.plot(*get_circle_xy((0, 0, circle_params[2])), c=(0.2,) * 3)
    ax.axis("equal")
    add_cbar(
        sc,
        ax,
        (0.98, 0.9, 0.04, 0.15),
        ticks=[],
        label=["pos. s-i", "pos. l-r", "pos. a-p"][i],
        titlesize=9,
    )
    # ax.set_title(["pos. z", "pos. left-right", "pos. rostr.-caud."][i])
    add_scalebar(ax, xlabel="PC1", ylabel="PC2", xlen=30, ylen=30, disable_axis=False)
    despine(ax, "all")

#### Fit to match anatomy 

The following function is then used to reorient the PC axes to match anatomy:

In [None]:
print_source(reorient_pcs)

In [None]:
rotated_pcs = reorient_pcs(centered_pcs, w_coords)

In [None]:
f, ax = plt.subplots(figsize=(2.5, 2.5))

sc = ax.scatter(rotated_pcs[:, 0], rotated_pcs[:, 1], c=w_coords[:, 2])
ax.plot(*get_circle_xy((0, 0, circle_params[2])), c=(0.2,) * 3)
ax.axis("equal")
# ax.set_title("Color by rostro.-caud.")
plt.show()
add_scalebar(xlabel="rPC1", ylabel="rPC2", xlen=60, ylen=60)

add_cbar(sc, ax, (0.98, 0.95, 0.04, 0.15), ticks=[], label="pos. a-p", titlesize=9)

In [None]:
# We can now calculate a phase for each neuron from their position in this rotated space:
rotated_pc_phase = np.angle(rotated_pcs[:, 0] + 1j * rotated_pcs[:, 1])

# And, for plotting, computing an "anatomical phase", aka position around the anatomy:
anatomical_phase = np.angle(-w_coords[:, 1] + 1j * w_coords[:, 2])

# And plot it over the anatomy:
f, axs = plt.subplots(2, 3, figsize=(8, 4))
for j, (color_scheme, color_val, lab) in enumerate(
    zip(
        [COLS["phase"], "twilight"],
        [rotated_pc_phase, anatomical_phase],
        ["col. by rPC phase", "col. by anatomy phase"],
    )
):
    for i, scatter_vals in enumerate([rotated_pcs, centered_coords[:, 1:]]):
        x = scatter_vals[:, 0]  # if i == 0
        axs[j, i].scatter(x, scatter_vals[:, 1], c=color_val, cmap=color_scheme)

        if i == 1:
            add_scalebar(
                axs[j, i],
                xlabel="lf.-rt.",
                ylabel="post.-ant.",
                xlen=60,
                ylen=60,
                ypos=-60,
            )
        else:
            add_scalebar(
                axs[j, i],
                xlabel="rPC1",
                ylabel="rPC2",
                xlen=45,
                ylen=45,
                disable_axis=False,
            )
        axs[j, i].axis("equal")
        despine(axs[j, i], "all")
    n_phases = np.full(exp.n_rois, np.nan)
    n_phases[exp.hdn_indexes] = color_val
    colored = color_stack(exp.rois_stack, variable=n_phases, color_scheme=color_scheme)

    proj = colored.max(0)

    axs[j, 2].imshow(proj[80:260, 20:-20, :])
    add_scalebar(
        axs[j, 2], xlabel="lf.-rt.", ylabel="post.-ant.", xlen=60, ypos=120, ylen=-60
    )

    axs[j, 0].set_ylabel(lab)

plt.suptitle(path.name)

# Perform analysis over all fish
Here we run the same analysis over all fish:

In [None]:
results_list = []

np.random.seed(50)

# Loop over all fish, and compute values for data and after shuffling coords itentity:
for path in dataset_folders:
    for group, shuf_f in zip(["data", "shuf"], [lambda x: x, np.random.shuffle]):
        exp = LotrExperiment(path)
        sequence = np.arange(len(exp.hdn_indexes))
        shuf_f(sequence)

        # Compute PCA in time, fit circle and center:
        traces = exp.traces[exp.pca_t_slice, exp.hdn_indexes]
        pcaed, phase, pca, circle_params = pca_and_phase(traces.T)
        centered_pcs = pcaed[:, :2] - circle_params[:2]

        # center coords:
        coords = exp.coords[exp.hdn_indexes, :][sequence, :]  # shuffle if necessary
        centered_coords = coords - np.mean(coords, axis=0)
        w_coords = get_zero_mean_weights(coords)

        # rotate pcs:
        rotated_pcs = reorient_pcs(centered_pcs, w_coords)

        # We can now calculate a phase for each neuron from their position in this rotated space:
        rotated_pc_phase = np.angle(rotated_pcs[:, 0] + 1j * rotated_pcs[:, 1])

        # And, for plotting, computing an "anatomical phase", aka position around the anatomy:
        anatomical_phase = np.angle(-centered_coords[:, 1] + 1j * centered_coords[:, 2])

        results_list.append(
            pd.DataFrame(
                dict(
                    rpc1_proj=rotated_pcs[:, 0],
                    rpc2_proj=rotated_pcs[:, 1],
                    lr_pos=-centered_coords[:, 1],  # invert for scatterplots
                    ap_pos=centered_coords[:, 2],
                    rpc_phase=rotated_pc_phase,
                    anatomical_phase=anatomical_phase,
                    exp_id=path.name,
                    group=group,
                )
            )
        )

pooled_df = pd.concat(results_list, axis=0).reset_index()

In [None]:
f, axs = plt.subplots(2, 2, figsize=(5, 5))
s = 8

pooled_data = pooled_df[pooled_df["group"] == "data"]

for j, (color_scheme, color_k, lab) in enumerate(
    zip(
        [COLS["phase"], "twilight"],
        ["rpc_phase", "anatomical_phase"],
        ["color by phase", "color by anatomy"],
    )
):
    for ax, scatt_ks in zip(
        axs[j, :], [["rpc1_proj", "rpc2_proj"], ["lr_pos", "ap_pos"]]
    ):
        ax.scatter(
            pooled_data[scatt_ks[0]],
            pooled_data[scatt_ks[1]],
            lw=0,
            c=pooled_data[color_k],
            cmap=color_scheme,
            s=s,
            alpha=0.5,
        )

    axs[j, 0].set_ylabel(lab)

for j in range(2):
    add_scalebar(
        axs[j, 1],
        xlabel="lf.-rt.",
        ylabel="post.-ant.",
        xlen=80,
        ylen=80,
        disable_axis=False,
    )
    despine(axs[j, 1], "all")
    despine(axs[j, 0], "all")

    add_scalebar(
        axs[j, 0], xlabel="rPC1", ylabel="rPC2", xlen=40, ylen=40, disable_axis=False
    )

plt.tight_layout()

In [None]:
n_rois = len(pooled_df[pooled_df["group"] == "data"])

# create the function we want to fit
def my_cos(x, amplitude, phase, offset, freq):
    return np.cos(x * freq + phase) * amplitude + offset


# Perform the fit.
# We will fit on the first half of the data, and then measure R^2 over the left outs:
initial_guesses = dict(lr_pos=(80, np.pi, 0, 1), ap_pos=(80, np.pi / 2, 0, 1))

fit_data_slice = slice(0, n_rois // 2)
test_slice = slice(n_rois // 2, None)

fit_results = []
for group in ["data", "shuf"]:
    group_data = pooled_df[pooled_df["group"] == group]

    for ax in ["lr_pos", "ap_pos"]:

        fit = curve_fit(
            my_cos,
            group_data["rpc_phase"].values[fit_data_slice],
            group_data[ax].values[fit_data_slice],
            p0=initial_guesses[ax],
        )

        fit_results.append(
            dict(
                group=group,
                ax=ax,
                amp=fit[0][0],
                ph=fit[0][1],
                off=fit[0][2],
                freq=fit[0][3],
            )
        )

fit_results = pd.DataFrame(fit_results)

In [None]:
n_bins = 40
cols = [(0.8,) * 3, (0.8, 0.4, 0.3)]

ax_labels = dict(lr_pos="l-r pos. (μm)", ap_pos="a-p pos. (μm)")

fig = plt.figure(figsize=(6, 3.5))

gs = gridspec.GridSpec(4, 6, figure=fig)
axs_scat = [fig.add_subplot(gs[1:, :3]), fig.add_subplot(gs[1:, 3:])]
axs_comp = [fig.add_subplot(gs[0, 1:3]), fig.add_subplot(gs[0, 4:])]
# ax2 = ax1

x_array = np.arange(-np.pi, np.pi, 0.1)

res = []

for j, ax in enumerate(["lr_pos", "ap_pos"]):
    residuals = []
    for i, group in enumerate(["shuf", "data"]):
        group_data = pooled_df[pooled_df["group"] == group]

        axs_scat[j].scatter(
            group_data["rpc_phase"],
            group_data[ax],
            color=cols[i],
            s=5,
            alpha=0.8,
            lw=0,
            label=group,
        )

        fit_params = fit_results[
            (fit_results["group"] == group) & (fit_results["ax"] == ax)
        ]
        y = my_cos(x_array, *list(fit_params.iloc[0, 2:].values))
        axs_scat[j].plot(x_array, y, color=dark_col(cols[i]), label="_nolegend_")

        axs_scat[j].set(xlabel="rPC phase (rad)", ylabel=ax_labels[ax])
        despine(axs_scat[j])

        group_data = pooled_df[pooled_df["group"] == group]

        # fit_params = fit_results[(fit_results["group"]==group) & (fit_results["ax"]==ax)]
        predicted = my_cos(
            group_data["rpc_phase"].values[test_slice],
            *list(fit_params.iloc[0, 2:].values)
        )

        # h_range = (np.percentile(pooled_df[ax]**2, (1, 98)))
        residuals.append((predicted - group_data[ax].values[test_slice]) ** 2)

    boxplot(residuals, cols=cols, ax=axs_comp[j], widths=0.6, ec=(0.3,) * 3)

    axs_comp[j].set_yticklabels(["shuf", "data"])
    axs_comp[j].tick_params(axis="both", which="both", labelsize=8)
    axs_comp[j].xaxis.tick_top()
    axs_comp[j].xaxis.set_label_position("top")
    axs_comp[j].set_xlabel("residuals (cross. val.)")

    [axs_comp[j].axes.spines[s].set_visible(False) for s in ["bottom", "right"]]


axs_scat[0].set_ylim(-120, 140)
axs_scat[1].set_ylim(-50, 80)
axs_scat[1].legend(
    frameon=False,
    fontsize=10,
    markerscale=2,  # bbox_to_anchor=(0.8, 0.9),
    handletextpad=-0.3,
)

plt.tight_layout()