# Evaluations

This notebook is used to: 

1. ~~produce *baselines* for the GaussianProxy models (data vs data FID)~~ -> moved to `scripts/metrics_null_test.py`
2. load metrics computed on generated data and plot them against the baselines

---

# Imports

In [None]:
import concurrent.futures
import json
import math
import pickle
import sys
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from pprint import pprint

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import seaborn as sns
import torch
from IPython.display import HTML
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from torchvision.transforms import Compose, ConvertImageDtype, RandomHorizontalFlip, RandomVerticalFlip
from tqdm.notebook import tqdm

sys.path.insert(0, "..")
from GaussianProxy.utils.data import RandomRotationSquareSymmetry

In [None]:
torch.set_grad_enabled(False)

In [None]:
sns.set_theme(context="paper")

# Dataset

In [None]:
from my_conf.dataset.biotine_png_128_hard_aug_inference import dataset

In [None]:
assert dataset.dataset_params is not None
database_path = Path(dataset.path)
print(f"Using dataset {dataset.name} from {database_path}")
subdirs: list[Path] = [e for e in database_path.iterdir() if e.is_dir() and not e.name.startswith(".")]
subdirs.sort(key=dataset.dataset_params.sorting_func)
print(f"Found {len(subdirs)} times: {subdirs}")

# now split the dataset into 2 non-overlapping parts, respecting classes proportions...
# ...and repeat that 10 times to get std of the metric
is_flip_or_rotation = lambda t: isinstance(t, (RandomHorizontalFlip, RandomVerticalFlip, RandomRotationSquareSymmetry))
flips_rot = [t for t in dataset.transforms.transforms if is_flip_or_rotation(t)]

# with or without augmentations:
# transforms = Compose(flips_rot + [ConvertImageDtype(torch.uint8)])
transforms = Compose([ConvertImageDtype(torch.uint8)])

print(f"Using transforms:\n{transforms}")


def count_elements(subdir: Path):
    return subdir.name, len(list(subdir.glob(f"*.{dataset.dataset_params.file_extension}")))


with concurrent.futures.ThreadPoolExecutor() as executor:
    futures = {executor.submit(count_elements, subdir): subdir for subdir in subdirs}
    nb_elems_per_class = {}
    for future in tqdm(concurrent.futures.as_completed(futures), total=len(subdirs), desc="Counting elements per time"):
        subdir_name, count = future.result()
        nb_elems_per_class[subdir_name] = count
nb_elems_per_class["all_classes"] = sum(nb_elems_per_class.values())

print(f"Number of elements per class: {nb_elems_per_class}")

In [None]:
nb_repeats = 10

# FID

## Load train vs train (null test) FIDs

In [None]:
null_test_path = Path("evaluations", dataset.name, "eval_metrics.json")
assert null_test_path.exists(), f"File {null_test_path} does not exist"

In [None]:
with open(null_test_path, "r") as f:
    train_vs_train_eval_metrics = json.load(f)

class_names = sorted(train_vs_train_eval_metrics["exp_rep_0"].keys(), key=lambda x: int(x))
fid_scores_by_class_train = {class_name: [] for class_name in class_names}

for exp_rep in train_vs_train_eval_metrics.values():
    for class_name in class_names:
        fid_scores_by_class_train[class_name].append(exp_rep[class_name]["frechet_inception_distance"])

print("FID scores by class for train vs train:")
pprint(fid_scores_by_class_train)

## Load gen vs train FIDs

### Option 1: from `inference.py` script

In [None]:
saved_FIDs_generation_path = Path(
    "/",
    "projects",
    "static2dynamic",
    "Thomas",
    "experiments",
    "GaussianProxy",
    "biotine_all_paired_new_jz_MANUAL_WEIGHTS_DOWNLOAD_FROM_JZ_11-02-2025_14h31",  # <- change here
    "inferences",
    "MetricsComputation_100_diffsteps_no_SNR_leading_bf16_fixed",  # <- change here
    "all_procs_metrics_dict.pkl",
)
assert saved_FIDs_generation_path.exists(), f"File {saved_FIDs_generation_path} does not exist"
experiment_name = saved_FIDs_generation_path.parts[6]
inference_name = saved_FIDs_generation_path.parts[8]

In [None]:
# Extract FID scores on generated data vs training data
with open(saved_FIDs_generation_path, "rb") as f:
    fid_scores_by_class_gen = pickle.load(f)
