In [1]:
%load_ext blackcellmagic
%load_ext autoreload
%autoreload 2

# Imports

In [21]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from ipywidgets import interact
from matplotlib.gridspec import GridSpec

from plotters import _get_data
from figures import (
    mean_agg_rule_tr_indivs,
    mean_agg_rule_tr_group,
    mean_agg_rule_tr_indiv_nb,
)

from natsort import natsorted

from plotters import (
    _boxplots_one_variable_with_stats,
    _get_num_data_points,
    _update_order,
    _get_variables_ylims_and_offsets,
)
from mlxtend.evaluate import permutation_test
from stats import MEAN_STATS_KWARGS
from stats import PAIRS_OF_GROUPS
from constants import (
    COLORS,
    GENOTYPE_GROUP_GENOTYPE_ORDER,
    all_variables_names_enhanced,
    GENOTYPE_GROUP_ORDER,
    FOCAL_NB_GENOTYPE_ORDER
)
from plotters import _boxplot_axes_one_variable

from trajectorytools.plot import polar_histogram, plot_polar_histogram

# Constants

In [3]:
animal_uid_col = "trial_uid_id"
video_uid_col = "trial_uid"

animal_info_cols = [
    "genotype",
    "fish_id_exp",
    "identity",
    "dpf",
    "size_cm",
]

video_info_cols = [
    "trial_uid",
    "experiment_type",
    "gene",
    "founder",
    "replicate",
    "genotype_group",
]

individual_variables = ["normed_distance_to_origin", "speed"]
group_variables = [
    "mean_distance_to_center_of_group",
    "polarization_order_parameter",
    "rotation_order_parameter",
]
individual_nb_variables = ["nb_angle", "nb_distance"]

individual_variables_stats = [("normed_distance_to_origin", "mean"), ("speed", "mean")]
group_variables_stats = [
    ("mean_distance_to_center_of_group", "mean"),
    ("polarization_order_parameter", "mean"),
    ("rotation_order_parameter", "mean"),
]
individual_nb_variables_stats = [
    ("nb_angle", "ratio_in_front"),
    # ("nb_angle", "ratio_in_back"),
    ("nb_distance", "mean"),
]

boxplot_kwargs = {
    "x": "genotype_group_genotype",
    "palette": COLORS,
    "whis": 100,
}
boxplot_kwargs_group = {
    "x": "genotype_group",
    "palette": COLORS,
    "whis": 100,
}
boxplot_kwargs_indiv_nb = {
    "x": "focal_nb_genotype",
    "whis": 100,
}
stats_kwargs = {
    "test_func": permutation_test,
    "test_func_kwargs": MEAN_STATS_KWARGS,
}

data_filters = [
    lambda x: ~x["gene"].str.contains("srrm"),
    # lambda x: x["experiment_type"] != 1,
]
# data_filters = []

# Load data

In [4]:
data_path_indiv_vars = "/home/pacorofe/Dropbox (CRG ADV)/ZFISH_MICs/_BSocial/2022_ANALYSIS_social/generated_tables/tr_indiv_vars_bl.pkl"
per_indiv_stasts_kwargs = {
    "groupby": [
        "trial_uid",
        "identity",
        "genotype_group",
        "genotype",
        "line",
        "line_replicate",
        "experiment_type",
    ],
    "agg_rule": mean_agg_rule_tr_indivs,
}
data_indiv, data_indiv_stats = _get_data(
    data_path_indiv_vars,
    data_filters=data_filters,
    per_indiv_stats_kwargs=per_indiv_stasts_kwargs,
)

Getting data
Normalizing replicates stats with respect to HET_HET
  (data.genotype_group == normalizing_genotype_group)
  (data.genotype_group == normalizing_genotype_group)
Filtering data
Filtering data
original shape (26493056, 76)
(22209532, 76)
Groupping data


In [5]:
data_path_group_vars = "/home/pacorofe/Dropbox (CRG ADV)/ZFISH_MICs/_BSocial/2022_ANALYSIS_social/generated_tables/tr_group_vars_bl.pkl"

per_group_stasts_kwargs = {
    "groupby": [
        "trial_uid",
        "genotype_group",
        "line",
        "line_replicate",
        "experiment_type",
    ],
    "agg_rule": mean_agg_rule_tr_group,
}
data_group, data_group_stats = _get_data(
    data_path_group_vars,
    data_filters=data_filters,
    per_indiv_stats_kwargs=per_group_stasts_kwargs,
)

Getting data
Normalizing replicates stats with respect to HET_HET
  (data.genotype_group == normalizing_genotype_group)
  (data.genotype_group == normalizing_genotype_group)
Filtering data
Filtering data
original shape (13246528, 18)
(11104766, 18)
Groupping data


In [6]:
data_path_indiv_nb_vars = "/home/pacorofe/Dropbox (CRG ADV)/ZFISH_MICs/_BSocial/2022_ANALYSIS_social/generated_tables/tr_indiv_nb_vars_bl.pkl"

