In [1]:
from src.viz import *
import json
import numpy as np
import os
import json
import numpy as np
import matplotlib.pyplot as plt

## Plotting Function Definitions

In [None]:
RESULTS_DIR = "results"
def load_all_log_histories(base_dir=RESULTS_DIR):
    """Walk through results folders and load log_history.json for each experiment."""
    all_logs = {}
    for root, dirs, _ in os.walk(base_dir):
        for d in dirs:
            folder_path = os.path.join(root, d)
            log_file = os.path.join(folder_path, "log_history.json")
            if os.path.exists(log_file):
                try:
                    with open(log_file, "r") as f:
                        log_history = json.load(f)
                        all_logs[d] = log_history
                except json.JSONDecodeError:
                    print(f"Failed to load {log_file}")
    return all_logs


def load_all_metrics_histories(base_dir=RESULTS_DIR):
    """Walk through results folders and load metrics.json for each experiment."""
    all_logs = {}
    for root, dirs, _ in os.walk(base_dir):
        for d in dirs:
            folder_path = os.path.join(root, d)
            log_file = os.path.join(folder_path, "metrics.json")
            if os.path.exists(log_file):
                try:
                    with open(log_file, "r") as f:
                        log_history = json.load(f)
                        all_logs[d] = log_history
                except json.JSONDecodeError:
                    print(f"Failed to load {log_file}")
    return all_logs

def extract_metrics(log_history):
    """Extract train_loss, eval_loss, eval_acc, cumulative_time, grad norms, steps."""
    metrics = {
        "train_loss": [],
        "train_steps": [],
        "eval_loss": [],
        "eval_steps": [],
        "eval_acc": [],
        "eval_acc_steps": [],
        "cumulative_time": [],
        "grad_norms": [],  # {module_name: list of mean abs grad per step}
    }

    cumulative_time = 0.0
    for entry in log_history:
        if "loss" in entry:
            metrics["train_loss"].append(entry["loss"])
            metrics["train_steps"].append(entry.get("step", len(metrics["train_steps"])+1))
        if "eval_loss" in entry:
            metrics["eval_loss"].append(entry["eval_loss"])
            metrics["eval_steps"].append(entry.get("step", len(metrics["eval_steps"])+1))
        if "eval_accuracy" in entry:
            metrics["eval_acc"].append(entry["eval_accuracy"])
            metrics["eval_acc_steps"].append(entry.get("step", len(metrics["eval_acc_steps"])+1))
        if "time_per_step" in entry:
            cumulative_time += np.sum(entry["time_per_step"])
            metrics["cumulative_time"].append(cumulative_time)
        if "grad_norm" in entry:
            v = entry["grad_norm"]
            metrics["grad_norms"].append(v)
    return metrics

def plot_f1_per_size_curves(metrics, save_path="results/f1_per_size_curves.png"):
    """
    metrics: dict[strategy_name, log_history_dict or list]
        Each log_history should contain 'dataset_size' and 'eval_f1' (scalar or list).
    """
    plt.figure(figsize=(8,6), dpi=300)
    sns.set_theme(style="whitegrid")

    data = []
    for name, log_history in metrics.items():
        # If log_history is a list of dicts (from Trainer logs)
        if isinstance(log_history, list):
            eval_f1s = [d.get('eval_f1') for d in log_history if 'eval_f1' in d]
            sizes = [d.get('dataset_size') for d in log_history if 'dataset_size' in d]
            if not sizes and 'dataset_size' in log_history[0]:
                sizes = [log_history[0]['dataset_size']] * len(eval_f1s)
        else:
            eval_f1s = [log_history.get('eval_f1')]
            sizes = [log_history.get('dataset_size')]

        for s, f1 in zip(sizes, eval_f1s):
            strat = name.split('_')[0]
            if s is not None and f1 is not None:
                data.append({"strategy": strat, "size": s, "f1": f1})

    df = pd.DataFrame(data)
    sns.lineplot(data=df, x="size", y="f1", hue="strategy", marker="o", linewidth=2)
    
    plt.xlabel("Dataset Size")
    plt.ylabel("F1 Score")
    plt.title("F1 vs Dataset Size by Strategy")
    plt.legend(title="Strategy")
    plt.tight_layout()
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()

