disclosure: lots of AI-generated boilerplate code scattered throughout

# baseline results

In [None]:
import json
import numpy as np
    
def evaluate_routing(small_file, large_file, small_proportion, n_trials=10, seed=42):
    with open(small_file) as f: small_outputs = [json.loads(line) for line in f]
    with open(large_file) as f: large_outputs = [json.loads(line) for line in f]
        
    n_examples = len(small_outputs)
    n_to_small = int(small_proportion * n_examples)
    
    pass_rates = []
    perfect_rates = []
    rng = np.random.default_rng(seed)

    for _ in range(n_trials):
        small_set = set(rng.choice(n_examples, n_to_small, replace=False))
        
        total_test_cases = 0
        passed_test_cases = 0
        perfect_solutions = 0
        
        for i in range(n_examples):
            results = small_outputs[i]['results'] if i in small_set else large_outputs[i]['results']
            
            for r in results:
                if r.get('code_error'): continue
                
                # individual cases
                for result, _ in r['test_results']:
                    total_test_cases += 1
                    passed_test_cases += result == 'pass'
                
                # perfect solutions
                if all(res == 'pass' for res, _ in r['test_results']):
                    perfect_solutions += 1
        
        pass_rates.append(passed_test_cases / total_test_cases * 100)
        perfect_rates.append(perfect_solutions / (n_examples * len(results)) * 100)
    
    return pass_rates, perfect_rates

props = [0.2, 0.4, 0.6, 0.8]
model_pairs = [
    ('outputs/1b_test_outputs.jsonl', 'outputs/8b_test_outputs.jsonl'),
    ('outputs/8b_test_outputs.jsonl', 'outputs/70b_test_outputs.jsonl'),
    ('outputs/1b_test_outputs.jsonl', 'outputs/70b_test_outputs.jsonl')
]

for small_file, large_file in model_pairs:
    print(f"\n{small_file.split('/')[-1]} vs {large_file.split('/')[-1]}")
    for prop in [0., 0.2, 0.4, 0.6, 0.8, 1.]:
        pass_rates, perfect_rates = evaluate_routing(small_file, large_file, prop)
        print(f"\n{prop*100}% routed to small model:")
        print(f"Pass rate: {np.mean(pass_rates):.1f}% ± {np.std(pass_rates):.1f}%")
        print(f"Perfect rate: {np.mean(perfect_rates):.1f}% ± {np.std(perfect_rates):.1f}%")

# compute soft labels from paired data

In [34]:
import json
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
from generate import format_prompt
from matplotlib.gridspec import GridSpec