per_group_stasts_kwargs = {
    "groupby": [
        "trial_uid",
        "identity",
        "identity_nb",
        "genotype_group",
        "genotype",
        "genotype_nb",
        "line",
        "line_replicate",
        "experiment_type",
        'focal_nb_genotype'
    ],
    "agg_rule": mean_agg_rule_tr_indiv_nb,
}
data_indiv_nb, data_indiv_nb_stats = _get_data(
    data_path_indiv_nb_vars,
    data_filters=data_filters,
    per_indiv_stats_kwargs=per_group_stasts_kwargs,
)

Getting data
Normalizing replicates stats with respect to HET_HET
  (data.genotype_group == normalizing_genotype_group)
  (data.genotype_group == normalizing_genotype_group)
Filtering data
Filtering data
original shape (26492917, 37)
(22209394, 37)
Groupping data


# Transform data

In [7]:
data_indiv_per_video = data_indiv.groupby("trial_uid", as_index=False)
data_with_normed_positions = []
for idx, data_indiv_ in data_indiv_per_video:
    data_indiv_["s_x_normed"] = data_indiv_["s_x"] / data_indiv_["s_x"].abs().max()
    data_indiv_["s_y_normed"] = data_indiv_["s_y"] / data_indiv_["s_y"].abs().max()
    data_with_normed_positions.append(data_indiv_)
data_indiv = pd.concat(data_with_normed_positions)

data_indiv["line_exp"] = (
    data_indiv["line"] + "_" + data_indiv["experiment_type"].astype(str)
)
data_group["line_exp"] = (
    data_group["line"] + "_" + data_group["experiment_type"].astype(str)
)
data_indiv_nb["line_exp"] = (
    data_indiv_nb["line"] + "_" + data_indiv_nb["experiment_type"].astype(str)
)
data_indiv_stats["line_exp"] = (
    data_indiv_stats["line"] + "_" + data_indiv_stats["experiment_type"].astype(str)
)
data_group_stats["line_exp"] = (
    data_group_stats["line"] + "_" + data_group_stats["experiment_type"].astype(str)
)
data_indiv_nb_stats["line_exp"] = (
    data_indiv_nb_stats["line"]
    + "_"
    + data_indiv_nb_stats["experiment_type"].astype(str)
)

# Variables range

In [8]:
variables_ranges = {}

for dataset in [data_indiv, data_indiv_nb, data_group]:
    for col in dataset.columns:
        if col in all_variables_names_enhanced:
            variables_ranges[col] = {
                "min": np.nanmin(dataset[col]),
                "max": np.nanmax(dataset[col]),
            }

# Functions to prepare type of figures

In [9]:
def _prepare_animal_indiv_vars_fig(num_variables):
    fig = plt.figure(constrained_layout=True, figsize=(30, 10))
    num_cols = num_variables * 3
    num_rows = num_variables
    gs = GridSpec(num_rows, num_cols, figure=fig)

    ax_trajectories = fig.add_subplot(gs[:num_rows, :num_rows])
    axs_variables = []
    axs_distributions = []
    for i in range(num_variables):
        axs_variables.append(
            fig.add_subplot(gs[i : i + 1, num_rows : num_cols - 1])
        )
        axs_distributions.append(
            fig.add_subplot(gs[i : i + 1, num_cols - 1 :])
        )
    return fig, ax_trajectories, axs_variables, axs_distributions


def _prepare_video_group_fig(num_variables):
    fig = plt.figure(constrained_layout=True, figsize=(30, 10))
    num_cols = num_variables * 3
    num_rows = num_variables
    gs = GridSpec(num_rows, num_cols, figure=fig)

    ax_order_params = fig.add_subplot(gs[:num_rows, :num_rows])
    axs_variables = []
    axs_distributions = []
    for i in range(num_variables):
        axs_variables.append(
            fig.add_subplot(gs[i : i + 1, num_rows : num_cols - 1])
        )
        axs_distributions.append(
            fig.add_subplot(gs[i : i + 1, num_cols - 1 :])
        )
    return fig, ax_order_params, axs_variables, axs_distributions


def _prepare_video_indiv_nb_fig(num_variables):
    fig = plt.figure(constrained_layout=True, figsize=(30, 10))
    num_cols = num_variables * 3
    num_rows = num_variables
    gs = GridSpec(num_rows, num_cols, figure=fig)

    ax_relative_position = fig.add_subplot(gs[:num_rows, :num_rows])
    axs_variables = []
    axs_distributions = []
    for i in range(num_variables):
        axs_variables.append(
            fig.add_subplot(gs[i : i + 1, num_rows : num_cols - 1])
        )
        axs_distributions.append(
            fig.add_subplot(gs[i : i + 1, num_cols - 1 :])
        )
    return fig, ax_relative_position, axs_variables, axs_distributions


