In [9]:
from src.viz import *
import json
import numpy as np
import os
import json
import numpy as np
import matplotlib.pyplot as plt

## Plotting Function Definitions

In [10]:
RESULTS_DIR = "results"
def load_all_log_histories(base_dir=RESULTS_DIR):
    """Walk through results folders and load log_history.json for each experiment."""
    all_logs = {}
    for root, dirs, _ in os.walk(base_dir):
        for d in dirs:
            folder_path = os.path.join(root, d)
            log_file = os.path.join(folder_path, "log_history.json")
            if os.path.exists(log_file):
                try:
                    with open(log_file, "r") as f:
                        log_history = json.load(f)
                        all_logs[d] = log_history
                except json.JSONDecodeError:
                    print(f"Failed to load {log_file}")
    return all_logs


def load_all_metrics_histories(base_dir=RESULTS_DIR):
    """Walk through results folders and load metrics.json for each experiment."""
    all_logs = {}
    for root, dirs, _ in os.walk(base_dir):
        for d in dirs:
            folder_path = os.path.join(root, d)
            log_file = os.path.join(folder_path, "metrics.json")
            if os.path.exists(log_file):
                try:
                    with open(log_file, "r") as f:
                        log_history = json.load(f)
                        all_logs[d] = log_history
                except json.JSONDecodeError:
                    print(f"Failed to load {log_file}")
    return all_logs

def extract_metrics(log_history):
    """Extract train_loss, eval_loss, eval_acc, cumulative_time, grad norms, steps."""
    metrics = {
        "train_loss": [],
        "train_steps": [],
        "eval_loss": [],
        "eval_steps": [],
        "eval_acc": [],
        "eval_acc_steps": [],
        "cumulative_time": [],
        "grad_norms": [],  # {module_name: list of mean abs grad per step}
    }

    cumulative_time = 0.0
    for entry in log_history:
        if "loss" in entry:
            metrics["train_loss"].append(entry["loss"])
            metrics["train_steps"].append(entry.get("step", len(metrics["train_steps"])+1))
        if "eval_loss" in entry:
            metrics["eval_loss"].append(entry["eval_loss"])
            metrics["eval_steps"].append(entry.get("step", len(metrics["eval_steps"])+1))
        if "eval_accuracy" in entry:
            metrics["eval_acc"].append(entry["eval_accuracy"])
            metrics["eval_acc_steps"].append(entry.get("step", len(metrics["eval_acc_steps"])+1))
        if "time_per_step" in entry:
            cumulative_time += np.sum(entry["time_per_step"])
            metrics["cumulative_time"].append(cumulative_time)
        if "grad_norm" in entry:
            v = entry["grad_norm"]
            metrics["grad_norms"].append(v)
    return metrics

def plot_f1_per_size_curves(metrics, save_path="results/f1_per_size_curves.png"):
    """
    metrics: dict[strategy_name, log_history_dict or list]
        Each log_history should contain 'dataset_size' and 'eval_f1' (scalar or list).
    """
    plt.figure(figsize=(8,6), dpi=300)
    sns.set_theme(style="whitegrid")

    data = []
    for name, log_history in metrics.items():
        # If log_history is a list of dicts (from Trainer logs)
        if isinstance(log_history, list):
            eval_f1s = [d.get('eval_f1') for d in log_history if 'eval_f1' in d]
            sizes = [d.get('dataset_size') for d in log_history if 'dataset_size' in d]
            if not sizes and 'dataset_size' in log_history[0]:
                sizes = [log_history[0]['dataset_size']] * len(eval_f1s)
        else:
            eval_f1s = [log_history.get('eval_f1')]
            sizes = [log_history.get('dataset_size')]

        for s, f1 in zip(sizes, eval_f1s):
            strat = name.split('_')[0]
            if s is not None and f1 is not None:
                data.append({"strategy": strat, "size": s, "f1": f1})

    df = pd.DataFrame(data)
    sns.lineplot(data=df, x="size", y="f1", hue="strategy", marker="o", linewidth=2)
    
    plt.xlabel("Dataset Size")
    plt.ylabel("F1 Score")
    plt.title("F1 vs Dataset Size by Strategy")
    plt.legend(title="Strategy")
    plt.tight_layout()
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()

