In [1]:
import re
import os
from collections import defaultdict

import wandb
import numpy as np
from omegaconf import OmegaConf as om

In [2]:
api = wandb.Api()
project_name = "mosaic-ml/mpt-over-sampling"

In [3]:
# Online vs offline comparisons
groups = {
    "mpt1.5-3b-params-52b-tokens-2.5x-epoch-raw-proportion": "raw-mix-10T-setting",
    "mpt1.5-3b-params-52b-tokens-2.5x-epoch-mpt1.5-proportion": "mpt1.5-mix-10T-setting",
    "mpt1.5-3b-params-52b-tokens-1.4x-epoch-raw-proportion": "raw-mix-5T-setting",
}

# 2.5 vs 1.4 epoch comparisons
# groups = {
#     "mpt1.5-3b-params-52b-tokens-1.4x-epoch-raw-proportion": "raw-mix-5T-setting",
#     "mpt1.5-3b-params-52b-tokens-2.5x-epoch-raw-proportion": "raw-mix-10T-setting",
# }

In [4]:
normalized = True
filters = {}

In [5]:
runs = api.runs(project_name, filters=filters)

In [6]:
categories = ["category-average", "world_knowledge", "commonsense_reasoning", "language_understanding", "symbolic_problem_solving", "reading_comprehension"]
metric_names = [os.path.join("metrics", f"{'normalized_' if normalized else ''}model_gauntlet", category) for category in categories]

In [7]:
raw_group_results = defaultdict(list)
for run in runs:
    group = run.group
    name = run.name

    if "eval" not in name:
        continue

    if group not in groups:
        continue
    
    if "maxified" in name:
        continue

    if "normalized" in name and not normalized:
        continue
    
    try:
        history = run.history(pandas=False)
        final_step = history[-1]
        metrics = [final_step[metric_name] for metric_name in metric_names]
        raw_group_results[group].append(metrics)
    except:
        print(f"Failed for run: {name}")

In [8]:
group_results = {k: np.mean(v, axis=0) for k, v in raw_group_results.items()}

In [9]:
for run_name, metrics in group_results.items():
    metrics = metrics.tolist()
    metrics = [f"{metric:.4f}" for metric in metrics]
    print(f"| {groups[run_name]} | {' | '.join(metrics)} |")

| raw-mix-5T-setting | 0.2097 | 0.2126 | 0.2307 | 0.2397 | 0.0937 | 0.2717 |
| mpt1.5-mix-10T-setting | 0.1666 | 0.1715 | 0.1974 | 0.1947 | 0.1004 | 0.1688 |
| raw-mix-10T-setting | 0.1982 | 0.1957 | 0.2156 | 0.2462 | 0.1060 | 0.2277 |