def analyze_model_comparison(small, large, output_file):
    win_freq = defaultdict(int)
    reason_freq = defaultdict(int)

    def get_win_brevity(gen1, gen2, tok1, tok2):
        if not gen1.get('code_error') and not gen2.get('code_error'):
            if gen1['pass'] != gen2['pass']:
                return 'test_cases', gen1['pass'] > gen2['pass']
            return 'tokens', tok1 <= tok2
        if gen1.get('code_error') is None and gen2.get('code_error') is None:
            raise Exception('both failed')
        return 'code_error', not gen1.get('code_error')
    
    def get_win_standard(gen1, gen2, tok1, tok2):
        return 'test_cases', gen1['pass'] > gen2['pass']

    with open(small) as f1, open(large) as f2, open(output_file, 'w') as f_out:
        for line1, line2 in zip(f1, f2):
            o1, o2 = json.loads(line1), json.loads(line2)
            prompt = format_prompt(o1['item'])
            
            wins = 0
            for pair in zip(o1['results'], o2['results'], o1['num_tokens'], o2['num_tokens']):
                reason, is_win = get_win_standard(*pair)
                wins += is_win
                reason_freq[reason] += 1
                
            win_freq[wins] += 1
            json.dump({'prompt': prompt, 'target': wins / 10}, f_out)
            f_out.write('\n')

    plt.style.use('seaborn-v0_8-paper')
    sns.set_palette("muted")
    
    fig = plt.figure(figsize=(8, 5))
    gs = GridSpec(1, 1, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    
    bar_color = "#4878D0"
    edge_color = "#2F4858"
    
    k, v = zip(*sorted(win_freq.items(), key=lambda x: x[0]))
    sns.barplot(x=list(k), y=list(v), ax=ax1, color=bar_color, edgecolor=edge_color)
    ax1.set_xlabel('Number of Wins (out of 10)', fontsize=10)
    ax1.set_ylabel('Frequency', fontsize=10)
    ax1.grid(True, axis='y', linestyle='--', alpha=0.7)
    ax1.spines['top'].set_visible(False)
    ax1.spines['right'].set_visible(False)

    plt.tight_layout()
    return fig, win_freq, reason_freq


with brevity quality function

In [None]:
fig, wins, reasons = analyze_model_comparison(
    'outputs/1b_train_outputs.jsonl',
    'outputs/70b_train_outputs.jsonl',
    'outputs/1b_70b_wins.jsonl'
)
plt.savefig('figures/1b_70b_analysis.png')

fig, wins, reasons = analyze_model_comparison(
    'outputs/8b_train_outputs.jsonl',
    'outputs/70b_train_outputs.jsonl',
    'outputs/8b_70b_wins.jsonl'
)
plt.savefig('figures/8b_70b_analysis.png')

fig, wins, reasons = analyze_model_comparison(
    'outputs/1b_train_outputs.jsonl',
    'outputs/8b_train_outputs.jsonl',
    'outputs/1b_8b_wins.jsonl'
)
plt.savefig('figures/1b_8b_analysis.png')

with standard quality function

In [None]:
fig, wins, reasons = analyze_model_comparison(
    'data/1b_train_outputs.jsonl',
    'data/70b_train_outputs.jsonl',
    'data/1b_70b_wins_simpl.jsonl'
)
plt.savefig('figures/1b_70b_analysis_simpl.png')

fig, wins, reasons = analyze_model_comparison(
    'data/8b_train_outputs.jsonl',
    'data/70b_train_outputs.jsonl',
    'data/8b_70b_wins_simpl.jsonl'
)
plt.savefig('figures/8b_70b_analysis_simpl.png')

fig, wins, reasons = analyze_model_comparison(
    'data/1b_train_outputs.jsonl',
    'data/8b_train_outputs.jsonl',
    'data/1b_8b_wins_simpl.jsonl'
)
plt.savefig('figures/1b_8b_analysis_simpl.png')

## compute winrates (with data augmentation)

challenge: our scores aren't continuous and it's hard to grid search over the discrete space we've created

solution: introduce a heuristic code quality score

In [67]:
def quality(gen, num_tokens, max_tokens=512):
    if gen.get('code_error'): return 0.0
    test_score = gen['pass'] / 3
    token_score = 1 - (num_tokens / max_tokens)
    return (0.8 * test_score) + (0.2 * token_score)

In [None]:
import json
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from generate import format_prompt
from matplotlib.gridspec import GridSpec
from collections import Counter

def find_optimal_t(quality_gaps_by_prompt):
    t_values = np.linspace(0, 1, 100)
    
    def diversity_score(t):
        prob_labels = [np.mean([gap >= -t for gap in gaps]) for gaps in quality_gaps_by_prompt]
        return sum(abs(l1 - l2) for i, l1 in enumerate(prob_labels) for _, l2 in enumerate(prob_labels[i+1:]))
    
    scores = [diversity_score(t) for t in t_values]
    optimal_t = t_values[np.argmax(scores)]
    return optimal_t

def analyze_model_comparison(small, large, output_file, title='', augmented=False):
    quality_gaps = []
    quality_scores_small = []
    quality_scores_large = []
    wins_per_prompt = []
    quality_gaps_by_prompt = []

    bar_color = "#4878D0"
    edge_color = "#2F4858"
    
    with open(small) as f1, open(large) as f2, open(output_file, 'w') as f_out:
        for line1, line2 in zip(f1, f2):
            o1, o2 = json.loads(line1), json.loads(line2)
            prompt = format_prompt(o1['item'])
            
            wins = 0
            prompt_gaps = []
            for gen1, gen2, tok1, tok2 in zip(o1['results'], o2['results'], o1['num_tokens'], o2['num_tokens']):
                score1 = quality(gen1, tok1)
                score2 = quality(gen2, tok2)
                quality_gap = score1 - score2
                quality_scores_small.append(score1)
                quality_scores_large.append(score2)
                quality_gaps.append(quality_gap)
                wins += (score1 >= score2)
                prompt_gaps.append(score1 - score2)
            
            quality_gaps_by_prompt.append(prompt_gaps)
            wins_per_prompt.append(wins)

            json.dump({'prompt': prompt, 'target': wins / len(o1['results'])}, f_out)
            f_out.write('\n')

    plt.style.use('seaborn-v0_8-paper')
    sns.set_palette("muted")
    
    fig = plt.figure(figsize=(8, 5))
    gs = GridSpec(1, 1, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])

    if not augmented:
        fig.suptitle(title, fontsize=12)
        win_freq = Counter(wins_per_prompt)
        k, v = zip(*sorted(win_freq.items(), key=lambda x: x[0]))
    else:
        fig.suptitle(title+', augmented', fontsize=12)
        optimal_t = find_optimal_t(quality_gaps_by_prompt)
        augmented_labels = [np.mean([gap >= -optimal_t for gap in gaps]) for gaps in quality_gaps_by_prompt]
        win_freq = Counter(10 * augmented_labels)

        with open(small) as f, open(output_file, 'w') as f_out:
            for line, aug_label in zip(f, augmented_labels):
                prompt = format_prompt(json.loads(line)['item'])
                json.dump({'prompt': prompt, 'target': aug_label}, f_out)
                f_out.write('\n')

    k, v = zip(*sorted(win_freq.items(), key=lambda x: x[0]))
    sns.barplot(x=list(k), y=list(v), ax=ax1, color=bar_color, edgecolor=edge_color)
    ax1.set_xlabel('Number of Wins (out of 10)', fontsize=10)
    ax1.set_ylabel('Frequency', fontsize=10)
    ax1.grid(True, axis='y', linestyle='--', alpha=0.7)
    ax1.spines['top'].set_visible(False)
    ax1.spines['right'].set_visible(False)

    plt.tight_layout()
    return fig