def plot_efficiency_bubble(metrics, save_path="results/efficiency_bubble.png"):
    """
    Bubble plot: dataset size (x) vs F1 (y), bubble size ∝ trainable parameters.
    Expects metrics = {strategy_name: {'dataset_size': int, 'trainable_params': int, 'eval_f1': float}}.
    """
    plt.figure(figsize=(12, 6), dpi=300)
    sns.set_theme(style="whitegrid", context="talk", palette="Set2")
    for name, log_history in metrics.items():
        plt.scatter(
            x=log_history['dataset_size'],
            y=log_history['eval_f1'],
            alpha=0.7,
            label=name,
            edgecolors='black',
            linewidth=0.5
        )
        plt.text(
            log_history['dataset_size'] * 1.02,
            log_history['eval_f1'],
            f"{log_history['trainable_params'] // 1_000}k",
            fontsize=8
        )

    plt.xlabel("Dataset Size")
    plt.ylabel("Macro F1")
    plt.title("Compute–Performance Trade-off (Bubble ∝ Trainable Parameters)")
    plt.legend(title="Strategy", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()

    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()

def plot_loss_curves(all_logs, normalize=True, save_path="results/plots/loss_curves.png"):
    plt.figure(figsize=(20,13), dpi=300)
    sns.set_palette("Paired")
    for name, log_history in all_logs.items():
        m = extract_metrics(log_history)
        x_train = np.array(m["train_steps"])
        x_eval = np.array(m["eval_steps"])
        x_train_norm = x_train / max(x_train)
        x_eval_norm = x_eval / max(x_train)
        sns.lineplot(x=x_train_norm, y=m["train_loss"], label=f"{name} train", linewidth=2)
        sns.lineplot(x=x_eval_norm, y=m["eval_loss"], label=f"{name} eval", linewidth=2, linestyle="--")
    plt.xlabel("Normalized Training Progress" if normalize else "Step")
    plt.ylabel("Loss")
    plt.title("Training and Evaluation Loss")
    plt.legend()
    plt.grid(True)
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()

def plot_accuracy_curves(all_logs, normalize=True, save_path="results/plots/accuracy_curves.png"):
    plt.figure(figsize=(12,8), dpi=300)
    for name, log_history in all_logs.items():
        m = extract_metrics(log_history)
        x_eval = np.array(m["eval_acc_steps"])
        if normalize and len(x_eval) > 0:
            x_eval = x_eval / max(x_eval)
        if len(m["eval_acc"]) > 0:
            plt.plot(x_eval, m["eval_acc"], label=name, linewidth=2)
    plt.xlabel("Normalized Training Progress" if normalize else "Step")
    plt.ylabel("Eval Accuracy")
    plt.title("Evaluation Accuracy")
    plt.legend()
    plt.grid(True)
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()

def plot_cumulative_time(all_logs, normalize=True, save_path="results/time_curves.png"):
    plt.figure(figsize=(10,6))
    for name, log_history in all_logs.items():
        m = extract_metrics(log_history)
        x_train = np.array(m["train_steps"])
        y_time = np.array(m["cumulative_time"])
        if normalize and len(x_train) > 0:
            x_train = x_train / max(x_train)
        plt.plot(x_train, y_time, label=name, linewidth=2)
    plt.xlabel("Normalized Training Progress" if normalize else "Step")
    plt.ylabel("Cumulative Training Time (s)")
    plt.title("Cumulative Training Time")
    plt.legend()
    plt.grid(True)
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()


## Produce Plots

In [11]:
results = load_all_metrics_histories('results')
print(results)
# Make F1 vs size plot
os.makedirs(os.path.join('results', "plots"), exist_ok=True)
plot_f1_per_size_curves(results, save_path=os.path.join('results', "plots", "acc_vs_size.png"))

# Make efficiency tradeoff plot
plot_efficiency_bubble(results, save_path=os.path.join('results', "plots", "efficiency_tradeoff.png"))

{'full_size10000': {'strategy': 'full', 'dataset_size': 10000, 'trainable_params': 66955010, 'train_time': 2109.243427991867, 'train_runtime_from_trainer': 2108.6784, 'train_samples': 10000, 'eval_loss': 0.25121867656707764, 'eval_accuracy': 0.9046, 'eval_f1': 0.904591280253631, 'eval_steps_used': 156, 'logging_steps_used': 39}, 'full_size25000': {'strategy': 'full', 'dataset_size': 25000, 'trainable_params': 66955010, 'train_time': 3408.66047000885, 'train_runtime_from_trainer': 3408.3631, 'train_samples': 25000, 'eval_loss': 0.3417150676250458, 'eval_accuracy': 0.91636, 'eval_f1': 0.916356267615916, 'eval_steps_used': 390, 'logging_steps_used': 97}, 'full_size5000': {'strategy': 'full', 'dataset_size': 5000, 'trainable_params': 66955010, 'train_time': 1666.4126119613647, 'train_runtime_from_trainer': 1665.9797, 'train_samples': 5000, 'eval_loss': 0.37619733810424805, 'eval_accuracy': 0.89464, 'eval_f1': 0.8946335235863073, 'eval_steps_used': 78, 'logging_steps_used': 19}, 'lora_size1

  plt.tight_layout()
  plt.savefig(save_path, bbox_inches="tight")


In [12]:
all_logs = load_all_log_histories()
plot_loss_curves(all_logs)
plot_accuracy_curves(all_logs)
# plot_cumulative_time(all_logs)
# plot_gradients(all_logs)