Main notebook containing relevant analysis steps, run for each ensemble. 

The script `notebook_per_ensemble.py' automatically copies this notebook to an ensemble directory and executes it for newly trained ensembles using papermill.

**Warning:** You can loose your work! Don't edit automatically created copies of this notebook within an ensemble directory. Those will be overwritten at a rerun. Create a copy instead.

**Warning:** This notebook is not intended for standalone use. It is automatically copied to an ensemble directory and executed for newly trained ensembles using papermill. Adapt mindfully.


In [None]:
import logging

import matplotlib as mpl
import matplotlib.pyplot as plt

from flyvis import EnsembleView
from flyvis.analysis.moving_bar_responses import plot_angular_tuning
from flyvis.analysis.visualization.plt_utils import add_cluster_marker, get_marker

logging.disable()


mpl.rcParams["figure.dpi"] = 300

%load_ext autoreload
%autoreload 2

In [None]:
ensemble_name = "flow/0001"  # type: str

In [None]:
validation_subdir = "validation"
loss_file_name = "epe"

In [None]:
ensemble = EnsembleView(
    ensemble_name,
    best_checkpoint_fn_kwargs={
        "validation_subdir": validation_subdir,
        "loss_file_name": loss_file_name,
    },
)

In [None]:
print(f"Description of experiment: {getattr(ensemble[0].dir.config, 'description', '')}")

# Task performance

## Training and validation losses

In [None]:
fig, ax = ensemble.training_loss()

In [None]:
fig, ax = ensemble.validation_loss()

In [None]:
fig, ax = ensemble.task_error_histogram()

## Learned parameter marginals

In [None]:
fig, axes = ensemble.node_parameters("bias")

In [None]:
fig, axes = ensemble.node_parameters("time_const")

In [None]:
fig, axes = ensemble.edge_parameters("syn_strength")

## Dead or alive

In [None]:
fig, ax, cbar, matrix = ensemble.dead_or_alive()

## Contrast selectivity and flash response indices (FRI)

#### 20% best task-performing models

In [None]:
with ensemble.ratio(best=0.2):
    ensemble.flash_response_index()

#### 100% models

In [None]:
fig, ax = ensemble.flash_response_index()

## Motion selectivity and direction selectivity index (DSI)

#### 20% best task-performing models

In [None]:
with ensemble.ratio(best=0.2):
    ensemble.direction_selectivity_index()

#### 100% models

In [None]:
ensemble.direction_selectivity_index()

## Clustering of models based on responses to naturalistic stimuli

#### T4c

In [None]:
task_error = ensemble.task_error()
embeddingplot = ensemble.clustering("T4c").plot(
    task_error=task_error.values, colors=task_error.colors
)

In [None]:
r = ensemble.moving_edge_responses()

In [None]:
cluster_indices = ensemble.cluster_indices("T4c")

In [None]:
colors = ensemble.task_error().colors

In [None]:
fig, axes = plt.subplots(
    1, len(cluster_indices), subplot_kw={"projection": "polar"}, figsize=[2, 1]
)
for cluster_id, indices in cluster_indices.items():
    plot_angular_tuning(
        r.sel(network_id=indices),
        "T4c",
        intensity=1,
        colors=colors[indices],
        zorder=ensemble.zorder()[indices],
        groundtruth=True if cluster_id == 0 else False,
        fig=fig,
        ax=axes[cluster_id],
    )
    add_cluster_marker(fig, axes[cluster_id], marker=get_marker(cluster_id))