pprint(fid_scores_by_class_gen)

### Option 2: load custom file, eg CSV from wandb logs...

In [None]:
import pandas as pd

df = pd.read_csv("tmp_downloaded_eval_values/wandb_export_2024-12-05T16_56_51.199+01_00.csv")
df

In [None]:
df = df[df["Step"] == 120000]
df

In [None]:
fid_scores_by_class_gen = df[[c for c in df.columns if c.endswith("frechet_inception_distance")]]
col_names = ["0.0003", "0.001", "0.003", "0.01", "0.03", "0.1", "0.3", "1.0", "all_classes"]
fid_scores_by_class_gen.columns = col_names
fid_scores_by_class_gen = fid_scores_by_class_gen.iloc[0].to_dict()
fid_scores_by_class_gen

In [None]:
for key, val in fid_scores_by_class_gen.items():
    fid_scores_by_class_gen[key] = {"frechet_inception_distance": val}
pprint(fid_scores_by_class_gen)

## Plot

In [None]:
plt.figure(figsize=(10, 6), dpi=300)
class_labels = [f"{class_name}\n({nb_elems_per_class[class_name]})" for class_name in class_names]
# null-test
plt.boxplot(
    [fid_scores_by_class_train[class_name] for class_name in class_names],
    tick_labels=class_labels,
    showfliers=True,
    flierprops=dict(marker="x", markersize=3),
    label="true data vs true data",
)
# gen values
y_gen = []
for cl_name in class_names:
    if cl_name in fid_scores_by_class_gen:
        fid = fid_scores_by_class_gen[cl_name]["frechet_inception_distance"]
        y_gen.append(fid)
    else:
        y_gen.append(np.nan)
plt.scatter(
    x=range(1, len(class_names) + 1),
    y=y_gen,
    label="generated data vs true data",
)
plt.xlabel("Class Name (total number of class elements)")
plt.ylabel("FID Score")
plt.title(f"Dataset: {dataset.name} | Experiment: {experiment_name}")
plt.suptitle("Intra-class FID score")
plt.xticks(rotation=45, ha="right")
plt.grid(axis="x")
plt.legend()
plt.tight_layout()
# plt.ylim(2, 90.5)
plt.figtext(0, 0.01, f"Inference strategy: {inference_name}", fontsize=8)
plt.savefig(f"evaluations/{dataset.name}/intra_class_fid_score_{experiment_name}_{inference_name}.png")
print(f"Figure saved to evaluations/{dataset.name}/intra_class_fid_score_{experiment_name}_{inference_name}.png")
plt.show()

In [None]:
fig = go.Figure()

# train
corresp_x_values = []
y_vals = []
for cl_name in class_names:
    corresp_x_values += [cl_name] * nb_repeats
    y_vals += fid_scores_by_class_train[cl_name]
fig.add_trace(
    go.Box(
        y=y_vals,
        x=corresp_x_values,
        name="train vs train",
    )
)

# gen
corresp_x_values = list(fid_scores_by_class_gen.keys())
y_vals = [fid_scores_by_class_gen[cl_name]["frechet_inception_distance"] for cl_name in corresp_x_values]
fig.add_trace(
    go.Box(
        y=y_vals,
        x=corresp_x_values,
        name="gen vs train",
    )
)

fig.update_layout(
    yaxis_title="FID Score",
    boxmode="group",
    xaxis_title="Class Name (total number of class elements)",
    title="Intra-class FID score",
    xaxis=dict(
        tickmode="array",
        tickvals=list(range(len(class_names))),
        ticktext=[f"{class_name}<br>({nb_elems_per_class[class_name]})" for class_name in class_names],
        tickangle=-45,
    ),
    height=700,
)
fig.show()

## Plot comparison of inference experiments

In [None]:
list_experiment_names = [
    "MetricsComputation_100_diffsteps_no_SNR_leading_adapt_half_aug",
    "MetricsComputation_100_diffsteps_with_SNR_trailing_adapt_half_aug",
    "MetricsComputation_100_diffsteps_no_SNR_leading_f32",
    "MetricsComputation_100_diffsteps_no_SNR_leading_bf16_fixed",
]

In [None]:
experiment_results = {}

