In [None]:
import json 
%load_ext autoreload
%autoreload 2

with open("results.json", "r") as file: 
  all_data = json.load(file) 


In [None]:
epoch_num = 30 # epoch number to analyze 
epoch_data = [] 
rank_metrics = [] 
keep_metrics = [] 
general_metrics = [] 

data = all_data.copy() 

for data_idx, item in enumerate(data): 
  if item["epoch"] == epoch_num: 
    epoch_data.append(item)

print("-- rank_head metrics\n\n")
for data in epoch_data: 
  print(data["all_metrics"]["avg_loss"])

In [None]:
data = all_data.copy() 
dual_head_runs = [] 
for item in data: 
  if "rank_head_metrics" and "keep_head_metrics" in item: 
    dual_head_runs.append(item)

full_epoch_runs = [] # holds an array of array of full epoch runs, each index are all separate epochs from 1-30
batch = [] 
for item in dual_head_runs:
  if item["epoch"] == 30: 
    batch.append(item) 
    full_epoch_runs.append(batch) 
    print([item["epoch"] for item in batch]) 
    batch = [] # reset batch
  else: # add for whatever batch, this is always in order 
    batch.append(item)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Create visualizations for all 4 runs
for run_idx, run_data in enumerate(full_epoch_runs, 1):
    # Extract data for this run
    epochs = [item["epoch"] for item in run_data]
    
    # All metrics
    all_loss = [item["all_metrics"]["avg_loss"] for item in run_data]
    
    # Rank head metrics
    rank_f1 = [item["rank_head_metrics"]["avg_f1_score"] for item in run_data]
    rank_precision = [item["rank_head_metrics"]["avg_precision"] for item in run_data]
    
    # Keep head metrics
    keep_f1 = [item["keep_head_metrics"]["avg_f1_score"] for item in run_data]
    keep_precision = [item["keep_head_metrics"]["avg_precision"] for item in run_data]
    
    # Create figure with 3 subplots
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    fig.suptitle(f'Training Run {run_idx} - Metrics Over 30 Epochs', fontsize=16, fontweight='bold')
    
    # Plot 1: All Metrics Loss (Joint Loss)
    axes[0].plot(epochs, all_loss, 'b-', linewidth=2, marker='o', markersize=3)
    axes[0].set_title('Joint Loss Over Time')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].grid(True, alpha=0.3)
    
    # Plot 2: F1 Score Comparison (Rank vs Keep)
    axes[1].plot(epochs, rank_f1, 'g-', linewidth=2, marker='o', markersize=3, label='Rank Head')
    axes[1].plot(epochs, keep_f1, 'r-', linewidth=2, marker='o', markersize=3, label='Keep Head')
    axes[1].set_title('F1 Score Comparison')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('F1 Score')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    axes[1].set_ylim([0, 1])
    
    # Plot 3: Precision Comparison (Rank vs Keep)
    axes[2].plot(epochs, rank_precision, 'g-', linewidth=2, marker='o', markersize=3, label='Rank Head')
    axes[2].plot(epochs, keep_precision, 'r-', linewidth=2, marker='o', markersize=3, label='Keep Head')
    axes[2].set_title('Precision Comparison')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Precision')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)
    axes[2].set_ylim([0, 1])
    
    plt.tight_layout()
    plt.show()
    
    print(f"\n{'='*80}")
    print(f"Run {run_idx} Summary (Epoch 30):")
    print(f"{'='*80}")
    print(f"Rank Head - F1: {rank_f1[-1]:.4f}, Precision: {rank_precision[-1]:.4f}")
    print(f"Keep Head - F1: {keep_f1[-1]:.4f}, Precision: {keep_precision[-1]:.4f}")
    print(f"Performance Gap - F1: {rank_f1[-1] - keep_f1[-1]:.4f}, Precision: {rank_precision[-1] - keep_precision[-1]:.4f}")
    print(f"{'='*80}\n")