for pair, title in zip([
    ('1b_train_outputs.jsonl', '8b_train_outputs.jsonl', '1b_8b'),
    ('8b_train_outputs.jsonl', '70b_train_outputs.jsonl', '8b_70b'),
    ('1b_train_outputs.jsonl', '70b_train_outputs.jsonl', '1b_70b'),
], ['1B vs. 8B', '8B vs. 70B', '1B vs. 70B']):
    small_file, large_file, name = pair
    fig = analyze_model_comparison(f'data/{small_file}', f'data/{large_file}', f'data/{name}_quality.jsonl', title=title)
    plt.savefig(f'figures/{name}_quality.png')

for pair, title in zip(
    [('1b_train_outputs.jsonl', '8b_train_outputs.jsonl', '1b_8b'),
    ('8b_train_outputs.jsonl', '70b_train_outputs.jsonl', '8b_70b'),
    ('1b_train_outputs.jsonl', '70b_train_outputs.jsonl', '1b_70b')],
    ['1B vs. 8B', '8B vs. 70B', '1B vs. 70B']
):
    small_file, large_file, name = pair
    fig = analyze_model_comparison(
        f'data/{small_file}',
        f'data/{large_file}',
        f'data/{name}_quality_augmented.jsonl',
        title=title,
        augmented=True
    )
    plt.savefig(f'figures/{name}_quality_augmented.png')

## validation set threshold tuning

In [None]:
import torch
from router import Router, RouterDataset
from tqdm import tqdm
from torch.utils.data import DataLoader

device = 'mps'

def load_model(from_checkpoint):
    model = Router(hidden_size=384)
    weights = torch.load(from_checkpoint, map_location='cpu')['ema']

    weight_map = {
        'ema_model.head.0.weight': model.head[0].weight,
        'ema_model.head.0.bias': model.head[0].bias,
        'ema_model.head.3.weight': model.head[3].weight,
        'ema_model.head.3.bias': model.head[3].bias
    }

    for ema_key, model_param in weight_map.items():
        model_param.data.copy_(weights[ema_key])
    
    return model.eval().to(device)