for exp_name in list_experiment_names:
    parts = list(saved_FIDs_generation_path.parts)
    parts[8] = exp_name
    this_exp_path = Path(*parts)
    assert this_exp_path.exists(), f"File {this_exp_path} does not exist"

    # Extract FID scores on generated data vs training data
    with open(this_exp_path, "rb") as f:
        fid_scores_by_class_gen = pickle.load(f)
    experiment_results[exp_name] = {
        time: metrics["frechet_inception_distance"] for time, metrics in fid_scores_by_class_gen.items()
    }

experiment_results

In [None]:
plt.figure(figsize=(10, 6), dpi=300)
class_labels = [f"{class_name}\n({nb_elems_per_class[class_name]})" for class_name in class_names]
# null-test
plt.boxplot(
    [fid_scores_by_class_train[class_name] for class_name in class_names],
    tick_labels=class_labels,
    showfliers=True,
    flierprops=dict(marker="x", markersize=3),
    label="true data vs true data",
)
# gen values
for exp_name, exp_res in experiment_results.items():
    y_gen = []
    for cl_name in class_names:  # in order
        if cl_name in fid_scores_by_class_gen:
            fid = exp_res[cl_name]
            y_gen.append(fid)
        else:
            y_gen.append(np.nan)
    plt.scatter(
        x=range(1, len(class_names) + 1),
        y=y_gen,
        label=f"generated data vs true data - {exp_name}",
        marker="x",
    )
plt.xlabel("Class Name (total number of class elements)")
plt.ylabel("FID Score")
plt.title(f"Dataset: {dataset.name} | Experiment: {experiment_name}")
plt.suptitle("Intra-class FID score")
plt.xticks(rotation=45, ha="right")
plt.grid(axis="x")
plt.legend()
plt.tight_layout()
plt.savefig(f"evaluations/{dataset.name}/intra_class_fid_score_{experiment_name}_{inference_name}.png")
print(f"Figure saved to evaluations/{dataset.name}/intra_class_fid_score_{experiment_name}_{inference_name}.png")
plt.show()

# Compare reconstructions to original data of different scheduler configs

In [None]:
# tmp: compute L2
from PIL import Image

base_path = Path(
    "/projects/static2dynamic/Thomas/experiments/GaussianProxy/biotine_all_paired_new_jz_MANUAL_WEIGHTS_DOWNLOAD_FROM_JZ_11-02-2025_14h31/inferences"
)

experiment_1_path = (
    base_path / "InversionRegenerationOnly_test_scheduler_100_diffsteps_M_13_fld_3" / "regeneration_-1_1 raw.png"
)
experiment_1 = Image.open(experiment_1_path)

experiment_2_path = (
    base_path / "InversionRegenerationOnly_100_diffsteps_no_SNR_leading_f32" / "regeneration_-1_1 raw.png"
)
experiment_2 = Image.open(experiment_2_path)

reference_path = (
    base_path / "InversionRegenerationOnly_test_scheduler_100_diffsteps_M_13_fld_3" / "starting_samples_-1_1 raw.png"
)
reference = Image.open(reference_path)

experiment_1 = np.array(experiment_1)
experiment_2 = np.array(experiment_2)
reference = np.array(reference)

print(f"L2 between {experiment_1_path.parent.name} and reference: {np.linalg.norm(experiment_1 - reference)}")
print(f"L2 between {experiment_2_path.parent.name} and reference: {np.linalg.norm(experiment_2 - reference)}")
print(
    f"L2 between {experiment_1_path.parent.name} and {experiment_2_path.parent.name}: {np.linalg.norm(experiment_1 - experiment_2)}"
)

In [None]:
from PIL import Image


# Function to prepare image for display
def prepare_image(img_array: np.ndarray | Image.Image, max_size: int = 128) -> np.ndarray:
    if isinstance(img_array, np.ndarray):
        img = Image.fromarray(img_array.astype("uint8"))
    else:
        img = img_array
    img.thumbnail((max_size, max_size), Image.LANCZOS)
    return np.array(img)


# Function to place images as nodes
def add_image(ax: plt.Axes, img: np.ndarray | Image.Image, xy: tuple, zoom: float = 1.0) -> None:
    img_prepared = prepare_image(img)
    img_box = OffsetImage(img_prepared, zoom=zoom)
    ab = AnnotationBbox(img_box, xy, frameon=True, pad=0.1, bboxprops=dict(edgecolor="white", facecolor="white"))
    ax.add_artist(ab)