def _prepare_partition_indiv_vars_summary_fig(
    num_variables, num_genotype_groups
):
    assert num_variables == 2
    fig = plt.figure(constrained_layout=True, figsize=(30, 10))
    num_columns_position_hist = np.ceil(
        (num_genotype_groups + 1) / num_variables
    ).astype(int)
    num_cols = 4 + num_columns_position_hist
    num_rows = num_variables
    gs = GridSpec(num_rows, num_cols, figure=fig)

    axs_order_params_dist = []
    for row in range(num_rows):
        for col in range(num_columns_position_hist):
            axs_order_params_dist.append(
                fig.add_subplot(gs[row : row + 1, col : col + 1])
            )
    # axs_variables = []
    axs_distributions = []
    axs_boxplots_raw = []
    axs_boxplots_diff = []
    axs_boxplots_standardized = []
    for i in range(num_variables):
        # axs_variables.append(fig.add_subplot(gs[i:i+1, num_rows:num_cols-4]))
        axs_distributions.append(
            fig.add_subplot(gs[i : i + 1, num_cols - 4 : num_cols - 3])
        )
        axs_boxplots_raw.append(
            fig.add_subplot(gs[i : i + 1, num_cols - 3 : num_cols - 2])
        )
        axs_boxplots_diff.append(
            fig.add_subplot(gs[i : i + 1, num_cols - 2 : num_cols - 1])
        )
        axs_boxplots_standardized.append(
            fig.add_subplot(gs[i : i + 1, num_cols - 1 :])
        )
    return (
        fig,
        axs_order_params_dist,
        axs_distributions,
        axs_boxplots_raw,
        axs_boxplots_diff,
        axs_boxplots_standardized,
    )


def _prepare_partition_group_vars_summary_fig(
    num_variables, num_genotype_groups
):
    assert num_variables == 3
    fig = plt.figure(constrained_layout=True, figsize=(30, 10))
    num_columns_position_hist = np.ceil(
        (num_genotype_groups + 1) / num_variables
    ).astype(int)
    num_cols = 4 + num_columns_position_hist
    num_rows = num_variables
    gs = GridSpec(num_rows, num_cols, figure=fig)

    axs_positions_dist = []
    for row in range(num_rows):
        for col in range(num_columns_position_hist):
            axs_positions_dist.append(
                fig.add_subplot(gs[row : row + 1, col : col + 1])
            )
    # axs_variables = []
    axs_distributions = []
    axs_boxplots_raw = []
    axs_boxplots_diff = []
    axs_boxplots_standardized = []
    for i in range(num_variables):
        # axs_variables.append(fig.add_subplot(gs[i:i+1, num_rows:num_cols-4]))
        axs_distributions.append(
            fig.add_subplot(gs[i : i + 1, num_cols - 4 : num_cols - 3])
        )
        axs_boxplots_raw.append(
            fig.add_subplot(gs[i : i + 1, num_cols - 3 : num_cols - 2])
        )
        axs_boxplots_diff.append(
            fig.add_subplot(gs[i : i + 1, num_cols - 2 : num_cols - 1])
        )
        axs_boxplots_standardized.append(
            fig.add_subplot(gs[i : i + 1, num_cols - 1 :])
        )
    return (
        fig,
        axs_positions_dist,
        axs_distributions,
        axs_boxplots_raw,
        axs_boxplots_diff,
        axs_boxplots_standardized,
    )
    
def _prepare_partition_indiv_nb_vars_summary_fig(
    num_variables, num_focal_nb_genotype_groups
):
    assert num_variables == 2
    fig = plt.figure(constrained_layout=True, figsize=(30, 10))
    num_columns_position_hist = np.ceil(
        num_focal_nb_genotype_groups / num_variables
    ).astype(int)
    num_cols = 4 + num_columns_position_hist
    num_rows = num_variables
    gs = GridSpec(num_rows, num_cols, figure=fig)

    axs_positions_dist = []
    for row in range(num_rows):
        for col in range(num_columns_position_hist):
            axs_positions_dist.append(
                fig.add_subplot(gs[row : row + 1, col : col + 1], polar=True)
            )
    # axs_variables = []
    axs_distributions = []
    axs_boxplots_raw = []
    axs_boxplots_diff = []
    axs_boxplots_standardized = []
    for i in range(num_variables):
        # axs_variables.append(fig.add_subplot(gs[i:i+1, num_rows:num_cols-4]))
        axs_distributions.append(
            fig.add_subplot(gs[i : i + 1, num_cols - 4 : num_cols - 3])
        )
        axs_boxplots_raw.append(
            fig.add_subplot(gs[i : i + 1, num_cols - 3 : num_cols - 2])
        )
        # axs_boxplots_diff.append(
        #     fig.add_subplot(gs[i : i + 1, num_cols - 2 : num_cols - 1])
        # )
        # axs_boxplots_standardized.append(
        #     fig.add_subplot(gs[i : i + 1, num_cols - 1 :])
        # )
    return (
        fig,
        axs_positions_dist,
        axs_distributions,
        axs_boxplots_raw
        # axs_boxplots_diff,
        # axs_boxplots_standardized,
    )
    

# String info 

In [10]:
def _get_info(data, info_cols):
    info_str = ""
    for info_col in info_cols:
        assert info_col in data.columns
        infos = data[info_col].unique()
        assert len(set(infos)) == 1, set(infos)
        info = str(infos[0])
        info_str += f"{info_col}: {info} - "
    info_str = info_str[:-3]
    return info_str


def get_animal_info_str(
    animal_data, info_cols=animal_info_cols, video_info_cols=video_info_cols
):
    video_info = _get_info(animal_data, video_info_cols)
    animal_info = _get_info(animal_data, info_cols)
    return f"{video_info} \n {animal_info}"