@torch.no_grad()
def get_preds(from_checkpoint, split):
    dataset = RouterDataset(input=f'data/{split}_prompts.jsonl')
    loader = DataLoader(dataset, batch_size=8, shuffle=False)

    model = load_model(from_checkpoint)
    all_preds = []
    for batch in tqdm(loader):
        logits = model.forward(batch['input_ids'].to(device), batch['attention_mask'].to(device))
        preds = torch.sigmoid(logits)
        all_preds.extend(preds.tolist())
    return all_preds

import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def analyze_thresholds(from_checkpoint, small_outputs, large_outputs, title=''):
    preds = np.array(get_preds(from_checkpoint, split='val'))

    with open(small_outputs, 'r') as f:
        small_results = np.array([json.loads(line)['results'][0]['pass'] for line in f])
    with open(large_outputs, 'r') as f:
        large_results = np.array([json.loads(line)['results'][0]['pass'] for line in f])
        
    large_pass = np.sum(large_results) / (3 * len(large_results))
    
    percentiles = [20, 40, 60, 80]
    percentile_thresholds = np.percentile(preds, percentiles)
    
    plt.style.use('seaborn-v0_8-paper')
    sns.set_palette("muted")
    plt.figure(figsize=(10, 6))
    hist_color = "#4878D0"
    threshold_color = "#2F4858"
    plt.hist(preds, bins=30, color=hist_color, alpha=0.6, edgecolor='white')
    for t in percentile_thresholds: plt.axvline(x=t, color=threshold_color, linestyle='--', alpha=0.5)
    plt.title(title)
    plt.xlabel('Model Confidence Score', fontsize=10)
    plt.ylabel('Frequency', fontsize=10)
    plt.grid(True, axis='y', linestyle='--', alpha=0.7)
    for spine in ['top', 'right']: plt.gca().spines[spine].set_visible(False)
    plt.tight_layout()

    for p, t in zip(percentiles, percentile_thresholds):
        route_small = preds < t
        mixed_pass = np.sum(np.where(route_small, small_results, large_results)) / (3 * len(large_results))
        cost_adv = np.sum(route_small) / len(preds)
        perf_gap = mixed_pass - large_pass
        print(f"{p}th percentile:")
        print(f"  t: {t:.6f}")
        print(f"  perf_gap: {perf_gap:.3f}")
        print(f"  cost_adv: {cost_adv:.3f}")
        
    return plt.gcf(), preds

with brevity quality function

In [None]:
fig, preds = analyze_thresholds(
    'checkpoints/1b_8b_aug=3_epoch=10.pt', 
    'data/1b_val_outputs.jsonl',
    'data/8b_val_outputs.jsonl',
    title='1B vs. 8B'
)
plt.savefig('figures/1b_8b_threshold.png', bbox_inches='tight', dpi=300)

fig, preds = analyze_thresholds(
    'checkpoints/8b_70b_aug=3_epoch=10.pt',
    'data/8b_val_outputs.jsonl',
    'data/70b_val_outputs.jsonl',
    title='8B vs. 70B'
)
plt.savefig('figures/8b_70b_threshold.png', bbox_inches='tight', dpi=300)

fig, preds = analyze_thresholds(
    'checkpoints/1b_70b_aug=3_epoch=10.pt', 
    'data/1b_val_outputs.jsonl',
    'data/70b_val_outputs.jsonl',
    title='1B vs. 70B'
)
plt.savefig('figures/1b_70b_threshold.png', bbox_inches='tight', dpi=300)

In [None]:
fig, preds = analyze_thresholds(
    'checkpoints/1b_8b_simpl_epoch=10.pt', 
    'data/1b_val_outputs.jsonl',
    'data/8b_val_outputs.jsonl',
    title='1B vs. 8B'
)
plt.savefig('figures/1b_8b_threshold_simpl.png', bbox_inches='tight', dpi=300)