# Get experiment names from paths
exp1_name = experiment_1_path.parent.name
exp2_name = experiment_2_path.parent.name
ref_name = "reference"

# Calculate L2 distances
l2_no_zero_vs_ref = np.linalg.norm(experiment_1 - reference)
l2_zero_vs_ref = np.linalg.norm(experiment_2 - reference)
l2_no_zero_vs_zero = np.linalg.norm(experiment_1 - experiment_2)

# Normalize distances to fit within plotting area
scale_factor = 1.0 / max(l2_no_zero_vs_ref, l2_zero_vs_ref, l2_no_zero_vs_zero)
scaled_dist_1 = l2_no_zero_vs_ref * scale_factor
scaled_dist_2 = l2_zero_vs_ref * scale_factor
scaled_dist_3 = l2_no_zero_vs_zero * scale_factor

# Position nodes using proportional distances
# Place first two nodes
pos_1 = np.array([0, 0])  # no_zero_SNR_leading
pos_2 = np.array([scaled_dist_3, 0])  # zero_SNR_trailing


# Function to calculate the position of the third point
def get_third_point(pos_1: np.ndarray, pos_2: np.ndarray, d1: float, d2: float) -> np.ndarray:
    """Find position of third point given distances d1 and d2 from pos_1 and pos_2"""
    # Using law of cosines to find the angle
    cos_angle = (d1**2 + d2**2 - np.sum((pos_2 - pos_1) ** 2)) / (2 * d1 * d2)
    cos_angle = min(1, max(-1, cos_angle))  # Ensure within valid range
    angle = math.acos(cos_angle)

    # Distance from pos_1
    x = d1 * math.cos(angle)
    y = d1 * math.sin(angle)

    return np.array([x, y]) + pos_1


# Find position for the reference point
pos_3 = get_third_point(pos_1, pos_2, scaled_dist_1, scaled_dist_2)

# Create positions dictionary with dynamic names from paths
positions = {exp1_name: pos_1, exp2_name: pos_2, ref_name: pos_3}

# Create figure and axes
fig, ax = plt.subplots(figsize=(12, 10))

# Add images at vertices
add_image(ax, experiment_1, positions[exp1_name])
add_image(ax, experiment_2, positions[exp2_name])
add_image(ax, reference, positions[ref_name])

# Draw lines representing distances
for p1, p2, dist in [
    (exp1_name, ref_name, l2_no_zero_vs_ref),
    (exp2_name, ref_name, l2_zero_vs_ref),
    (exp1_name, exp2_name, l2_no_zero_vs_zero),
]:
    ax.plot([positions[p1][0], positions[p2][0]], [positions[p1][1], positions[p2][1]], "k-", alpha=0.7)

    # Add distance labels
    midpoint = ((positions[p1][0] + positions[p2][0]) / 2, (positions[p1][1] + positions[p2][1]) / 2)
    ax.annotate(
        f"L2 = {dist:.2e}",
        midpoint,
        ha="center",
        va="center",
        bbox=dict(boxstyle="round,pad=0.3", fc="white", alpha=0.8),
    )

# Add node labels with adaptive positioning based on triangle shape
offset = 0.15
# Adjust label positions based on point locations
for name, pos in positions.items():
    if name == ref_name:
        # Place above or below depending on y-coordinate
        if pos[1] > 0:
            ax.text(pos[0], pos[1] + offset, name, ha="center", va="bottom", fontsize=10)
        else:
            ax.text(pos[0], pos[1] - offset, name, ha="center", va="top", fontsize=10)
    else:
        # For the bottom points
        ax.text(pos[0], pos[1] - offset, name, ha="center", va="top", fontsize=10)

# Set limits and turn off axes
all_x = [p[0] for p in positions.values()]
all_y = [p[1] for p in positions.values()]
x_margin, y_margin = 0.5, 0.5
ax.set_xlim(min(all_x) - x_margin, max(all_x) + x_margin)
ax.set_ylim(min(all_y) - y_margin, max(all_y) + y_margin)
ax.axis("off")

plt.title("Image Comparison Triangle with Proportional L2 Distances")
plt.tight_layout()
plt.show()

# Compare histograms of true vs generated data along time

(sanity check)

## Chose experiment

In [None]:
gen_data_path = saved_FIDs_generation_path.parent
print(f"Using generated data at {gen_data_path}")
print("Project:", gen_data_path.parts[-3])
print("Inference experiment:", gen_data_path.parts[-1])