def get_video_info_str(video_data, video_info_cols=video_info_cols):
    video_str_info = _get_info(video_data, video_info_cols)

    animals_uid_ids = video_data[animal_uid_col].unique()
    for animal_uid in animals_uid_ids:
        animal_data = video_data[video_data[animal_uid_col] == animal_uid]
        animal_str_info = _get_info(animal_data, animal_info_cols)
        video_str_info += f"\n {animal_str_info}"
    return video_str_info


def get_focal_nb_info(animal_nb_data, animal_info_cols=animal_info_cols):
    focal_info_cols = animal_info_cols
    nb_info_cols = [f"{col}_nb" for col in animal_info_cols]
    focal_info_str = _get_info(animal_nb_data, focal_info_cols)
    focal_info_str = f"focal: {focal_info_str}"
    nb_info_str = _get_info(animal_nb_data, nb_info_cols)
    nb_info_str = f"neighbour: {nb_info_str}"
    return f"{focal_info_str} \n {nb_info_str}"

# Summary plotters

In [11]:
def _plot_animal_indiv_vars_summary(data, variables, hue=None):
    (
        fig,
        ax_trajectories,
        axs_variables,
        axs_distributions,
    ) = _prepare_animal_indiv_vars_fig(len(variables))
    plot_trajectory(data, ax=ax_trajectories, hue=hue)
    for variable, ax_time, ax_dist in zip(variables, axs_variables, axs_distributions):
        plot_variable_along_time(data, variable, ax=ax_time, hue=hue)
        plot_variable_1d_distribution(data, variable, ax=ax_dist, hue=hue)
    return fig


def _plot_video_indiv_vars_summary(data, variables, hue=None):
    (
        fig,
        ax_trajectories,
        axs_variables,
        axs_distributions,
    ) = _prepare_animal_indiv_vars_fig(len(variables))
    plot_trajectory(data, ax=ax_trajectories, hue=hue)
    for variable, ax_time, ax_dist in zip(variables, axs_variables, axs_distributions):
        plot_variable_along_time(data, variable, ax=ax_time, hue=hue)
        plot_variable_1d_distribution(data, variable, ax=ax_dist, hue=hue)
    return fig


def _plot_group_variables_summary(data, variables):
    fig, ax_order_params, axs_variables, axs_distributions = _prepare_video_group_fig(
        len(variables)
    )
    plot_order_parameter_dist(data, ax=ax_order_params)
    for variable, ax_time, ax_dist in zip(variables, axs_variables, axs_distributions):
        plot_variable_along_time(data, variable, ax=ax_time)
        plot_variable_1d_distribution(data, variable, ax=ax_dist)
    return fig


def _plot_video_indiv_nb_variables_summary(data, variables):
    (
        fig,
        ax_order_params,
        axs_variables,
        axs_distributions,
    ) = _prepare_video_indiv_nb_fig(len(variables))
    plot_relative_position_dist(data, ax=ax_order_params)
    for variable, ax_time, ax_dist in zip(variables, axs_variables, axs_distributions):
        plot_variable_along_time(
            data, variable, ax=ax_time, hue="genotype_nb", units="identity_nb"
        )
        plot_variable_1d_distribution(data, variable, ax=ax_dist, hue="genotype_nb")
    return fig


def _plot_partition_indiv_vars_summary(
    data, data_stats, variables, variables_stats, hue=None
):
    num_genotype_groups = len(data["genotype_group"].unique())
    (
        fig,
        axs_positions_dist,
        axs_distributions,
        axs_boxplots_raw,
        axs_boxplots_diff,
        axs_boxplots_standardized,
    ) = _prepare_partition_indiv_vars_summary_fig(
        len(variables_stats), num_genotype_groups
    )
    plot_positions_dist_per_genotype_group(data, axs=axs_positions_dist)
    num_data_points = _get_num_data_points(data_stats, boxplot_kwargs)
    for i, (variable, ax_dist) in enumerate(zip(variables, axs_distributions)):
        if i == 0:
            legend = True
        else:
            legend = False
        plot_variable_1d_distribution(
            data, variable, ax=ax_dist, hue=hue, legend=legend, how="v"
        )
    for i, (
        variable,
        ax_boxplot,
        ax_boxplot_diff,
        ax_boxplot_standardized,
    ) in enumerate(
        zip(
            variables_stats,
            axs_boxplots_raw,
            axs_boxplots_diff,
            axs_boxplots_standardized,
        )
    ):
        if i == 0:
            legend = True
        else:
            legend = False
        variables_ylims, variables_y_offsets = _get_variables_ylims_and_offsets(
            data_stats
        )
        _update_order(data, boxplot_kwargs, GENOTYPE_GROUP_GENOTYPE_ORDER)
        _boxplots_one_variable_with_stats(
            ax_boxplot,
            data_stats,
            variable,
            num_data_points=num_data_points,
            boxplot_kwargs=boxplot_kwargs,
            stats_kwargs=stats_kwargs,
            pairs_of_groups_for_stats=PAIRS_OF_GROUPS,
            variable_ylim=None,
            variable_y_offset=variables_y_offsets[variable],
        )
        _boxplots_one_variable_with_stats(
            ax_boxplot_diff,
            data_stats,
            (f"{variable[0]}_diff", variable[1]),
            num_data_points=num_data_points,
            boxplot_kwargs=boxplot_kwargs,
            stats_kwargs=stats_kwargs,
            pairs_of_groups_for_stats=PAIRS_OF_GROUPS,
            variable_ylim=None,
            variable_y_offset=variables_y_offsets[(f"{variable[0]}_diff", variable[1])],
        )
        _boxplots_one_variable_with_stats(
            ax_boxplot_standardized,
            data_stats,
            (f"{variable[0]}_standardized", variable[1]),
            num_data_points=num_data_points,
            boxplot_kwargs=boxplot_kwargs,
            stats_kwargs=stats_kwargs,
            pairs_of_groups_for_stats=PAIRS_OF_GROUPS,
            variable_ylim=None,
            variable_y_offset=variables_y_offsets[(f"{variable[0]}_standardized", variable[1])],
        )
    return fig