def plot_efficiency_bubble(metrics, save_path="results/efficiency_bubble.png"):
    """
    Bubble plot: dataset size (x) vs F1 (y), bubble size ∝ trainable parameters.
    Expects metrics = {strategy_name: {'dataset_size': int, 'trainable_params': int, 'eval_f1': float}}.
    """
    plt.figure(figsize=(12, 6), dpi=300)
    sns.set_theme(style="whitegrid", context="talk", palette="Set2")
    for name, log_history in metrics.items():
        plt.scatter(
            x=log_history['dataset_size'],
            y=log_history['eval_f1'],
            alpha=0.7,
            label=name,
            edgecolors='black',
            linewidth=0.5
        )
        plt.text(
            log_history['dataset_size'] * 1.02,
            log_history['eval_f1'],
            f"{log_history['trainable_params'] // 1_000}k",
            fontsize=8
        )

    plt.xlabel("Dataset Size")
    plt.ylabel("Macro F1")
    plt.title("Compute–Performance Trade-off (Bubble ∝ Trainable Parameters)")
    plt.legend(title="Strategy", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()

    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()

def plot_loss_curves(all_logs, normalize=True, save_path="results/plots/loss_curves.png"):
    plt.figure(figsize=(20,13), dpi=300)
    sns.set_palette("Paired")
    for name, log_history in all_logs.items():
        m = extract_metrics(log_history)
        x_train = np.array(m["train_steps"])
        x_eval = np.array(m["eval_steps"])
        x_train_norm = x_train / max(x_train)
        x_eval_norm = x_eval / max(x_train)
        sns.lineplot(x=x_train_norm, y=m["train_loss"], label=f"{name} train", linewidth=2)
        sns.lineplot(x=x_eval_norm, y=m["eval_loss"], label=f"{name} eval", linewidth=2, linestyle="--")
    plt.xlabel("Normalized Training Progress" if normalize else "Step")
    plt.ylabel("Loss")
    plt.title("Training and Evaluation Loss")
    plt.legend()
    plt.grid(True)
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()

def plot_accuracy_curves(all_logs, normalize=True, save_path="results/plots/accuracy_curves.png"):
    plt.figure(figsize=(12,8), dpi=300)
    for name, log_history in all_logs.items():
        m = extract_metrics(log_history)
        x_eval = np.array(m["eval_acc_steps"])
        if normalize and len(x_eval) > 0:
            x_eval = x_eval / max(x_eval)
        if len(m["eval_acc"]) > 0:
            plt.plot(x_eval, m["eval_acc"], label=name, linewidth=2)
    plt.xlabel("Normalized Training Progress" if normalize else "Step")
    plt.ylabel("Eval Accuracy")
    plt.title("Evaluation Accuracy")
    plt.legend()
    plt.grid(True)
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()

In [None]:
from typing import Dict, List, Any

sns.set_theme(style="whitegrid")
def build_summary_table(results_data: Dict[str, List[Dict[str,Any]]],
                        trainable_params_map: Dict[str,int]=None) -> pd.DataFrame:
    """
    results_data: dict name -> list of dicts (log_history)
    trainable_params_map: optional mapping name->int used when log_history doesn't contain trainable_params
    """
    rows = []
    for name, history in results_data.items():
        trainable_params = None
        eval_f1 = max([entry['eval_f1'] for entry in history if 'eval_f1' in entry]) 
        total_flos = max([entry['total_flos'] for entry in history if 'total_flos' in entry])
        train_time = max([entry['train_runtime'] for entry in history if 'train_runtime' in entry])
        
        # fallback: try to infer dataset size or strategy from name e.g., 'lora_size25000'
        dataset_size = None
        parts = name.split("_")
        for p in parts:
            if p.startswith("size"):
                try:
                    dataset_size = int(p.replace("size",""))
                except:
                    pass
        # try provided mapping for trainable params
        if (trainable_params is None) and trainable_params_map and name in trainable_params_map:
            trainable_params = int(trainable_params_map[name])

        row = {
            "name": name,
            "strategy": parts[0] if parts else name,
            "dataset_size": dataset_size,
            "trainable_params": trainable_params,
            "eval_f1": eval_f1,
            "total_flos": total_flos,
            "train_time": train_time
        }
        rows.append(row)
    df = pd.DataFrame(rows)
    return df

def compute_f1_petaflop(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df["petaflops"] = df["total_flos"] / 1e15
    # avoid divide-by-zero or NaN
    df["f1_per_petaflop"] = df.apply(
        lambda r: (r["eval_f1"] / r["petaflops"]) if pd.notna(r["eval_f1"]) and pd.notna(r["petaflops"]) and r["petaflops"]>0 else None,
        axis=1
    )
    # optionally f1_per_hour if train_time available
    df["f1_per_hour"] = df.apply(
        lambda r: (r["eval_f1"] / (r["train_time"]/3600.0)) if pd.notna(r["eval_f1"]) and pd.notna(r["train_time"]) and r["train_time"]>0 else None,
        axis=1
    )
    df["f1_per_mcompute"] = df.apply(
        lambda r: (r["eval_f1"] / (r["trainable_params"]/ 1e6)) if pd.notna(r["eval_f1"]) and pd.notna(r["trainable_params"]) and r["trainable_params"]>0 else None,
        axis=1
    )
    return df

def plot_f1_per_petaflop(df: pd.DataFrame, out_path=os.path.join('results/plots',"f1_per_petaflop_simple.png")):
    dfp = df.dropna(subset=["f1_per_petaflop"]).copy()
    if dfp.empty:
        print("No data to plot F1 per PetaFLOP.")
        return
    dfp = dfp.sort_values("f1_per_petaflop", ascending=False)
    plt.figure(figsize=(10,5), dpi=150)
    ax = sns.barplot(data=dfp, x="name", y="f1_per_petaflop", hue="strategy")
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
    ax.set_ylabel("Macro F1 per PetaFLOP")
    ax.set_title("Macro F1 per PetaFLOP")
    plt.tight_layout()
    plt.savefig(out_path, bbox_inches="tight")
    plt.close()
    print("Saved", out_path)

def plot_f1_per_mparams(df: pd.DataFrame, out_path=os.path.join('results/plots',"f1_per_mparams.png")):
    dfp = df.dropna(subset=["f1_per_mcompute"]).copy()
    if dfp.empty:
        print("No data to plot F1 per mparams.")
        return
    dfp = dfp.sort_values("f1_per_mcompute", ascending=False)
    plt.figure(figsize=(10,5), dpi=300)
    ax = sns.barplot(data=dfp, x="name", y="f1_per_mcompute", hue="strategy")
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
    ax.set_ylabel("Macro F1 per Million Params")
    ax.set_title("Parameter Effiency")
    plt.tight_layout()
    plt.savefig(out_path, bbox_inches="tight")
    plt.close()
    print("Saved", out_path)

## Produce Plots

In [None]:
all_logs = load_all_log_histories()

# optional map if trainable_params missing in logs
trainable_params_map = {
    "full_size10000": 67700000,
    "full_size25000": 67700000,
    "full_size5000": 67700000,
    "lora_size10000": 739586,
    "lora_size25000": 739586,
    "lora_size5000": 739586
}

df = build_summary_table(all_logs, trainable_params_map)
df = compute_f1_petaflop(df)
print(df[["name","strategy","dataset_size","trainable_params","eval_f1","total_flos","f1_per_petaflop","f1_per_hour","f1_per_mcompute"]])
plot_f1_per_mparams(df)

             name strategy  dataset_size  trainable_params   eval_f1  \
0  full_size10000     full         10000          67700000  0.904591   
1  full_size25000     full         25000          67700000  0.916356   
2   full_size5000     full          5000          67700000  0.894634   
3  lora_size10000     lora         10000            739586  0.882311   
4  lora_size25000     lora         25000            739586  0.896599   
5   lora_size5000     lora          5000            739586  0.868918   

     total_flos  f1_per_petaflop  f1_per_hour  f1_per_mcompute  
0  1.987011e+15         0.455252     1.544346         0.013362  
1  4.967527e+15         0.184469     0.967879         0.013536  
2  9.935055e+14         0.900482     1.933205         0.013215  
3  2.021091e+15         0.436552     1.548938         1.192980  
4  5.052728e+15         0.177449     1.037948         1.212299  
5  1.010546e+15         0.859850     1.850905         1.174871  
Saved results/plots/f1_per_mparams.png


  ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")


In [11]:
results = load_all_metrics_histories('results')
print(results)
# Make F1 vs size plot
os.makedirs(os.path.join('results', "plots"), exist_ok=True)
plot_f1_per_size_curves(results, save_path=os.path.join('results', "plots", "acc_vs_size.png"))

# Make efficiency tradeoff plot
plot_efficiency_bubble(results, save_path=os.path.join('results', "plots", "efficiency_tradeoff.png"))

{'full_size10000': {'strategy': 'full', 'dataset_size': 10000, 'trainable_params': 66955010, 'train_time': 2109.243427991867, 'train_runtime_from_trainer': 2108.6784, 'train_samples': 10000, 'eval_loss': 0.25121867656707764, 'eval_accuracy': 0.9046, 'eval_f1': 0.904591280253631, 'eval_steps_used': 156, 'logging_steps_used': 39}, 'full_size25000': {'strategy': 'full', 'dataset_size': 25000, 'trainable_params': 66955010, 'train_time': 3408.66047000885, 'train_runtime_from_trainer': 3408.3631, 'train_samples': 25000, 'eval_loss': 0.3417150676250458, 'eval_accuracy': 0.91636, 'eval_f1': 0.916356267615916, 'eval_steps_used': 390, 'logging_steps_used': 97}, 'full_size5000': {'strategy': 'full', 'dataset_size': 5000, 'trainable_params': 66955010, 'train_time': 1666.4126119613647, 'train_runtime_from_trainer': 1665.9797, 'train_samples': 5000, 'eval_loss': 0.37619733810424805, 'eval_accuracy': 0.89464, 'eval_f1': 0.8946335235863073, 'eval_steps_used': 78, 'logging_steps_used': 19}, 'lora_size1

  plt.tight_layout()
  plt.savefig(save_path, bbox_inches="tight")


In [5]:
all_logs = load_all_log_histories()
print(all_logs)
plot_loss_curves(all_logs)
plot_accuracy_curves(all_logs)

{'full_size10000': [{'loss': 0.5034, 'grad_norm': 5.185615539550781, 'learning_rate': 4.79765708200213e-05, 'epoch': 0.12460063897763578, 'step': 39}, {'loss': 0.3352, 'grad_norm': 2.747762680053711, 'learning_rate': 4.589989350372737e-05, 'epoch': 0.24920127795527156, 'step': 78}, {'loss': 0.3127, 'grad_norm': 7.535058975219727, 'learning_rate': 4.3823216187433444e-05, 'epoch': 0.3738019169329074, 'step': 117}, {'loss': 0.3382, 'grad_norm': 2.5630247592926025, 'learning_rate': 4.174653887113951e-05, 'epoch': 0.4984025559105431, 'step': 156}, {'eval_loss': 0.28028789162635803, 'eval_accuracy': 0.89412, 'eval_f1': 0.8941009847728258, 'eval_runtime': 200.8735, 'eval_samples_per_second': 124.456, 'eval_steps_per_second': 3.893, 'epoch': 0.4984025559105431, 'step': 156}, {'loss': 0.2866, 'grad_norm': 2.7168471813201904, 'learning_rate': 3.966986155484558e-05, 'epoch': 0.6230031948881789, 'step': 195}, {'loss': 0.2887, 'grad_norm': 3.5869998931884766, 'learning_rate': 3.759318423855165e-05,