## Helper funcs

In [None]:
def find_files_in_subdir(subdir: Path, extension: str) -> list[Path]:
    files = []
    ext = f".{extension}"

    if not subdir.exists():
        print(f"Subdirectory {subdir} does not exist.")
        return files

    for entry in subdir.iterdir():
        if entry.is_file() and entry.suffix == ext:
            files.append(entry)
        elif entry.is_dir():
            print(f"Ignoring sub-subdirectory: {entry}")

    return files


def parallel_file_search(base_path: Path, extension: str, subdirs: list[str]) -> list[Path]:
    # Get all files in parallel
    all_files = []
    with ProcessPoolExecutor() as executor:
        futures = {executor.submit(find_files_in_subdir, base_path / subdir, extension): subdir for subdir in subdirs}

        for future in tqdm(as_completed(futures), total=len(futures), desc="Searching directories"):
            all_files.extend(future.result())

    return all_files


common_times_to_use = set(class_names) & set(
    d.name for d in gen_data_path.iterdir() if d.is_dir() and d.name in class_names
)
print(f"Using common times: {common_times_to_use} from {class_names}")


def get_image_RGB_histogram(
    image_path: Path, bins: int, expected_shape: tuple[int, int, int], expected_dtype: np.dtype
):
    image = Image.open(image_path)
    image = np.array(image)
    assert image.shape == expected_shape, (
        f"Expected shape {expected_shape}, but got {image.shape} for image {image_path}"
    )
    assert image.dtype == expected_dtype, (
        f"Expected dtype {expected_dtype}, but got {image.dtype} for image {image_path}"
    )
    r = image[:, :, 0]
    g = image[:, :, 1]
    b = image[:, :, 2]
    r_hist, _ = np.histogram(r, bins=bins, range=(0, 255))
    g_hist, _ = np.histogram(g, bins=bins, range=(0, 255))
    b_hist, _ = np.histogram(b, bins=bins, range=(0, 255))
    return r_hist, g_hist, b_hist

## Find generated files

run this on each new inference experiment

In [None]:
# Generated patches
print(f"Using generated dataset at {gen_data_path}")
print("Searching for generated files...")
all_gen_patches = parallel_file_search(gen_data_path, dataset.dataset_params.file_extension, common_times_to_use)
print(f"Found {len(all_gen_patches)} generated patches")

## Compute and save true data histograms

run this only once to compute & save the true data histograms

In [None]:
# True patches
not_augmented_ds_path = database_path.with_name(database_path.name.rstrip("_hard_augmented"))
print(f"Using true, not augmented dataset at {not_augmented_ds_path}")
print("Searching for true files...")
# all_true_patches = list(database_path.rglob(f"*.{dataset.dataset_params.file_extension}"))
# all_true_patches = fast_find_files(str(database_path), dataset.dataset_params.file_extension)
all_true_patches = parallel_file_search(database_path, dataset.dataset_params.file_extension, common_times_to_use)
print(f"Found {len(all_true_patches)} true patches")

assert len(all_gen_patches) == len(all_true_patches) // 2, "Number of generated patches should be half of the true ones"

# Get & save true patches histograms
true_tot_r_hist = {str(time): np.zeros(256) for time in common_times_to_use}
true_tot_g_hist = {str(time): np.zeros(256) for time in common_times_to_use}
true_tot_b_hist = {str(time): np.zeros(256) for time in common_times_to_use}

futures = {}
with ProcessPoolExecutor() as executor:
    for image_path in all_true_patches:
        time = image_path.parent.name
        assert time in common_times_to_use, f"Time {time} not in common times to use"
        futures[executor.submit(get_image_RGB_histogram, image_path, 256, (255, 255, 3), np.dtype(np.uint8))] = time

    for future in tqdm(as_completed(futures), total=len(futures), desc="Computing true histograms"):
        time = futures[future]
        r_hist, g_hist, b_hist = future.result()
        # sum histograms
        true_tot_r_hist[time] += r_hist
        true_tot_g_hist[time] += g_hist
        true_tot_b_hist[time] += b_hist

# Pickle save histograms under 'evaluations/<dataset_name>/histograms/'
Path("evaluations", dataset.name, "histograms").mkdir(parents=True, exist_ok=True)
with open(f"evaluations/{dataset.name}/histograms/true_tot_r_hist.pkl", "wb") as f:
    pickle.dump(true_tot_r_hist, f)