def _plot_partition_group_vars_summary(
    data, data_stats, variables, variables_stats, hue=None
):
    num_genotype_groups = len(data["genotype_group"].unique())
    (
        fig,
        axs_order_params_dist,
        axs_distributions,
        axs_boxplots_raw,
        axs_boxplots_diff,
        axs_boxplots_standardized,
    ) = _prepare_partition_group_vars_summary_fig(
        len(variables_stats), num_genotype_groups
    )
    plot_order_parameter_dist_per_genotype_group(data, axs_order_params_dist)
    num_data_points = _get_num_data_points(data_stats, boxplot_kwargs_group)
    for i, (variable, ax_dist) in enumerate(zip(variables, axs_distributions)):
        if i == 0:
            legend = True
        else:
            legend = False
        plot_variable_1d_distribution(
            data, variable, ax=ax_dist, hue=hue, legend=legend, how="v"
        )
    for i, (
        variable,
        ax_boxplot,
        ax_boxplot_diff,
        ax_boxplot_standardized,
    ) in enumerate(
        zip(
            variables_stats,
            axs_boxplots_raw,
            axs_boxplots_diff,
            axs_boxplots_standardized,
        )
    ):
        if i == 0:
            legend = True
        else:
            legend = False
        variables_ylims, variables_y_offsets = _get_variables_ylims_and_offsets(
            data_stats
        )
        _update_order(data, boxplot_kwargs_group, GENOTYPE_GROUP_ORDER)
        _boxplots_one_variable_with_stats(
            ax_boxplot,
            data_stats,
            variable,
            num_data_points=num_data_points,
            boxplot_kwargs=boxplot_kwargs_group,
            stats_kwargs=stats_kwargs,
            pairs_of_groups_for_stats=PAIRS_OF_GROUPS,
            variable_ylim=None,
            variable_y_offset=variables_y_offsets[variable],
        )
        _boxplots_one_variable_with_stats(
            ax_boxplot_diff,
            data_stats,
            (f"{variable[0]}_diff", variable[1]),
            num_data_points=num_data_points,
            boxplot_kwargs=boxplot_kwargs_group,
            stats_kwargs=stats_kwargs,
            pairs_of_groups_for_stats=PAIRS_OF_GROUPS,
            variable_ylim=None,
            variable_y_offset=variables_y_offsets[(f"{variable[0]}_diff", variable[1])],
        )
        _boxplots_one_variable_with_stats(
            ax_boxplot_standardized,
            data_stats,
            (f"{variable[0]}_standardized", variable[1]),
            num_data_points=num_data_points,
            boxplot_kwargs=boxplot_kwargs_group,
            stats_kwargs=stats_kwargs,
            pairs_of_groups_for_stats=PAIRS_OF_GROUPS,
            variable_ylim=None,
            variable_y_offset=variables_y_offsets[(f"{variable[0]}_standardized", variable[1])],
        )
    return fig