fig, preds = analyze_thresholds(
    'checkpoints/8b_70b_simpl_epoch=10.pt',
    'data/8b_val_outputs.jsonl',
    'data/70b_val_outputs.jsonl',
    title='8B vs. 70B'
)
plt.savefig('figures/8b_70b_threshold_simpl.png', bbox_inches='tight', dpi=300)

fig, preds = analyze_thresholds(
    'checkpoints/1b_70b_simpl_epoch=10.pt', 
    'data/1b_val_outputs.jsonl',
    'data/70b_val_outputs.jsonl',
    title='1B vs. 70B'
)
plt.savefig('figures/1b_70b_threshold_simpl.png', bbox_inches='tight', dpi=300)

# validation-tuned thresholds

with brevity quality function

| Pair      | 20%      | 40%      | 60%      | 80%       |
| --------- | -------- | -------- | -------- | --------- |
| 1B vs 8B  | 0.0826   | 0.086494 | 0.089773 | 0.096273  |
| 8B vs 70B | 0.745153 | 0.758044 | 0.766146 | 0.7747814 |
| 1B vs 70B | 0.218676 | 0.219519 | 0.220657 | 0.222232  |

with standard quality fuunction

| Pair      | 20%      | 40%      | 60%      | 80%      |
| --------- | -------- | -------- | -------- | -------- |
| 1B vs 8B  | 0.657548 | 0.662825 | 0.667725 | 0.671890 |
| 8B vs 70B | 0.877196 | 0.887329 | 0.893095 | 0.899906 |
| 1B vs 70B | 0.642317 | 0.667306 | 0.681851 | 0.699679 |

## test set eval

In [5]:
def eval(from_checkpoint, small_outputs, large_outputs, thresholds):
    preds = np.array(get_preds(from_checkpoint, split='test'))

    with open(small_outputs, 'r') as f:
        small_results = np.array([json.loads(line)['results'][0]['pass'] for line in f])
    with open(large_outputs, 'r') as f:
        large_results = np.array([json.loads(line)['results'][0]['pass'] for line in f])
        
    large_pass = np.sum(large_results) / (3 * len(large_results))
    
    for t in thresholds:
        route_small = preds < t
        mixed_pass = np.sum(np.where(route_small, small_results, large_results)) / (3 * len(large_results))
        cost_adv = np.sum(route_small) / len(preds)
        perf_gap = mixed_pass - large_pass
        print(f"with threshold {t}")
        print(f"  t: {t:.3f}")
        print(f"  perf_gap: {100*perf_gap:.1f}")
        print(f"  cost_adv: {cost_adv:.3f}")

eval with brevity quality function

In [None]:
eval(
    'checkpoints/1b_8b_aug=3_epoch=10.pt',
    'data/1b_test_outputs.jsonl',
    'data/8b_test_outputs.jsonl',
    [0.0826, 0.086494, 0.089773, 0.096273]
)

eval(
    'checkpoints/8b_70b_aug=3_epoch=10.pt',
    'data/8b_test_outputs.jsonl',
    'data/70b_test_outputs.jsonl',
    [0.745153, 0.758044, 0.766146, 0.7747814]
)

eval(
    'checkpoints/1b_70b_aug=3_epoch=10.pt', 
    'data/1b_test_outputs.jsonl',
    'data/70b_test_outputs.jsonl',
    [0.218676, 0.219519, 0.220657, 0.222232]
)

eval with standard quality function

In [None]:
eval(
    'checkpoints/1b_8b_simpl_epoch=10.pt',
    'data/1b_test_outputs.jsonl',
    'data/8b_test_outputs.jsonl',
    [0.657548, 0.662825, 0.667725, 0.671890]
)

eval(
    'checkpoints/8b_70b_simpl_epoch=10.pt',
    'data/8b_test_outputs.jsonl',
    'data/70b_test_outputs.jsonl',
    [0.877196, 0.887329, 0.893095, 0.899906]
)

eval(
    'checkpoints/1b_70b_simpl_epoch=10.pt', 
    'data/1b_test_outputs.jsonl',
    'data/70b_test_outputs.jsonl',
    [0.642317, 0.667306, 0.681851, 0.699679]
)