with open(f"evaluations/{dataset.name}/histograms/true_tot_g_hist.pkl", "wb") as f:
    pickle.dump(true_tot_g_hist, f)
with open(f"evaluations/{dataset.name}/histograms/true_tot_b_hist.pkl", "wb") as f:
    pickle.dump(true_tot_b_hist, f)
print(f"Saved true histograms to evaluations/{dataset.name}/histograms/")

## Compute and save generated data histograms

run this on each new inference experiment

In [None]:
# Get & save generated patches histograms
gen_tot_r_hist = {str(time): np.zeros(256) for time in common_times_to_use}
gen_tot_g_hist = {str(time): np.zeros(256) for time in common_times_to_use}
gen_tot_b_hist = {str(time): np.zeros(256) for time in common_times_to_use}

futures = {}
with ProcessPoolExecutor() as executor:
    for image_path in all_gen_patches:
        time = image_path.parent.name
        assert time in common_times_to_use, f"Time {time} not in common times to use"
        futures[executor.submit(get_image_RGB_histogram, image_path, 256, (128, 128, 3), np.dtype(np.uint8))] = time

    for future in tqdm(as_completed(futures), total=len(futures), desc="Computing gen histograms"):
        time = futures[future]
        r_hist, g_hist, b_hist = future.result()
        # sum histograms
        gen_tot_r_hist[time] += r_hist
        gen_tot_g_hist[time] += g_hist
        gen_tot_b_hist[time] += b_hist

# Pickle save histograms under 'evaluations/<dataset_name>/histograms/'
Path(gen_data_path, "histograms").mkdir(exist_ok=True)
with open(f"{gen_data_path}/histograms/gen_tot_r_hist.pkl", "wb") as f:
    pickle.dump(gen_tot_r_hist, f)
with open(f"{gen_data_path}/histograms/gen_tot_g_hist.pkl", "wb") as f:
    pickle.dump(gen_tot_g_hist, f)
with open(f"{gen_data_path}/histograms/gen_tot_b_hist.pkl", "wb") as f:
    pickle.dump(gen_tot_b_hist, f)
print(f"Saved generated histograms to {gen_data_path}/histograms/")

In [None]:
def create_histogram_animation(y_lims: tuple[float, float, float]) -> animation.FuncAnimation:
    # Load histograms
    with open(true_histogram_path / "true_tot_r_hist.pkl", "rb") as f:
        true_r_hist = pickle.load(f)
    with open(true_histogram_path / "true_tot_g_hist.pkl", "rb") as f:
        true_g_hist = pickle.load(f)
    with open(true_histogram_path / "true_tot_b_hist.pkl", "rb") as f:
        true_b_hist = pickle.load(f)

    # Load generated histograms
    with open(gen_histogram_path / "gen_tot_r_hist.pkl", "rb") as f:
        gen_r_hist = pickle.load(f)
    with open(gen_histogram_path / "gen_tot_g_hist.pkl", "rb") as f:
        gen_g_hist = pickle.load(f)
    with open(gen_histogram_path / "gen_tot_b_hist.pkl", "rb") as f:
        gen_b_hist = pickle.load(f)

    # Get all time points
    times = sorted(gen_r_hist.keys(), key=lambda x: int(x))

    # Create figure and axes
    fig, axes = plt.subplots(3, 3, figsize=(30, 10), dpi=300)
    fig.suptitle("RGB Histograms: True vs Generated", fontsize=16)

    # Normalize histograms for better visualization
    def normalize_hist(hist: np.ndarray) -> np.ndarray:
        return hist / np.sum(hist) if np.sum(hist) > 0 else hist

    # Animation update function
    def update(frame: int) -> list:
        time = times[frame]

        # Clear all axes
        for row in axes:
            for ax in row:
                ax.clear()

        # Remove any existing time text
        # Find and remove text at the bottom position
        for text in fig.texts:
            if text._y < 0.05:  # Text near the bottom
                text.remove()

        # Set titles for columns
        axes[0, 0].set_title("True Data")
        axes[0, 1].set_title("Generated Data")
        axes[0, 2].set_title("Difference")

        # Define x-axis (pixel values)
        x = np.arange(256)

        # Plot R channel
        axes[0, 0].bar(x, normalize_hist(true_r_hist[time]), color="red", alpha=0.7)
        axes[0, 1].bar(x, normalize_hist(gen_r_hist[time]), color="red", alpha=0.7)
        axes[0, 2].bar(x, normalize_hist(true_r_hist[time]) - normalize_hist(gen_r_hist[time]), color="red", alpha=0.7)
        axes[0, 0].set_ylabel("R Channel")

        # Plot G channel
        axes[1, 0].bar(x, normalize_hist(true_g_hist[time]), color="green", alpha=0.7)
        axes[1, 1].bar(x, normalize_hist(gen_g_hist[time]), color="green", alpha=0.7)
        axes[1, 2].bar(
            x, normalize_hist(true_g_hist[time]) - normalize_hist(gen_g_hist[time]), color="green", alpha=0.7
        )
        axes[1, 0].set_ylabel("G Channel")

        # Plot B channel
        axes[2, 0].bar(x, normalize_hist(true_b_hist[time]), color="blue", alpha=0.7)
        axes[2, 1].bar(x, normalize_hist(gen_b_hist[time]), color="blue", alpha=0.7)
        axes[2, 2].bar(x, normalize_hist(true_b_hist[time]) - normalize_hist(gen_b_hist[time]), color="blue", alpha=0.7)
        axes[2, 0].set_ylabel("B Channel")

        # Set x-axis labels
        for ax in axes[2, :]:
            ax.set_xlabel("Pixel Value")

        # Adjust y-axis limits for better visualization
        for row_index, row in enumerate(axes):
            for ax in row:
                ax.set_ylim(0, y_lims[row_index])

        # Add text indicating the current time point
        fig.text(0.5, 0.01, f"Current time point: {time}", ha="center", fontsize=12)

        fig.tight_layout(rect=(0, 0.03, 1, 0.95))  # Make room for title and text

        return axes.flatten()

    # Create animation
    anim = animation.FuncAnimation(fig, update, frames=len(times), interval=1500, blit=False)

    return anim


