In [5]:
import torch
import os
import pandas as pd

# Base directory containing your kmer checkpoints folders
base_dir = r"E:\results_msBERT_replication"

# The kmer folders to process
kmer_folders = ['checkpoints_k3', 'checkpoints_k4', 'checkpoints_k5', 'checkpoints_k6']

# Prepare data collection
data = []

for kmer in kmer_folders:
    folder_path = os.path.join(base_dir, kmer)
    
    # List checkpoint files sorted by epoch number
    checkpoints = sorted(
        [f for f in os.listdir(folder_path) if f.startswith('checkpoint_epoch_') and f.endswith('.pth')],
        key=lambda x: int(x.split('_')[-1].split('.')[0])  # extract epoch number for sorting
    )
    
    for ckpt_file in checkpoints:
        epoch_num = int(ckpt_file.split('_')[-1].split('.')[0])
        ckpt_path = os.path.join(folder_path, ckpt_file)
        
        checkpoint = torch.load(ckpt_path, map_location='cpu')
        
        # Extract train and val losses lists (assuming they exist and are lists)
        train_losses = checkpoint.get('train_losses', [])
        val_losses = checkpoint.get('val_losses', [])
        
        # Sometimes train_losses and val_losses might be stored as lists per epoch or just scalars
        # We want the loss of this epoch:
        # Let's try to get the last element if list, or the scalar directly
        
        train_loss = train_losses[epoch_num] if isinstance(train_losses, list) and len(train_losses) > epoch_num else (train_losses if isinstance(train_losses, float) else None)
        val_loss = val_losses[epoch_num] if isinstance(val_losses, list) and len(val_losses) > epoch_num else (val_losses if isinstance(val_losses, float) else None)
        
        # Append to data list
        data.append({
            'kmer_model': kmer,
            'epoch': epoch_num,
            'train_loss': train_loss,
            'val_loss': val_loss
        })

# Convert to DataFrame
df = pd.DataFrame(data)

# Sort nicely by model and epoch
df = df.sort_values(['kmer_model', 'epoch']).reset_index(drop=True)

# Save as CSV
output_csv = os.path.join(base_dir, 'all_kmer_losses.csv')
df.to_csv(output_csv, index=False)

print(f"Saved combined losses CSV to:\n{output_csv}")

# Print some rows to check
print(df.head(10))


  checkpoint = torch.load(ckpt_path, map_location='cpu')


Saved combined losses CSV to:
E:\results_msBERT_replication\all_kmer_losses.csv
       kmer_model  epoch  train_loss  val_loss
0  checkpoints_k3      0    0.569749  0.509271
1  checkpoints_k3      1    0.511882  0.501003
2  checkpoints_k3      2    0.503876  0.484983
3  checkpoints_k3      3    0.498610  0.483134
4  checkpoints_k3      4    0.493705  0.478425
5  checkpoints_k3      5    0.491025  0.474061
6  checkpoints_k3      6    0.488256  0.512219
7  checkpoints_k3      7    0.483946  0.469574
8  checkpoints_k3      8    0.482185  0.485143
9  checkpoints_k3      9    0.478302  0.496562