# test set values

with brevity qf

| Pair      | 20%          | 40%          | 60%           | 80%           |
| --------- | ------------ | ------------ | ------------- | ------------- |
| 1B vs 8B  | -4.9 (0.260) | -7.9 (0.490) | -10.9 (0.596) | -15.7 (0.772) |
| 8B vs 70B | -0.7 (0.204) | 1.5 (0.392)  | 4.9 (0.544)   | 7.3 (0.754)   |
| 1B vs 70B | -1.7 (0.230) | -3.1 (0.372) | -6.3 (0.624)  | -11.1 (0.848) |

with standard quality function

| Pair      | 20%           | 40%           | 60%           | 80%           |
| --------- | ------------- | ------------- | ------------- | ------------- |
| 1B vs 8B  | -4.1 (0.158)  | -10.6 (0.354) | -15.3 (0.574) | -18.7 (0.782) |
| 8B vs 70B | -0.2 (0.214)  | 0.9 (0.376)   | 3.5 (0.490)   | 6.9 (0.722)   |
| 1B vs 70B | -5.9 (0.216)  | -9.4 (0.392)  | -10.6 (0.514) | -10.7 (0.748) |

baseline

| Pair      | 0%  | 20%  | 40%  | 60%   | 80%   | 100%  |
| --------- | --- | ---- | ---- | ----- | ----- | ----- |
| 1B vs 8B  | 0.0 | -3.1 | -7.0 | -11.1 | -14.3 | -18.5 |
| 8B vs 70B | 0.0 | 1.0  | 1.3  | 2.3   | 3.1   | 3.9   |
| 1B vs 70B | 0.0 | -2.4 | -6.2 | -9.3  | -11.5 | -14.6 |

In [25]:
import numpy as np
import matplotlib.pyplot as plt

def plot_all_data(data):
    plt.style.use('seaborn-v0_8-paper')
    plt.rcParams.update({
        'font.size': 20,
        'axes.labelsize': 20,
        'axes.titlesize': 20,
        'legend.fontsize': 20,
        'xtick.labelsize': 20,
        'ytick.labelsize':20 
    })
    
    fig, axes = plt.subplots(1, 3, figsize=(20, 5))
    baseline_color = '#4878D0'
    test_color = '#EE854A'
    
    for ax, (pair, pair_data) in zip(axes, data.items()):
        ax.fill_between(pair_data['baseline_x'],
                       np.array(pair_data['baseline_y']) - np.array(pair_data['baseline_std']),
                       np.array(pair_data['baseline_y']) + np.array(pair_data['baseline_std']),
                       color=baseline_color, alpha=0.2)
        
        ax.plot(pair_data['baseline_x'], pair_data['baseline_y'], 'o-', 
               color=baseline_color, alpha=0.7, label='Baseline')
        ax.plot(pair_data['test_x'], pair_data['test_y'], 's-', 
               color=test_color, alpha=0.9, label='Test')
        
        ax.set_ylabel('Performance Gap (%)') if ax.get_position().x0 < 0.1 else None
        ax.set_title(pair)
        ax.grid(True, linestyle='--', alpha=0.7)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
    
    axes[0].legend()
    plt.tight_layout()
    return fig

In [None]:
data = {
    '1B-8B': {
        'baseline_x': [0, 0.2, 0.4, 0.6, 0.8, 1.0],
        'baseline_y': [0.0, -3.1, -7.0, -11.1, -14.3, -18.5],
        'baseline_std': [0.0, 0.6, 0.7, 1.6, 0.9, 0.0],
        'test_x': [0.260, 0.490, 0.596, 0.772],
        'test_y': [-4.9, -7.9, -10.9, -15.7]
    },
    '8B-70B': {
        'baseline_x': [0, 0.2, 0.4, 0.6, 0.8, 1.0],
        'baseline_y': [0.0, 1.0, 1.3, 2.3, 3.1, 3.9],
        'baseline_std': [0.0, 0.8, 0.6, 0.6, 0.7, 0.0],
        'test_x': [0.204, 0.392, 0.544, 0.754],
        'test_y': [-0.7, 1.5, 4.9, 7.3]
    },
    '1B-70B': {
        'baseline_x': [0, 0.2, 0.4, 0.6, 0.8, 1.0],
        'baseline_y': [0.0, -2.4, -6.2, -9.3, -11.5, -14.6],
        'baseline_std': [0.0, 0.7, 0.9, 1.3, 1.0, 0.0],
        'test_x': [0.230, 0.372, 0.624, 0.848],
        'test_y': [-1.7, -3.1, -6.3, -11.1]
    }
}