true_histogram_path = Path("evaluations", dataset.name, "histograms")
gen_histogram_path = Path(gen_data_path, "histograms")

# Create and display the animation
anim = create_histogram_animation((0.1, 0.05, 0.25))

# To save the animation (uncomment to use)
save_path = Path(gen_histogram_path / "histogram_animation.mp4")
anim.save(save_path, writer="ffmpeg", fps=1, dpi=300)
print(f"Animation saved to {save_path}")

# Display in notebook
HTML(anim.to_jshtml())

# Experiment: Different splits vs only augs vs diff splits no augs

## Load aug FIDs

In [None]:
with open(
    "/workspaces/biocomp/tboyer/sources/GaussianProxy/notebooks/evaluations/BBBC021_196_docetaxel/eval_metrics_TEST_REPS_WITH_AUGS.json",
    "r",
) as f:
    eval_augs_metrics = json.load(f)
eval_augs_metrics

In [None]:
fid_scores_by_class_train_augs_only: dict[str, list[float]] = {}

for class_name in class_names:
    fid_scores_by_class_train_augs_only[class_name] = [
        eval_augs_metrics[str(idx)][class_name]["frechet_inception_distance"] for idx in range(nb_repeats)
    ]

fid_scores_by_class_train_augs_only

## load no augs FIDs

In [None]:
with open(
    "/workspaces/biocomp/tboyer/sources/GaussianProxy/notebooks/evaluations/BBBC021_196_docetaxel/eval_metrics_TEST_REPS_NO_AUGS.json",
    "r",
) as f:
    eval_no_augs_metrics = json.load(f)
eval_no_augs_metrics

In [None]:
fid_scores_by_class_train_no_augs: dict[str, list[float]] = {}

for class_name in class_names:
    fid_scores_by_class_train_no_augs[class_name] = [
        eval_no_augs_metrics[f"exp_rep_{idx}"][class_name]["frechet_inception_distance"] for idx in range(nb_repeats)
    ]

fid_scores_by_class_train_no_augs

## Load hard augmentations

In [None]:
with open(
    "/workspaces/biocomp/tboyer/sources/GaussianProxy/notebooks/evaluations/BBBC021_196_docetaxel/eval_metrics_TEST_HARD_AUGS.json",
    "r",
) as f:
    eval_hard_augs_metrics = json.load(f)
eval_hard_augs_metrics