def _plot_partition_indiv_nb_summary(
    data, data_stats, variables, variables_stats, hue=None
):
    num_focal_nb_genotype_groups = len(data["focal_nb_genotype"].unique())
    (
        fig,
        axs_polar_plots,
        axs_distributions,
        axs_boxplots_raw,
        # axs_boxplots_diff,
        # axs_boxplots_standardized,
    ) = _prepare_partition_indiv_nb_vars_summary_fig(
        len(variables_stats), num_focal_nb_genotype_groups
    )
    plot_polar_dist_relative_positions(data, axs_polar_plots)
    num_data_points = _get_num_data_points(data_stats, boxplot_kwargs_indiv_nb)
    for i, (variable, ax_dist) in enumerate(zip(variables, axs_distributions)):
        if i == 0:
            legend = True
        else:
            legend = False
        plot_variable_1d_distribution(
            data, variable, ax=ax_dist, hue=hue, legend=legend, how="v"
        )
    for i, (
        variable,
        ax_boxplot,
        # ax_boxplot_diff,
        # ax_boxplot_standardized,
    ) in enumerate(
        zip(
            variables_stats,
            axs_boxplots_raw,
            # axs_boxplots_diff,
            # axs_boxplots_standardized,
        )
    ):
        if i == 0:
            legend = True
        else:
            legend = False
        variables_ylims, variables_y_offsets = _get_variables_ylims_and_offsets(
            data_stats
        )
        _update_order(data, boxplot_kwargs_indiv_nb, FOCAL_NB_GENOTYPE_ORDER)
        _boxplots_one_variable_with_stats(
            ax_boxplot,
            data_stats,
            variable,
            num_data_points=num_data_points,
            boxplot_kwargs=boxplot_kwargs_indiv_nb,
            stats_kwargs=stats_kwargs,
            pairs_of_groups_for_stats=PAIRS_OF_GROUPS,
            variable_ylim=None,
            variable_y_offset=variables_y_offsets[variable],
        )
        # _boxplots_one_variable_with_stats(
        #     ax_boxplot_diff,
        #     data_stats,
        #     (f"{variable[0]}_diff", variable[1]),
        #     num_data_points=num_data_points,
        #     boxplot_kwargs=boxplot_kwargs_indiv_nb,
        #     stats_kwargs=stats_kwargs,
        #     pairs_of_groups_for_stats=PAIRS_OF_GROUPS,
        #     variable_ylim=None,
        #     variable_y_offset=variables_y_offsets[(f"{variable[0]}_diff", variable[1])],
        # )
        # _boxplots_one_variable_with_stats(
        #     ax_boxplot_standardized,
        #     data_stats,
        #     (f"{variable[0]}_standardized", variable[1]),
        #     num_data_points=num_data_points,
        #     boxplot_kwargs=boxplot_kwargs_indiv_nb,
        #     stats_kwargs=stats_kwargs,
        #     pairs_of_groups_for_stats=PAIRS_OF_GROUPS,
        #     variable_ylim=None,
        #     variable_y_offset=variables_y_offsets[(f"{variable[0]}_standardized", variable[1])],
        # )
    return fig

# Ax plotters

In [25]:
def plot_order_parameter_dist(data, ax=None):
    x_var = "rotation_order_parameter"
    y_var = "polarization_order_parameter"
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    assert x_var in data
    assert y_var in data
    sns.histplot(
        ax=ax, data=data, x=x_var, y=y_var, bins=(10, 10), binrange=((0, 1), (0, 1))
    )
    ax.set_aspect("equal")
    ax.set_ylabel(y_var)
    ax.set_xlabel(x_var)
    sns.despine(ax=ax)


def plot_relative_position_dist(data, ax=None):
    x_var = "nb_position_x"
    y_var = "nb_position_y"
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    assert x_var in data
    assert y_var in data
    sns.histplot(
        ax=ax,
        data=data,
        x=x_var,
        y=y_var,
        cbar=False,
        binrange=((-5, 5), (-5, 5)),
        bins=(50, 50),
    )
    ax.set_aspect("equal")
    ax.set_ylabel(y_var)
    ax.set_xlabel(x_var)
    ax.axhline(0, c="k", ls=":")
    ax.axvline(0, c="k", ls=":")
    focal_nb_str = f"focal: {data['genotype'].unique()[0]} - neighbour: {data['genotype_nb'].unique()[0]}"
    ax.set_title(focal_nb_str)
    sns.despine(ax=ax)


def plot_trajectory(
    data, ax=None, hue=None, show_trajectories=True, x_var="s_x", y_var="s_y"
):
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    assert x_var in data
    assert y_var in data
    if hue is not None:
        cbar = False
    else:
        cbar = True
    sns.histplot(ax=ax, data=data, x=x_var, y=y_var, cbar=cbar, hue=hue, bins=(20, 20))
    if show_trajectories:
        sns.lineplot(
            ax=ax,
            data=data,
            x=x_var,
            y=y_var,
            sort=False,
            hue=hue,
            alpha=0.5,
            units=hue,
            estimator=None,
        )
    ax.set_aspect("equal")
    ax.set_ylabel(y_var)
    ax.set_xlabel(x_var)
    sns.despine(ax=ax)


def plot_variable_along_time(
    data, variable, ax=None, hue=None, units=None, estimator=None, legend=True
):
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(30, 10))
    assert "frame" in data
    assert variable in data
    if units is None:
        units = hue
    _boxplot_axes_one_variable(ax, data, variable, how="h", add_text=True)
    if estimator is None:
        sns.lineplot(
            ax=ax,
            data=data,
            x="frame",
            y=variable,
            alpha=0.5,
            hue=hue,
            units=units,
            legend=legend,
            estimator=estimator,
        )
    else:
        if hue == "genotype_group_genotype":
            sns.lineplot(
                ax=ax,
                data=data,
                x="frame",
                y=variable,
                alpha=0.25,
                hue=hue,
                estimator=estimator,
                ci=None,
                legend=legend,
                palette=COLORS,
            )
        else:
            sns.lineplot(
                ax=ax,
                data=data,
                x="frame",
                y=variable,
                alpha=0.25,
                hue=hue,
                estimator=estimator,
                ci=None,
                legend=legend,
            )
    sns.despine(ax=ax)