fig = plot_all_data(data)
plt.savefig('figures/all_perf_comparison.png', bbox_inches='tight', dpi=300) 

In [None]:
data = {
    '1B-8B': {
        'baseline_x': [0, 0.2, 0.4, 0.6, 0.8, 1.0],
        'baseline_y': [0.0, -3.1, -7.0, -11.1, -14.3, -18.5],
        'baseline_std': [0.0, 0.6, 0.7, 1.6, 0.9, 0.0],
        'test_x': [0.158, 0.354, 0.574, 0.782],
        'test_y': [-4.1, -10.6, -15.3, -18.7]
    },
    '8B-70B': {
        'baseline_x': [0, 0.2, 0.4, 0.6, 0.8, 1.0],
        'baseline_y': [0.0, 1.0, 1.3, 2.3, 3.1, 3.9],
        'baseline_std': [0.0, 0.8, 0.6, 0.6, 0.7, 0.0],
        'test_x': [0.214, 0.376, 0.490, 0.722],
        'test_y': [-0.2, 0.9, 3.5, 6.9]
    },
    '1B-70B': {
        'baseline_x': [0, 0.2, 0.4, 0.6, 0.8, 1.0],
        'baseline_y': [0.0, -2.4, -6.2, -9.3, -11.5, -14.6],
        'baseline_std': [0.0, 0.7, 0.9, 1.3, 1.0, 0.0],
        'test_x': [0.216, 0.392, 0.514, 0.748],
        'test_y': [-5.9, -9.4, -10.6, -10.7]
    }
}

fig = plot_all_data(data)
plt.savefig('figures/all_perf_comparison_simpl.png', bbox_inches='tight', dpi=300)

# just a plot of the winrates for everything

In [None]:
import json
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from collections import Counter

def extract_winrates(filename):
    winrates = []
    with open(filename) as f:
        for line in f:
            data = json.loads(line)
            winrates.append(data['target'])
    return winrates

def plot_model_winrates(file1, file2, file3, output_file):
    plt.style.use('seaborn-v0_8-paper')
    sns.set_palette("muted")
    plt.rcParams.update({
        'font.size': 20,
        'axes.labelsize': 20,
        'axes.titlesize': 20,
        'legend.fontsize': 20,
        'xtick.labelsize': 20,
        'ytick.labelsize':20 
    })
    
    winrates1 = extract_winrates(file1)
    winrates2 = extract_winrates(file2)
    winrates3 = extract_winrates(file3)
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    sns.kdeplot(data=winrates1, label='1B-70B', ax=ax)
    sns.kdeplot(data=winrates2, label='8B-70B', ax=ax)
    sns.kdeplot(data=winrates3, label='1B-8B', ax=ax)
    
    ax.set_ylabel('', fontsize=10)
    ax.grid(True, linestyle='--', alpha=0.7)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.legend()
    
    plt.tight_layout()
    plt.savefig(output_file)
    return fig

files = [
    'data/1b_70b_wins_simpl.jsonl',
    'data/8b_70b_wins_simpl.jsonl',
    'data/1b_8b_wins_simpl.jsonl'
]

_ = plot_model_winrates(
    files[0],
    files[1],
    files[2],
    'figures/winrates_simpl.png'
)

files = [
    'data/1b_70b_wins.jsonl',
    'data/8b_70b_wins.jsonl',
    'data/1b_8b_wins.jsonl'
]

_ = plot_model_winrates(
    files[0],
    files[1],
    files[2],
    'figures/winrates.png'
)