In [None]:
fid_scores_by_class_train_hard_augs: dict[str, list[float]] = {}

for class_name in class_names:
    fid_scores_by_class_train_hard_augs[class_name] = [
        eval_hard_augs_metrics[f"repeat_{idx}"][class_name]["frechet_inception_distance"] for idx in range(nb_repeats)
    ]

fid_scores_by_class_train_hard_augs

## Plot

In [None]:
plt.figure(figsize=(12, 6), dpi=300)

n_classes = len(class_names)
positions1 = np.arange(1, n_classes + 1) - 0.25
positions2 = np.arange(1, n_classes + 1)
positions3 = np.arange(1, n_classes + 1) + 0.25
positions4 = positions2  # values are very different so ok to "overlap" on the x axis

class_labels = [f"{class_name}\n({nb_elems_per_class[class_name]})" for class_name in class_names]
bar_width = 0.1
text_offset_y = 0.1
text_offset_x = 0.12

for i, class_name in enumerate(class_names):
    # First group (blue)
    plt.scatter(
        np.full_like(fid_scores_by_class_train[class_name], positions1[i]),
        fid_scores_by_class_train[class_name],
        alpha=0.3,
        color="blue",
        s=20,
        label="diff splits   | o-t-f augs" if i == 0 else "",
    )
    median1 = np.median(fid_scores_by_class_train[class_name])
    plt.hlines(median1, positions1[i] - bar_width, positions1[i] + bar_width, colors="blue", alpha=0.8, linewidth=2)
    plt.text(
        positions1[i] + text_offset_x,
        median1 + text_offset_y,
        f"{median1:.1f}",
        color="blue",
        ha="center",
        va="bottom",
        alpha=0.8,
        fontsize=6,
    )

    # Second group (orange)
    plt.scatter(
        np.full_like(fid_scores_by_class_train_augs_only[class_name], positions2[i]),
        fid_scores_by_class_train_augs_only[class_name],
        alpha=0.3,
        color="orange",
        s=20,
        label="same split | o-t-f augs" if i == 0 else "",
    )
    median2 = np.median(fid_scores_by_class_train_augs_only[class_name])
    plt.hlines(median2, positions2[i] - bar_width, positions2[i] + bar_width, colors="orange", alpha=0.8, linewidth=2)
    plt.text(
        positions2[i] + text_offset_x,
        median2 + text_offset_y,
        f"{median2:.1f}",
        color="orange",
        ha="center",
        va="bottom",
        alpha=0.8,
        fontsize=6,
    )

    # Third group (green)
    plt.scatter(
        np.full_like(fid_scores_by_class_train_no_augs[class_name], positions3[i]),
        fid_scores_by_class_train_no_augs[class_name],
        alpha=0.3,
        color="green",
        s=20,
        label="diff splits   | no augs" if i == 0 else "",
    )
    median3 = np.median(fid_scores_by_class_train_no_augs[class_name])
    plt.hlines(median3, positions3[i] - bar_width, positions3[i] + bar_width, colors="green", alpha=0.8, linewidth=2)
    plt.text(
        positions3[i] + text_offset_x,
        median3 + text_offset_y,
        f"{median3:.1f}",
        color="green",
        ha="center",
        va="bottom",
        alpha=0.8,
        fontsize=6,
    )

    # Fourth group (red)
    plt.scatter(
        np.full_like(fid_scores_by_class_train_hard_augs[class_name], positions4[i]),
        fid_scores_by_class_train_hard_augs[class_name],
        alpha=0.3,
        color="red",
        s=20,
        label="diff splits   | hard augs" if i == 0 else "",
    )
    median4 = np.median(fid_scores_by_class_train_hard_augs[class_name])
    plt.hlines(median4, positions4[i] - bar_width, positions4[i] + bar_width, colors="red", alpha=0.8, linewidth=2)
    plt.text(
        positions4[i] + text_offset_x,
        median4 + text_offset_y,
        f"{median4:.1f}",
        color="red",
        ha="center",
        va="bottom",
        alpha=0.8,
        fontsize=6,
    )

plt.xlabel("Class Name (total number of class elements)")
plt.ylabel("FID Score")
plt.title("Intra-class true data vs true data FID score")
plt.xticks(range(1, n_classes + 1), class_labels, rotation=45, ha="right")
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.legend()
plt.tight_layout()
plt.show()