In [None]:
%matplotlib widget

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

import lotr.plotting as pltltr
from lotr import LotrExperiment, dataset_folders
from lotr.data_preprocessing.anatomy import anatomical_angle_remapping
from lotr.result_logging import ResultsLogger
from lotr.utils import circular_corr

logger = ResultsLogger()

COLS = pltltr.COLS

In [None]:
results_list = []

# Loop over all fish, and compute values for data and after shuffling coords itentity:
for path in dataset_folders:
    exp = LotrExperiment(path)
    sequence = np.arange(len(exp.hdn_indexes))

    # center coords (will be replaced with morphing):
    coords = exp.morphed_coords_um[exp.hdn_indexes, 1:][
        sequence, :
    ]  # shuffle if necessary
    coords[:, 1] = coords[:, 1] + 16  # arbitrary centering

    # And, for plotting, computing an "anatomical phase", aka position around the anatomy:
    anatomical_phase = np.angle(-coords[:, 1] + 1j * -coords[:, 0])

    # Stretch angle to avoid undersampling aroung midline
    anatomical_phase = anatomical_angle_remapping(anatomical_phase)

    r = np.mean(np.sqrt(exp.rpc_scores[:, 0] ** 2 + exp.rpc_scores[:, 1] ** 2))
    results_list.append(
        pd.DataFrame(
            dict(
                rpc1_proj=exp.rpc_scores[:, 0] / r,
                rpc2_proj=exp.rpc_scores[:, 1] / r,
                lr_pos=coords[:, 0],
                ap_pos=coords[:, 1],
                rpc_angle=exp.rpc_angles,
                anatomical_phase=anatomical_phase,
                exp_id=path.name,
                # group=group,
            )
        )
    )

pooled_df = pd.concat(results_list, axis=0).reset_index()

In [None]:
rhos = []

for sel in pooled_df["exp_id"].unique():
    sel_data = pooled_df[pooled_df["exp_id"] == sel]
    rhos.append(
        -circular_corr(
            sel_data["anatomical_phase"].values, sel_data["rpc_angle"].values
        )
    )
order = np.argsort(rhos)

In [None]:
exp_codes = []
for path in dataset_folders:
    exp_codes.append(LotrExperiment(path).exp_code)

In [None]:
p_per_ax = 7
n_cols = 2
n_rows = (len(rhos) + 2) // n_cols

f, all_axs = plt.subplots(
    n_rows + 1,
    p_per_ax * n_cols,
    figsize=(9, 12),
    gridspec_kw=dict(
        left=0.1,
        right=0.99,
        bottom=0.01,
        top=0.95,
        wspace=0.1,
        height_ratios=[
            1,
        ]
        * (n_rows + 1),
        width_ratios=([0.8, 0.8, 0.2, 0.8, 0.2, 0.8, 0.15]) * n_cols,
    ),
)
s = 10
ap = 50
lr = 70

col = ""

axs_flat = all_axs.flatten()

for j in range(2):
    all_axs[0, 3 + j * 2].set(title=["Frontal axis", "Sagittal axis"][j])
    all_axs[0, 3 + j * 2].set_xlabel("rPC angle", fontsize=6)
    all_axs[0, 3 + j * 2].set_ylabel(
        ["Right-left pos. (μm)", "Post.-ant. pos. (μm)"][j], fontsize=6
    )
    all_axs[0, 3 + j * 2].xaxis.set_tick_params(labelsize=6)
    all_axs[0, 3 + j * 2].yaxis.set_tick_params(labelsize=6)

for n in range(len(order) + 1):
    if n == len(order):
        sel_data = pooled_df
        axs = axs_flat[:p_per_ax]
        alpha = 0.1
        exp_code = "all fish"

    else:
        sel = pooled_df["exp_id"].unique()[order[n]]
        sel_data = pooled_df[(pooled_df["exp_id"] == sel)]

        i_plot = (n + 3) * p_per_ax
        axs = axs_flat[i_plot : i_plot + p_per_ax]
        alpha = 1
        exp_code = exp_codes[order[n]]

    axs[0].scatter(
        -sel_data["rpc2_proj"],
        sel_data["rpc1_proj"],
        lw=0,
        c=sel_data["rpc_angle"],
        cmap=COLS["phase"],
        s=s,
        alpha=alpha,
        rasterized=True,
    )

    axs[1].scatter(
        sel_data["lr_pos"],
        sel_data["ap_pos"],
        lw=0,
        c=sel_data["rpc_angle"],
        cmap=COLS["phase"],
        s=s,
        alpha=alpha,
        rasterized=True,
    )
    sc = axs[1].scatter(
        [np.nan], [np.nan], c=[np.nan], cmap=COLS["phase"], vmin=-np.pi, vmax=np.pi
    )

    rho = circular_corr(
        sel_data["anatomical_phase"].values, sel_data["rpc_angle"].values
    )

    axs[0].text(
        -1.55,
        0,
        f"{exp_code}\nr: {rho:.2f}",
        va="center",
        ha="center",
        rotation=90,
        fontsize=6,
    )
    [axs[j].axis("equal") for j in range(2)]

    axs[1].scatter([-lr, lr, -lr, lr], [ap, ap, -ap, -ap], s=0)
    pltltr.despine(axs[1], "all")
    pltltr.despine(axs[0], "all")

    for j, ax in enumerate(["lr_pos", "ap_pos"]):
        residuals = []

        axs[3 + j * 2].scatter(
            sel_data["rpc_angle"],
            sel_data[ax],
            color="C0",
            s=5,
            lw=0,
            alpha=alpha,
            label="_nolegend_",
            rasterized=True,
        )
        axs[3 + j * 2].set(
            xticks=(-np.pi / 2, 0, np.pi / 2),
            xlim=(-np.pi - 0.2, np.pi + 0.2),
            ylim=[(-60, 60), (-30, 20)][j],
            alpha=alpha,
        )
        if n != len(order):
            axs[3 + j * 2].set(yticklabels=[], xticklabels=[])
        pltltr.despine(axs[3 + j * 2])

    axs[2].axis("off")
    axs[4].axis("off")
    axs[6].axis("off")

[ax.axis("off") for ax in axs_flat[p_per_ax : p_per_ax * 3]]
pltltr.add_anatomy_scalebar(all_axs[0, 1], pos=(-50, -50), length=30, fontsize=6)

pltltr.add_cbar(
    sc,
    all_axs[0, 1],
    inset_loc=(-0.1, -0.2, 0.3, 0.08),
    ticks=[-np.pi + 0.01, 0, np.pi - 0.01],
    ticklabels=["-π", "0", "π"],
    title="rPC angle",
    orientation="horizontal",
    titlesize=6,
    labelsize=6,
)

In [None]:
pltltr.savefig("anatomy_all_fish", folder="S3b")