## Anatomical organization of the network

In this notebook, we first look at the anatomical distribution of head direction neurons across fish, to assess whether there is anatomical consistency over the anatomical location of cells w/r/t their position in the network activity topology. 

By doing that, we find a unified phase space that replaces the arbitrary orientation of cells in phase space for each fish; by orienting all of them in the same way with respect to the anatomy, we ensure that a given network phase always corresponds to a specific location of activity in the aHB. This rotated phase is what will actually be used in all subsequent analyses.

**TODO**
 - [ ] redefine the centering of the coordinates after morphing is done

In [None]:
%matplotlib widget
from pathlib import Path

import lotr.plotting as ltplt
import matplotlib.gridspec as gridspec
import numpy as np
import pandas as pd
from lotr import A_FISH, FIGURES_LOCATION, LotrExperiment, dataset_folders
from lotr.anatomy import anatomical_angle_remapping
from lotr.notebook_utils import print_source
from lotr.pca import pca_and_phase
from lotr.rpca_calculation import get_zero_mean_weights, reorient_pcs
from lotr.utils import get_rot_matrix, get_vect_angle, reduce_to_pi, zscore
from matplotlib import cm
from matplotlib import pyplot as plt
from scipy.optimize import curve_fit

COLS = ltplt.COLS
fig_location = FIGURES_LOCATION / "2 - anatomical_organization"
fig_location.mkdir(exist_ok=True)

## Is there an anatomical organization?

First, let's see whether the organization of cells in PC space matches their anatomical location.

In [None]:
# Compute PCA in time and angles for each neuron:
exp = LotrExperiment(A_FISH)

pca_scores, angles, _, circle_params = pca_and_phase(
    exp.traces[exp.pca_t_slice, exp.hdn_indexes].T
)

# We will work with the z projection, ignoring depth:
coords = exp.coords_um[exp.hdn_indexes, 1:]

In [None]:
# Now plot phase over the anatomy:
f, axs = plt.subplots(1, 3, figsize=(6, 2.5))

axs[0].scatter(pca_scores[:, 0], pca_scores[:, 1], c=angles, cmap=COLS["phase"])
ltplt.add_scalebar(axs[0], xlabel="PC1", ylabel="PC2", xlen=30, ylen=30)
axs[0].axis("equal")

# Show coordinates:
sc = axs[1].scatter(coords[:, 0], coords[:, 1], c=angles, cmap=COLS["phase"], s=15)

axs[1].set(xlim=exp.lr_extent_um, ylim=exp.pa_extent_um)
axs[1].axis("off")
ltplt.add_phase_cbar(
    sc, axs[1], inset_loc=(-0.0, 0.15, 0.25, 0.04), orientation="horizontal"
)

# Show ROI map:
colored_rois = exp.color_rois_by(angles, color_scheme=COLS["phase"])
axs[2].imshow(colored_rois.max(0), extent=exp.plane_ext_um, origin="lower")
ltplt.add_anatomy_scalebar(axs[2])

f.savefig(fig_location / "anatomical_distribution_onefish.pdf")

Yay! There seems to be an interesting anatomical localization, as expected from the raw activity after bouts.

## Rotate PC space to match anatomy

Now, if we want to compare anatomical distributions between different fish in the same figure, we need to somehow match the ROI coordinates in PC space from one fish to the other. Lukily, they are arranged in a circle! It should not be too hard to match the phases across fish with the following steps:
 1. Compute PCs across time
 2. Center them to have 0 mean
 3. Find a rotation so that the ROIs on the left of the "rotated PC" (**rPC**) circle (0 phase) are the left ROIs, and the ROIs on top of the PC circle are the rostral ROIs
 4. Define a new "rPC" angle for all neurons in this new space, and a use them to define a phase which is now in the same space across all fish

In [None]:
# 1. compute PCs:
exp = LotrExperiment(A_FISH)
pca_scores, angles, _, circle_params = pca_and_phase(
    exp.traces[exp.pca_t_slice, exp.hdn_indexes].T
)

# 2. center on 0:
centered_pca_scores = pca_scores[:, :2] - circle_params[:2]

### Find PC rotation

