In [2]:
import os
import torch
import pandas as pd

# Base directory containing all ablation study folders
base_dir = "results_ablation_studies"  

# List all model folders (e.g., A1_Transformer_Priors, A2_GAT_Priors, etc.)
model_folders = [
    f for f in os.listdir(base_dir)
    if os.path.isdir(os.path.join(base_dir, f)) and f.startswith("A")
]

# Data collection
all_data = []

for model_name in model_folders:
    ckpt_dir = os.path.join(base_dir, model_name, "checkpoints")
    if not os.path.isdir(ckpt_dir):
        continue  # Skip if no checkpoints folder

    # List all checkpoint files
    checkpoints = sorted(
        [f for f in os.listdir(ckpt_dir) if f.startswith("checkpoint_epoch_") and f.endswith(".pth")],
        key=lambda x: int(x.split('_')[-1].split('.')[0])
    )

    for ckpt_file in checkpoints:
        epoch_num = int(ckpt_file.split('_')[-1].split('.')[0])
        ckpt_path = os.path.join(ckpt_dir, ckpt_file)

        try:
            checkpoint = torch.load(ckpt_path, map_location='cpu')
        except Exception as e:
            print(f"Error loading {ckpt_path}: {e}")
            continue

        train_losses = checkpoint.get('train_losses', [])
        val_losses = checkpoint.get('val_losses', [])

        train_loss = (
            train_losses[epoch_num]
            if isinstance(train_losses, list) and len(train_losses) > epoch_num
            else (train_losses if isinstance(train_losses, float) else None)
        )
        val_loss = (
            val_losses[epoch_num]
            if isinstance(val_losses, list) and len(val_losses) > epoch_num
            else (val_losses if isinstance(val_losses, float) else None)
        )

        all_data.append({
            'model': model_name,
            'epoch': epoch_num,
            'train_loss': train_loss,
            'val_loss': val_loss
        })

# Create DataFrame and sort
df = pd.DataFrame(all_data)
df = df.sort_values(['model', 'epoch']).reset_index(drop=True)

# Save CSV
output_csv = os.path.join(base_dir, 'ablation_all_losses.csv')
df.to_csv(output_csv, index=False)

print(f"✅ Saved losses to:\n{output_csv}")
print(df.head(10))


✅ Saved losses to:
results_ablation_studies/ablation_all_losses.csv
                   model  epoch  train_loss  val_loss
0  A1_Transformer_Priors      0    0.554252  0.474011
1  A1_Transformer_Priors      1    0.443766  0.392110
2  A1_Transformer_Priors      2    0.365065  0.307081
3  A1_Transformer_Priors      3    0.312617  0.275970
4  A1_Transformer_Priors      4    0.274475  0.274747
5  A1_Transformer_Priors      5    0.251992  0.218781
6  A1_Transformer_Priors      6    0.230154  0.208527
7  A1_Transformer_Priors      7    0.211188  0.178917
8  A1_Transformer_Priors      8    0.197184  0.164341
9  A1_Transformer_Priors      9    0.189866  0.153031
