In [16]:
import os
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

def load_all_versions(log_dir):
    """Load all metrics.csv files from versioned experiment directories."""
    metrics_df = pd.DataFrame()
    
    for version_dir in os.listdir(log_dir):
        version_path = os.path.join(log_dir, version_dir)
        
        # Check if the directory is a "version_*" folder
        if os.path.isdir(version_path) and version_dir.startswith("version_"):
            metrics_file = os.path.join(version_path, "metrics.csv")
            if os.path.exists(metrics_file):
                version_df = pd.read_csv(metrics_file)
                version_df['version'] = version_dir  # Add version column
                metrics_df = pd.concat([metrics_df, version_df], ignore_index=True)
    
    return metrics_df

def process_metrics(df):
    """Filter out NaNs and compute the mean train_loss_step per epoch."""
    df_filtered = df.dropna(subset=["train_loss_step"])  # Remove NaN rows
    mean_train_loss = df_filtered.groupby("epoch")["train_loss_step"].mean().reset_index()
    mean_train_loss.rename(columns={"train_loss_step": "train_loss_epoch"}, inplace=True)
    return mean_train_loss

def save_versions_to_csv(df, csv_dir):
    """Save processed metrics to CSV."""
    os.makedirs(csv_dir, exist_ok=True)
    df.to_csv(os.path.join(csv_dir, "metrics_versions.csv"), index=False)

def create_dir(plot_dir):
    """Ensure plot directory exists."""
    os.makedirs(plot_dir, exist_ok=True)

def plot_loss(metrics_df, plot_dir='plots/', title="Avg. Training Loss per Epoch - All Versions", filename="avg_loss_all_versions"):
    """Plot mean training loss per epoch."""
    create_dir(plot_dir)
    plt.figure(figsize=(10, 5))
    sns.lineplot(data=metrics_df, x="epoch", y="train_loss_epoch", marker="o")
    plt.title(title)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.savefig(os.path.join(plot_dir, f"{filename}.png"))
    plt.close()

current_path = os.getcwd()
log_dir = os.path.join(current_path, "model_saves")

In [17]:
# For all models in model_saves, go into it's directory and feed it to the load_all_versions function
for model_dir in os.listdir(log_dir):
    print(f"Processing {model_dir}")
    model_path = os.path.join(log_dir, model_dir)
    metrics_df = load_all_versions(model_path)
    if(metrics_df.empty):
        if "logs" in os.listdir(model_path):
            metrics_df = load_all_versions(os.path.join(log_dir, os.path.join(model_path, "logs")))
        elif "versions" in os.listdir(model_path):
            metrics_df = load_all_versions(os.path.join(log_dir, os.path.join(model_path, "versions")))
    if not metrics_df.empty:
        metrics_df = process_metrics(metrics_df)
        save_versions_to_csv(metrics_df, model_path)
        plot_loss(metrics_df, plot_dir=model_path)
        print(f"Processed {model_dir}")
    else:
        print(f"No metrics found for {model_dir}")

Processing dino
No metrics found for dino
Processing dino_gated_vit_audio_cnn_image_20250305
Processed dino_gated_vit_audio_cnn_image_20250305
Processing dino_unimodal_audio_20250310_161006
Processed dino_unimodal_audio_20250310_161006
Processing dino_unimodal_audio_clean_20250311_114502
Processed dino_unimodal_audio_clean_20250311_114502
Processing dino_unimodal_audio_clean_20250311_134951
Processed dino_unimodal_audio_clean_20250311_134951
Processing dino_unimodal_audio_clean_20250318_210013
No metrics found for dino_unimodal_audio_clean_20250318_210013
Processing dino_unimodal_audio_clean_20250318_210159
Processed dino_unimodal_audio_clean_20250318_210159
Processing dino_unimodal_audio_clean_20250318_212401
Processed dino_unimodal_audio_clean_20250318_212401
Processing dino_unimodal_audio_clean_augmentationsv2_20250311_161214
Processed dino_unimodal_audio_clean_augmentationsv2_20250311_161214
Processing dino_unimodal_audio_clean_mobilevit_20250312_091158
Processed dino_unimodal_audi