Step 3 will be slightly trickier. Our strategy is to find the vector in the circle that points toward the location of frontal-most ROIs. To do so, we will average ROIs angles (vectors) over the circle, similarly to what we did for the network phase, but now weghting for anatomical location over antero-posterior and left-right axes. As weights, we will be using coordinates after an offset subtraction that makes them 0-mean (the function also removes eventual outliers).

In [None]:
# get 0-mean coords (exclude z axis aready from here)
w_coords = get_zero_mean_weights(exp.coords[exp.hdn_indexes, 1:])

# Compute the PCs vector averages across the population using coordinates along
# each anatomical axis as weights.
# The result is a 3 x 2 matrix containing the average vector for each of the
# 3 anatomical axes used as weights:
avg_vects = np.einsum("ij,ik->jk", centered_pca_scores, w_coords)

# get angle from vectors for the 3 anatomical coordinates:
avg_angles = get_vect_angle(avg_vects)

# and we take the mean between left-right and anterior-posterior axes angles:
mean_angle = np.angle(np.sum(np.cos(avg_angles)) + 1j * np.sum(np.sin(avg_angles)))

# Since PC signs can be arbitrary, we also need to find whether left and right
# were flipped. This we decide based on the difference in sign from the angle
# between the lateral and sagittal axes fit in PC space:
s = np.sign(reduce_to_pi(avg_angles[1] - avg_angles[0]))
invert_mat = np.array([[1, 0], [0, s]])

# At this point, we simply need to rotate the coordinates so that the
# mean angle between vector pointing forward and vector pointing rightward is placed
# at (1/4)*pi (angle E):
FINAL_TH_SHIFT = -(1 / 4) * np.pi
#
rotated_pca_scores = (
    get_rot_matrix(FINAL_TH_SHIFT)
    @ get_rot_matrix(-mean_angle * s)
    @ invert_mat
    @ centered_pca_scores.T
).T

In [None]:
seg_len = 100
col_schemes = ["BrBG_r", "RdBu"]
f, axs = plt.subplots(2, 2, figsize=(5, 5))

radius = np.max(circle_params[2])

for i, coords in enumerate([centered_pca_scores, rotated_pca_scores]):
    for j, col in enumerate(col_schemes):
        ax = axs[i, j]
        sc = ax.scatter(
            coords[:, 0],
            coords[:, 1],
            c=exp.coords[exp.hdn_indexes, j + 1],
            cmap=col,
            s=10,
        )
        ax.plot(*ltplt.get_circle_xy((0, 0, radius)), lw=1, c=".2", zorder=-100)

        line_col = cm.get_cmap(col).get_over()  # get line color from cmap

        avg_vects = np.einsum("ij,ik->jk", coords, w_coords)
        # get angle from vectors for the 3 anatomical coordinates:
        avg_angles = get_vect_angle(avg_vects)

        ax.plot(
            [0, np.cos(avg_angles[j]) * radius],
            [0, np.sin(avg_angles[j]) * radius],
            lw=1.5,
            ls="--",
            dashes=(7, 3),
            c=line_col,
        )
        ax_lab = ["l. - r.", "p. - a."][j]
        ltplt.add_cbar(
            sc, ax, (0.98, 0.9, 0.04, 0.15), ticks=[], label=ax_lab, titlesize=9,
        )
        l = 120
        ax.set(xlim=(-l, l), ylim=(-l, l))
        ltplt.despine(ax, "all")

for ax in axs[1, :]:
    for j, col in enumerate(col_schemes):
        l_col = cm.get_cmap(col).get_over()
        ax.plot(
            [0, np.sin(j * np.pi / 2) * radius],
            [0, -np.cos(j * np.pi / 2) * radius],
            c=l_col,
            lw=1.5,
            alpha=0.3,
            zorder=-100,
        )
        ltplt.add_fish(ax, head_offset=(0, 0), c=".7", scale=radius, angle=-np.pi / 2)


axs[0, 0].set(ylabel="before")
axs[1, 0].set(ylabel="after")
blen = 40
ltplt.add_scalebar(axs[0, 0], xlabel="PC1", ylabel="PC2", xlen=blen, ylen=blen)
ltplt.add_scalebar(axs[1, 0], xlabel="rPC1", ylabel="rPC2", xlen=blen, ylen=blen)

