In [None]:
%matplotlib notebook

In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
from utilities import *

from matplotlib import pyplot as plt
import seaborn as sns
sns.set(style="ticks", palette="deep")
cols = sns.color_palette()

In [None]:
def get_summary_df(trial_stats_table):
    trial_stats_table = trial_stats_table.copy()
    N_TRIALS_EXCLUDE = 10  # Number of initial abituation trials to remove from the statistics

    # Calculate median of computed statistics after excluding abituating trials
    table = trial_stats_table[N_TRIALS_EXCLUDE:].groupby("spatial_period").median()
    
    # Calculate fraction of trials with at least one bout
    trial_stats_table["swimmed_fract"] = (trial_stats_table["bout_n"] > 0).values.astype(np.float)
    table["swimmed_fract"] = trial_stats_table[N_TRIALS_EXCLUDE:].groupby("spatial_period").mean()["swimmed_fract"]
    return table

# Load all experiments:

Change this folder to point to the folder containing all individual fish data

In [None]:
# Specify name of the subdirectories with the ablation and control data:
folder_dict_names = dict(ntr="ntr ablated random spatial frequency",
                         cnt="control random spatial frequency")

for group in ["ntr", "cnt"]:
    print("Analysing ", group)
    
    master_path = Path(r"./{}".format(folder_dict_names[group]))
    paths = list(master_path.glob('*_f[0-9]'))
    exps = [Experiment(path) for path in paths]
    genotypes = [e["general"]["animal"]["comments"] for e in exps]  # animal genotypes

    # List of trial-wise bout statistics for all fish:
    trial_stats = [get_exp_stats(exp, get_spatial_period=True) for exp in exps]

    # Exclude initial 10 trials and calculate median across spatial periods for each fish:
    aggregate = [get_summary_df(s) for s in trial_stats]


    # Get summary for desired statistics from the aggregate values and save in excel file
    for param in ["bout_n", "first_bout_latency", "swimmed_fract"]:
        summary = pd.concat([aggr[param].rename(path.name) for aggr, path in zip(aggregate, paths)], axis=1)
        summary.to_excel(str(master_path / "{}_{}_summary.xlsx".format(param, group)))
        
    
    # Make figures:
    figure_saving_path = master_path # Path(r".\python_figures")
    figure_saving_path.mkdir(exist_ok=True)

    for param in ["bout_n", "first_bout_latency", "swimmed_fract"]:
        summary = pd.concat([aggr[param].rename(path.name) for aggr, path in zip(aggregate, paths)], axis=1)
        f = plt.figure(figsize=(4,3))
        plt.plot(summary,  linewidth=0.5)

        quart1 = np.percentile(summary.values, 25, axis=1)
        median = np.percentile(summary.values, 50, axis=1)
        quart2 = np.percentile(summary.values, 75, axis=1)

        plt.errorbar(summary.index, median, [median-quart1, quart2-median],  linewidth=2, color="k")
        plt.ylabel(param)
        plt.xlabel("Spatial period (mm)")
        sns.despine()
        plt.tight_layout()
        f.savefig(str(figure_saving_path / f"{param}_{group}.pdf"), format="pdf")