def plot_variable_1d_distribution(
    data, variable, ax=None, hue=None, legend=None, how="h"
):
    bin_range = variables_ranges[variable]["min"], variables_ranges[variable]["max"]
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(30, 10))
    assert variable in data
    _boxplot_axes_one_variable(ax, data, variable, how=how)
    if how == "h":
        x_var = None
        y_var = variable
    elif how == "v":
        x_var = variable
        y_var = None
    if hue == "genotype_group_genotype":
        sns.histplot(
            ax=ax,
            data=data,
            x=x_var,
            y=y_var,
            stat="probability",
            common_norm=False,
            element="poly",
            fill=False,
            alpha=0.5,
            hue=hue,
            bins=100,
            binrange=bin_range,
            legend=legend,
            palette=COLORS,
        )
    else:
        sns.histplot(
            ax=ax,
            data=data,
            x=x_var,
            y=y_var,
            stat="probability",
            common_norm=False,
            element="poly",
            fill=False,
            alpha=0.5,
            hue=hue,
            bins=100,
            binrange=bin_range,
            legend=legend,
        )
    sns.despine(ax=ax)


def plot_positions_dist_per_genotype_group(data, axs):
    genotype_groups = data["genotype_group"].unique()
    for i, (genotype_group, ax) in enumerate(zip(genotype_groups, axs)):
        sub_data = data[data.genotype_group == genotype_group]
        plot_trajectory(
            sub_data,
            ax,
            show_trajectories=False,
            x_var="s_x_normed",
            y_var="s_y_normed",
        )
        ax.set_title(genotype_group)
    plot_trajectory(
        data, axs[i+1], show_trajectories=False, x_var="s_x_normed", y_var="s_y_normed"
    )
    axs[i+1].set_title("all")


def plot_order_parameter_dist_per_genotype_group(data, axs):
    genotype_groups = data["genotype_group"].unique()
    for i, (genotype_group, ax) in enumerate(zip(genotype_groups, axs)):
        sub_data = data[data.genotype_group == genotype_group]
        plot_order_parameter_dist(sub_data, ax)
        ax.set_title(genotype_group)
    plot_order_parameter_dist(data, axs[i+1])
    axs[i+1].set_title("all")
    if i+1 < len(axs)-1:
        # is not the last axes
        [ax.set_visible(False) for ax in axs[i+2:]]
    
    
def plot_polar_dist_relative_positions(data, axs_polar_plots):
    focal_nb_genotypes = natsorted(data["focal_nb_genotype"].unique())
    pos_hists = {focal_nb_genotype: [] for focal_nb_genotype in focal_nb_genotypes}
    for trial_uid_id in data.trial_uid_id.unique():
        data_focal = data[data.trial_uid_id == trial_uid_id]
        pos_hist, r_edges, theta_edges = polar_histogram(
            data_focal.nb_distance.values,
            data_focal.nb_angle.values,
            density=True,
            range_r=4,
            bins=(10, 12),
        )
        assert len(data_focal.focal_nb_genotype.unique()) == 1
        pos_hists[data_focal.focal_nb_genotype.unique()[0]].append(pos_hist)

    pos_hists_arrs = {
        focal_nb_genotype: np.asarray(pos_hists[focal_nb_genotype])
        for focal_nb_genotype in focal_nb_genotypes
    }

    # Plot polar histogram/maps for relative neighbor positions, turning and acceleration
    vmin = 0
    vmax = 0
    for i, (focal_nb_genotype, pos_hist) in enumerate(pos_hists_arrs.items()):
        mean_pos_hist = np.mean(pos_hist, axis=0)
        vmin = np.min([vmin, np.min(mean_pos_hist)])
        vmax = np.max([vmax, np.max(mean_pos_hist)])

    for i, (focal_nb_genotype, pos_hist) in enumerate(pos_hists_arrs.items()):
        ax = axs_polar_plots[i]
        mean_pos_hist = np.mean(pos_hist, axis=0)
        plot_polar_histogram(
            mean_pos_hist,
            r_edges,
            theta_edges,
            ax,
            vmin=vmin,
            vmax=vmax,
            symmetric_color_limits=False,
        )
        ax.set_title(focal_nb_genotype)
    

# Single animal summary

In [13]:
possible_animals_uid = data_indiv[animal_uid_col].unique()


@interact(animal_uid=possible_animals_uid, y=1.0)
def summary_animal(animal_uid, save=False):
    assert animal_uid_col in data_indiv.columns
    if animal_uid in possible_animals_uid:
        animal_data = data_indiv[data_indiv[animal_uid_col] == animal_uid]
        animal_info_str = get_animal_info_str(animal_data)

        fig = _plot_animal_indiv_vars_summary(animal_data, individual_variables)
        fig.suptitle(animal_info_str)
        if save:
            fig.savefig(f"{animal_uid}.png")
            fig.savefig(f"{animal_uid}.pdf")

    else:
        print(f"Animal {animal_uid} does not exist")
        print("Possible animals are")
        print(possible_animals_uid)