f.savefig(fig_location / "anatomical_fit_onefish.pdf")

In the actual analysis, the `reorient_pcs(pcs, coords)` function is used to reorient the PC axes to match anatomy, doing exactly the steps we did above:

In [None]:
reorient_with_funct = reorient_pcs(centered_pca_scores, w_coords)

# Important check that demo and lotr function are consistent:
assert np.allclose(rotated_pca_scores, reorient_with_funct)

### Define convention for rPC angles and phase

We need now a convention to define neuron angles, and therefore network phase. Our definition will be the following:
 - Caudal neurons will have angle 0; network phase 0 will correspond to caudal activity;
 - Angle will **increase mowing clockwise in the anatomy**  and **decrease moving counterclockwise in the anatomy**.
 
Therefore, if there is anatomical organization, starting from the rostral 0:
 - Left neurons will have (mostly) positive angles (if topology is true)
 - Right neurons will have (mostly) negative angles

In [None]:
pca_scores, angles, _, circle_params = pca_and_phase(
    exp.traces[exp.pca_t_slice, exp.hdn_indexes].T
)

In [None]:
# The rotated PCs are already centered on 0, so:
rpc_angles = np.arctan2(rotated_pca_scores[:, 1], -rotated_pca_scores[:, 0])

In [None]:
# Now plot phase over the anatomy:
f, axs = plt.subplots(1, 2, figsize=(5, 2.5))

# Inverting scatter order and sign we rotate the plot by 90° so that
# frontal ROIs are on top:
sc = axs[0].scatter(
    -rotated_pca_scores[:, 1],
    rotated_pca_scores[:, 0],
    c=rpc_angles,
    cmap=COLS["phase"],
    vmax=np.pi,
    vmin=-np.pi,
    s=15,
)
ltplt.add_scalebar(
    axs[0],
    xlabel="rPC2",
    ylabel="rPC1",
    xlen=-45,
    ylen=45,
    text_params=dict(fontsize=8),
)

l = 130
axs[0].set(xlim=(-l, l), ylim=(-l, l))

ltplt.add_cbar(
    sc,
    axs[0],
    inset_loc=(1, 1, 0.3, 0.06),
    ticks=[-np.pi + 0.1, 0, np.pi - 0.1],
    ticklabels=["-π", "0", "π"],
    title="rPC angle",
    orientation="horizontal",
    titlesize=8,
)
ltplt.add_fish(axs[0], head_offset=(0, 0), c=".7", scale=80, angle=0)

colored_rois = exp.color_rois_by(rpc_angles, color_scheme=COLS["phase"])
axs[1].imshow(colored_rois.max(0), extent=exp.plane_ext_um, origin="lower")
ltplt.add_anatomy_scalebar(axs[1], pos=(110, 40), length=30)

f.savefig(fig_location / "registered_angles_onefish.pdf")

All those operations from now on will be performed under the hood when asking for the experiment attributes `LotrExperiment.rpc_scores` and `LotrExperiment.rpc_angles`.

In [None]:
assert np.allclose(rotated_pca_scores, exp.rpc_scores, rtol=0.001)
assert np.allclose(rpc_angles, exp.rpc_angles, rtol=0.001)

## Registration of all fish on one PC space

Now, we can transform coordinates from all fish so that they are all oriented in the same way w/r/t the anatomy. We could also normalize their radius, but we won't do it here to appreciate better the dispersion of colors around the circle:

In [None]:
results_list = []

np.random.seed(50)

# Loop over all fish, and compute values for data and after shuffling coords itentity:
for path in dataset_folders:
    for group, shuf_f in zip(["data", "shuf"], [lambda x: x, np.random.shuffle]):
        exp = LotrExperiment(path)
        sequence = np.arange(len(exp.hdn_indexes))
        shuf_f(sequence)

        # center coords (will be replaced with morphing):
        coords = exp.coords_um[exp.hdn_indexes, 1:][sequence, :]  # shuffle if necessary
        coords = coords - np.mean(coords, 0)

        # And, for plotting, computing an "anatomical phase", aka position around the anatomy:
        anatomical_phase = np.angle(-coords[:, 1] + 1j * -coords[:, 0])

        # Stretch angle to avoid undersampling aroung midline
        anatomical_phase = anatomical_angle_remapping(anatomical_phase)

        r = np.mean(np.sqrt(exp.rpc_scores[:, 0] ** 2 + exp.rpc_scores[:, 1] ** 2))
        results_list.append(
            pd.DataFrame(
                dict(
                    rpc1_proj=exp.rpc_scores[:, 0] / r,
                    rpc2_proj=exp.rpc_scores[:, 1] / r,
                    lr_pos=coords[:, 0],
                    ap_pos=coords[:, 1],
                    rpc_angle=exp.rpc_angles,
                    anatomical_phase=anatomical_phase,
                    exp_id=path.name,
                    group=group,
                )
            )
        )

