In [None]:
"""Workbook to format regularization tests data."""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import annotations

import itertools
from pathlib import Path

import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
base_dir = Path.home() / "Projects/epiclass/output/paper"
base_data_dir = base_dir / "data"
base_fig_dir = base_dir / "figures"
paper_dir = base_dir

In [None]:
reg_data_dir = (
    base_data_dir
    / "training_results/dfreeze_v2/hg38_100kb_all_none/harmonized_donor_sex_1l_3000n/regularization_tests"
)
if not reg_data_dir.exists():
    raise FileNotFoundError(f"Directory {reg_data_dir} does not exist.")

In [None]:
comet_run_metadata = pd.read_csv(
    base_data_dir
    / "training_results"
    / "all_results_cometml_filtered_oversampling-fixed.csv"
)

## Acquire/Merge regularization runs weight data

In [None]:
# Initialize the list to store individual dataframes
data_frames = []

# Iterate through the directories in reg_data_dir
for folder in reg_data_dir.iterdir():
    if not folder.is_dir():
        continue

    split_folder = folder / "split0"
    weights_data_path = next(split_folder.glob("*weights_description.csv"))

    # Read the CSV, skipping the first row and setting the second row as columns
    weights_df = pd.read_csv(weights_data_path, skiprows=1, names=["metric", "value"])

    exp_folder = split_folder / "EpiLaP"
    exp_key = next(exp_folder.glob("*")).name

    # Transpose the dataframe and add a column for the experiment key
    weights_df = weights_df.set_index("metric").T
    weights_df["experimentKey"] = exp_key
    weights_df["folder_name"] = folder.name

    # Append the dataframe to the list
    data_frames.append(weights_df)

# Combine all dataframes into one final dataframe
reg_data_df = pd.concat(data_frames, ignore_index=True)
reg_data_df.columns.name = None

In [None]:
reg_runs_df = pd.merge(reg_data_df, comet_run_metadata, on="experimentKey", how="left")
reg_runs_df["hparams/dropout"] = 1 - reg_runs_df["hparams/keep_prob"]

In [None]:
summary_df = reg_runs_df[
    [
        "experimentKey",
        "folder_name",
        "hparams/dropout",
        "hparams/l1_scale",
        "hparams/l2_scale",
        "val_Accuracy",
        "val_F1Score",
    ]
    + list(reg_data_df.columns)[:-1]
]
reg_runs_df.to_csv(reg_data_dir / "weights_detail.csv", index=False)
summary_df.to_csv(reg_data_dir / "weights_detail_summary.csv", index=False)

## Weight distribution figure

In [None]:
# Initialize lists to store hyperparameters and image paths
hyperparams = []
images = []

# Iterate through the directories in reg_data_dir
for folder in reg_data_dir.iterdir():
    if not folder.is_dir():
        continue

    # Get hyperparam values
    sub_df = reg_runs_df[reg_runs_df["folder_name"] == folder.name]
    dropout = sub_df["hparams/dropout"].values[0]
    l1_scale = sub_df["hparams/l1_scale"].values[0]
    l2_scale = sub_df["hparams/l2_scale"].values[0]
    if l2_scale > 0:
        continue

    # Find the PNG image
    split_folder = folder / "split0"
    png_path = next(split_folder.glob("*.png"))

    # Store the hyperparameters and image path
    hyperparams.append((dropout, l1_scale))
    images.append(png_path)

# Convert hyperparams list to a DataFrame for easy handling
hyperparams_df = pd.DataFrame(hyperparams, columns=["dropout", "l1_scale"])

# Determine the unique values and grid size
unique_dropouts = [f"{val:.2f}" for val in sorted(hyperparams_df["dropout"].unique())]
unique_l1_scales = sorted(hyperparams_df["l1_scale"].unique())

# Create a figure with subplots
fig, axes = plt.subplots(len(unique_dropouts), len(unique_l1_scales), figsize=(15, 5))

# Plot each image in the corresponding subplot
for idx, (dropout, l1_scale) in enumerate(hyperparams):
    dropout = f"{dropout:.2f}"
    img = mpimg.imread(images[idx])
    row = unique_dropouts.index(dropout)
    col = unique_l1_scales.index(l1_scale)
    ax = axes[row, col]
    ax.imshow(img, aspect=0.9)
    ax.set_title(f"(D,L1): ({dropout}, {l1_scale:g})")

for i, j in itertools.product(range(len(unique_dropouts)), range(len(unique_l1_scales))):
    axes[i, j].axis("off")

# Adjust layout and show the figure
plt.tight_layout()
plt.savefig(reg_data_dir / "regularization_tests.png", dpi=400)
plt.savefig(reg_data_dir / "regularization_tests.svg", dpi=400)