interactive(children=(Dropdown(description='animal_uid', options=('ap1g1_1_1_1_1_0.0', 'ap1g1_1_1_1_1_1.0', 'a…

# Video summary

In [14]:
possible_video_uids_in_group_vars = data_group[video_uid_col].unique()
possible_video_uids_in_indiv_vars = data_indiv[video_uid_col].unique()
possible_video_uids_in_indiv_nb_vars = data_indiv_nb[video_uid_col].unique()
possible_video_uids = (
    set(possible_video_uids_in_group_vars)
    & set(possible_video_uids_in_indiv_vars)
    & set(possible_video_uids_in_indiv_nb_vars)
)
possible_video_uids = natsorted(possible_video_uids)


@interact(video_uid=possible_video_uids, save=False)
def summary_video(video_uid, save=False):
    assert video_uid_col in data_group.columns
    if video_uid in possible_video_uids:
        video_group_data = data_group[data_group[video_uid_col] == video_uid]
        video_indiv_data = data_indiv[data_indiv[video_uid_col] == video_uid]
        video_indiv_nb_data = data_indiv_nb[data_indiv_nb[video_uid_col] == video_uid]

        print(video_group_data.shape)
        print(video_indiv_data.shape)

        video_info_str = get_video_info_str(video_indiv_data)
        print(video_info_str)
        fig = _plot_video_indiv_vars_summary(
            video_indiv_data, individual_variables, hue="identity"
        )
        fig.suptitle(video_info_str)
        if save:
            fig.savefig(f"{video_uid}_indiv.png")
            fig.savefig(f"{video_uid}_indiv.pdf")

        fig = _plot_group_variables_summary(video_group_data, group_variables)
        fig.suptitle(video_info_str)
        if save:
            fig.savefig(f"{video_uid}_group.png")
            fig.savefig(f"{video_uid}_group.pdf")

        for animal_uid in video_indiv_nb_data[animal_uid_col].unique():
            animal_nb_data = video_indiv_nb_data[
                video_indiv_nb_data[animal_uid_col] == animal_uid
            ]
            fig = _plot_video_indiv_nb_variables_summary(
                animal_nb_data, individual_nb_variables
            )
            focal_nb_info_str = get_focal_nb_info(animal_nb_data)
            fig.suptitle(focal_nb_info_str)
            if save:
                fig.savefig(f"{animal_uid}_indiv_nb.png")
                fig.savefig(f"{animal_uid}_indiv_nb.pdf")

    else:
        print(f"Video {video_uid} does not exist")
        print("Possible videos are")
        print(possible_video_uids)

interactive(children=(Dropdown(description='video_uid', options=('ap1g1_1_1_1_1', 'ap1g1_1_1_1_2', 'ap1g1_1_1_…

# Line or Line replicate summary

In [26]:
partition_col = "line_exp"
if partition_col == "line_replicate":
    extra_info_cols = ["replicate"]
else:
    extra_info_cols = []
partition_info_cols = (
    [
        "experiment_type",
        "gene",
        "founder",
    ]
    + [partition_col]
    + extra_info_cols
)


def get_line_replicate_info_str(data, info_cols=partition_info_cols):
    return _get_info(data, info_cols)


possible_partition_uids = natsorted(data_group[partition_col].unique())


@interact(partition_uid=possible_partition_uids, save=False)
def summary_video(partition_uid, save=False):
    assert partition_col in data_group.columns
    if partition_uid in possible_partition_uids:
        indiv_data = data_indiv[data_indiv[partition_col] == partition_uid]
        indv_data_stats = data_indiv_stats[
            data_indiv_stats[partition_col] == partition_uid
        ]
        group_data = data_group[data_group[partition_col] == partition_uid]
        group_data_stats = data_group_stats[
            data_group_stats[partition_col] == partition_uid
        ]
        indiv_nb_data = data_indiv_nb[data_indiv_nb[partition_col] == partition_uid]
        indiv_nb_data_stats = data_indiv_nb_stats[
            data_indiv_nb_stats[partition_col] == partition_uid
        ]

        line_replicate_info_str = get_line_replicate_info_str(indiv_data)
        fig = _plot_partition_indiv_vars_summary(
            indiv_data,
            indv_data_stats,
            individual_variables,
            individual_variables_stats,
            hue="genotype_group_genotype",
        )
        fig.suptitle(line_replicate_info_str)
        if save:
            fig.savefig(f"{partition_uid}_indiv.png")
            fig.savefig(f"{partition_uid}_indiv.pdf")

        fig = _plot_partition_group_vars_summary(
            group_data, group_data_stats, group_variables, group_variables_stats, hue="genotype_group"
        )
        fig.suptitle(line_replicate_info_str)
        if save:
            fig.savefig(f"{partition_uid}_group.png")
            fig.savefig(f"{partition_uid}_group.pdf")
            
        fig = _plot_partition_indiv_nb_summary(indiv_nb_data, indiv_nb_data_stats, individual_nb_variables, individual_nb_variables_stats, hue="focal_nb_genotype")
        fig.suptitle(line_replicate_info_str)
        if save:
            fig.savefig(f"{partition_uid}_indiv_nb.png")
            fig.savefig(f"{partition_uid}_indiv_nb.pdf")



    else:
        print(f"Video {partition_uid} does not exist")
        print("Possible partition_uids are")
        print(possible_partition_uids)

interactive(children=(Dropdown(description='partition_uid', options=('ap1g1_1_1', 'apbb1_1_1', 'asap1b_7_1', '…