pooled_df = pd.concat(results_list, axis=0).reset_index()

In [None]:
f, axs = plt.subplots(2, 2, figsize=(5, 5))
s = 8

# Get unshuffled data:
pooled_data = pooled_df[pooled_df["group"] == "data"]

for j, color_by in enumerate(["rpc_angle", "anatomical_phase"]):
    axs[j, 0].scatter(
        -pooled_data["rpc2_proj"],
        pooled_data["rpc1_proj"],
        lw=0,
        c=pooled_data[color_by],
        cmap=COLS["phase"],
        s=s,
        alpha=0.5,
    )
    axs[j, 0].set_ylabel(f"col. by {color_by}")

    # for j, color_by in enumerate(["rpc_angle", "anatomical_phase"]:
    axs[j, 1].scatter(
        pooled_data["lr_pos"],
        pooled_data["ap_pos"],
        lw=0,
        c=pooled_data[color_by],
        cmap=COLS["phase"],
        s=s,
        alpha=0.5,
    )
    axs[j, 0].set_ylabel(f"col. by {color_by}")

    ltplt.add_anatomy_scalebar(axs[j, 1], pos=(110, 40), length=30)
    ltplt.despine(axs[j, 1], "all")
    ltplt.despine(axs[j, 0], "all")

    ltplt.add_scalebar(
        axs[j, 0], xlabel="rPC1", ylabel="rPC2", xlen=0.4, ylen=0.4, disable_axis=False
    )

plt.tight_layout()

f.savefig(fig_location / "pooled_angles_all_fish.pdf")

In [None]:
n_rois = len(pooled_df[pooled_df["group"] == "data"])

# create the function we want to fit
def my_cos(x, amplitude, phase, offset, freq):
    return np.cos(x * freq + phase) * amplitude + offset


# Perform the fit.
# We will fit on the first half of the data, and then measure R^2 over the left outs:
initial_guesses = dict(lr_pos=(80, np.pi, 0, 1), ap_pos=(80, np.pi / 2, 0, 1))

fit_data_slice = slice(0, n_rois // 2)
test_slice = slice(n_rois // 2, None)

fit_results = []
for group in ["data", "shuf"]:
    group_data = pooled_df[pooled_df["group"] == group]

    for ax in ["lr_pos", "ap_pos"]:

        fit = curve_fit(
            my_cos,
            group_data["rpc_angle"].values[fit_data_slice],
            group_data[ax].values[fit_data_slice],
            p0=initial_guesses[ax],
        )

        fit_results.append(
            dict(
                group=group,
                ax=ax,
                amp=fit[0][0],
                ph=fit[0][1],
                off=fit[0][2],
                freq=fit[0][3],
            )
        )

fit_results = pd.DataFrame(fit_results)

In [None]:
n_bins = 40
cols = [(0.8,) * 3, (0.8, 0.4, 0.3)]

ax_labels = dict(lr_pos="r-l pos. (μm)", ap_pos="p-a pos. (μm)")

fig = plt.figure(figsize=(5, 3.0))

gs = gridspec.GridSpec(4, 6, figure=fig)
axs_scat = [fig.add_subplot(gs[1:, :3]), fig.add_subplot(gs[1:, 3:])]
axs_comp = [fig.add_subplot(gs[0, 1:3]), fig.add_subplot(gs[0, 4:])]
# ax2 = ax1

x_array = np.arange(-np.pi, np.pi, 0.1)

res = []

for j, ax in enumerate(["lr_pos", "ap_pos"]):
    residuals = []
    for i, group in enumerate(["shuf", "data"]):
        group_data = pooled_df[pooled_df["group"] == group]

        axs_scat[j].scatter(
            group_data["rpc_angle"],
            group_data[ax],
            color=cols[i],
            s=5,
            alpha=0.8,
            lw=0,
            label=group,
        )

        fit_params = fit_results[
            (fit_results["group"] == group) & (fit_results["ax"] == ax)
        ]
        y = my_cos(x_array, *list(fit_params.iloc[0, 2:].values))
        axs_scat[j].plot(x_array, y, color=ltplt.dark_col(cols[i]), label="_nolegend_")

        axs_scat[j].set(xlabel="rPC phase (rad)", ylabel=ax_labels[ax])
        ltplt.despine(axs_scat[j])

        group_data = pooled_df[pooled_df["group"] == group]

        # fit_params = fit_results[(fit_results["group"]==group) & (fit_results["ax"]==ax)]
        predicted = my_cos(
            group_data["rpc_angle"].values[test_slice],
            *list(fit_params.iloc[0, 2:].values)
        )

        # h_range = (np.percentile(pooled_df[ax]**2, (1, 98)))
        residuals.append((predicted - group_data[ax].values[test_slice]) ** 2)

    ltplt.boxplot(residuals, cols=cols, ax=axs_comp[j], widths=0.6, ec=(0.3,) * 3)

    axs_comp[j].set_yticklabels(["shuf", "data"])
    axs_comp[j].tick_params(axis="both", which="both", labelsize=8)
    axs_comp[j].xaxis.tick_top()
    axs_comp[j].xaxis.set_label_position("top")
    axs_comp[j].set_xlabel("residuals (cross. val.)", fontsize=8)

    [axs_comp[j].axes.spines[s].set_visible(False) for s in ["bottom", "right"]]


axs_scat[0].set_ylim(-120, 140)
axs_scat[1].set_ylim(-50, 80)
axs_scat[1].legend(
    frameon=False,
    fontsize=10,
    markerscale=2,  # bbox_to_anchor=(0.8, 0.9),
    handletextpad=-0.3,
)

plt.tight_layout()
f.savefig(fig_location / "fit_anatomical_distribution.pdf")

## Redefine network phase

Now that we have a standard alignment of phase across fish, we can redefine our network phase starting from the rPC angles, to ensure that for all fish:
 - network phase 0 -> rostral activation,
 - phase pi/2 -> right activation, 
 - phase -pi/2 -> left activation, 
 - phase pi/-pi -> back activation
 
Positive phase changes mean cw network activity rotation, anf negative phase changes ccw rotations.

In [None]:
exp = LotrExperiment(A_FISH)

# From out previous definition of network phase:
norm_activity = get_zero_mean_weights(exp.traces[:, exp.hdn_indexes].T).T
avg_vects = np.einsum("ij,ik->jk", norm_activity.T, exp.rpc_scores)

phase = np.arctan2(avg_vects[:, 1], -avg_vects[:, 0])

In [None]:
d = 7
f, ax = plt.subplots(figsize=(3, 4))
for n, i in enumerate([100, 2000, 8000, 1200]):
    c = ltplt.get_default_phase_col(phase[i])
    ax.scatter(
        exp.rpc_angles, zscore(exp.traces[i, exp.hdn_indexes]) + n * d, color=c, s=15
    )
    ax.plot(
        [phase[i], phase[i]], np.array([-1, 1]) * (d / 2) + n * d, c=ltplt.dark_col(c)
    )

ltplt.add_scalebar(ax, xlabel="", ylabel="ΔF", xlen=0, ypos=d * 2.5, disable_axis=False)
ltplt.despine(ax, ["left", "right", "top"])
ax.set(
    xlabel="neuron angle/network phase",
    xlim=(-np.pi - 0.5, np.pi + 0.5),
    xticks=([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi]),
    xticklabels=["$-π$", r"$\dfrac{-π}{2}$", 0, r"$\dfrac{π}{2}$", "$π$"],
)
plt.show()
plt.tight_layout()

f.savefig(fig_location / "network_phases.pdf")

In the subsequent analyses, we will be using the `LotrExperiment.network_phase` property for this quantification.

In [None]:
assert np.allclose(phase, exp.network_phase, rtol=0.001)

In [None]:
c