disclosure: lots of AI-generated boilerplate code scattered throughout

# baseline results

In [None]:
import json
import numpy as np
    
def evaluate_routing(small_file, large_file, small_proportion, n_trials= 10, seed=42):
    with open(small_file) as f: small_outputs = [json.loads(line) for line in f]
    with open(large_file) as f: large_outputs = [json.loads(line) for line in f]
        
    n_examples = len(small_outputs)
    n_to_small = int(small_proportion * n_examples)
    
    pass_rates = []
    perfect_rates = []
    rng = np.random.default_rng(seed)

    for _ in range(n_trials):
        small_set = set(rng.choice(n_examples, n_to_small, replace=False))
        
        total_test_cases = 0
        passed_test_cases = 0
        perfect_solutions = 0
        
        for i in range(n_examples):
            results = small_outputs[i]['results'] if i in small_set else large_outputs[i]['results']
            
            for r in results:
                if r.get('code_error'): continue
                
                # individual cases
                for result, _ in r['test_results']:
                    total_test_cases += 1
                    passed_test_cases += result == 'pass'
                
                # perfect solutions
                if all(res == 'pass' for res, _ in r['test_results']):
                    perfect_solutions += 1
        
        pass_rates.append(passed_test_cases / total_test_cases * 100)
        perfect_rates.append(perfect_solutions / (n_examples * len(results)) * 100)
    
    return pass_rates, perfect_rates

props = [0.2, 0.4, 0.6, 0.8]
model_pairs = [
    ('outputs/1b_test_outputs.jsonl', 'outputs/8b_test_outputs.jsonl'),
    ('outputs/8b_test_outputs.jsonl', 'outputs/70b_test_outputs.jsonl'),
    ('outputs/1b_test_outputs.jsonl', 'outputs/70b_test_outputs.jsonl')
]

for small_file, large_file in model_pairs:
    print(f"\n{small_file.split('/')[-1]} vs {large_file.split('/')[-1]}")
    for prop in [0., 0.2, 0.4, 0.6, 0.8, 1.]:
        pass_rates, perfect_rates = evaluate_routing(small_file, large_file, prop)
        print(f"\n{prop*100}% routed to small model:")
        print(f"Pass rate: {np.mean(pass_rates):.1f}% ± {np.std(pass_rates):.1f}%")
        print(f"Perfect rate: {np.mean(perfect_rates):.1f}% ± {np.std(perfect_rates):.1f}%")

# compute winrates

In [None]:
import json
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
from sample import format_prompt
from matplotlib.gridspec import GridSpec

def analyze_model_comparison(small, large, output_file):
    win_freq = defaultdict(int)
    reason_freq = defaultdict(int)
    
    def get_win_reason(gen1, gen2, tok1, tok2):
        if not gen1.get('code_error') and not gen2.get('code_error'):
            if gen1['pass'] != gen2['pass']:
                return 'test_cases', gen1['pass'] > gen2['pass']
            return 'tokens', tok1 <= tok2
        if gen1.get('code_error') is None and gen2.get('code_error') is None:
            raise Exception('both failed')
        return 'code_error', not gen1.get('code_error')

    with open(small) as f1, open(large) as f2, open(output_file, 'w') as f_out:
        for line1, line2 in zip(f1, f2):
            o1, o2 = json.loads(line1), json.loads(line2)
            prompt = format_prompt(o1['item'])
            
            wins = 0
            for pair in zip(o1['results'], o2['results'], o1['num_tokens'], o2['num_tokens']):
                reason, is_win = get_win_reason(*pair)
                wins += is_win
                reason_freq[reason] += 1
                
            win_freq[wins] += 1
            json.dump({'prompt': prompt, 'target': wins / 10}, f_out)
            f_out.write('\n')

    plt.style.use('seaborn-v0_8-paper')
    sns.set_palette("muted")
    
    fig = plt.figure(figsize=(12, 5))
    gs = GridSpec(1, 2, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    
    bar_color = "#4878D0"
    edge_color = "#2F4858"
    
    k, v = zip(*sorted(win_freq.items(), key=lambda x: x[0]))
    sns.barplot(x=list(k), y=list(v), ax=ax1, color=bar_color, edgecolor=edge_color)
    ax1.set_title('Distribution of Model Wins', fontsize=11, pad=15)
    ax1.set_xlabel('Number of Wins (out of 10)', fontsize=10)
    ax1.set_ylabel('Frequency', fontsize=10)
    ax1.grid(True, axis='y', linestyle='--', alpha=0.7)
    ax1.spines['top'].set_visible(False)
    ax1.spines['right'].set_visible(False)

    reasons, counts = zip(*sorted(reason_freq.items(), key=lambda x: x[1], reverse=True))
    reason_labels = {'tokens': 'Less Tokens', 'test_cases': 'More Test Cases', 'code_error': 'Code Error'}
    formatted_reasons = [reason_labels[r] for r in reasons]

    sns.barplot(x=formatted_reasons, y=counts, ax=ax2, color=bar_color, edgecolor=edge_color)
    ax2.set_title('Determining Factor by Frequency', fontsize=11, pad=15)
    ax2.set_xlabel('Determining Factor', fontsize=10)
    ax2.set_ylabel('Frequency', fontsize=10)
    ax2.grid(True, axis='y', linestyle='--', alpha=0.7)
    ax2.spines['top'].set_visible(False)
    ax2.spines['right'].set_visible(False)

    plt.tight_layout()
    return fig, win_freq, reason_freq

fig, wins, reasons = analyze_model_comparison(
    'outputs/1b_train_outputs.jsonl',
    'outputs/70b_train_outputs.jsonl',
    'outputs/1b_70b_wins.jsonl'
)
plt.savefig('figures/1b_70b_analysis.png')

fig, wins, reasons = analyze_model_comparison(
    'outputs/8b_train_outputs.jsonl',
    'outputs/70b_train_outputs.jsonl',
    'outputs/8b_70b_wins.jsonl'
)
plt.savefig('figures/8b_70b_analysis.png')

fig, wins, reasons = analyze_model_comparison(
    'outputs/1b_train_outputs.jsonl',
    'outputs/8b_train_outputs.jsonl',
    'outputs/1b_8b_wins.jsonl'
)
plt.savefig('figures/1b_8b_analysis.png')

## validation set threshold tuning

In [58]:
import json
from generate import format_prompt
from datasets import load_dataset

dataset = load_dataset('google-research-datasets/mbpp')

with open('data/test_prompts.jsonl', 'w') as f:
    for item in dataset['test']:
        json.dump({'prompt': format_prompt(item), 'target': 0}, f)
        f.write('\n')
        
with open('data/val_prompts.jsonl', 'w') as f:
    for item in dataset['validation']:
        json.dump({'prompt': format_prompt(item), 'target': 0}, f)
        f.write('\n')

In [8]:
import torch
from router import Router, RouterDataset
from tqdm import tqdm
from torch.utils.data import DataLoader

device = 'mps'

def load_model(from_checkpoint):
    model = Router(hidden_size=384)
    weights = torch.load(from_checkpoint, map_location='cpu')['ema']

    weight_map = {
        'ema_model.head.0.weight': model.head[0].weight,
        'ema_model.head.0.bias': model.head[0].bias,
        'ema_model.head.3.weight': model.head[3].weight,
        'ema_model.head.3.bias': model.head[3].bias
    }

    for ema_key, model_param in weight_map.items():
        model_param.data.copy_(weights[ema_key])
    
    return model.eval().to(device)

@torch.no_grad()
def get_preds(from_checkpoint, split):
    dataset = RouterDataset(input=f'data/{split}_prompts.jsonl')
    loader = DataLoader(dataset, batch_size=8, shuffle=False)

    model = load_model(from_checkpoint)
    all_preds = []
    for batch in tqdm(loader):
        logits = model.forward(batch['input_ids'].to(device), batch['attention_mask'].to(device))
        preds = torch.sigmoid(logits)
        all_preds.extend(preds.tolist())
    return all_preds

In [None]:
import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def analyze_thresholds(from_checkpoint, small_outputs, large_outputs, title=''):
    preds = np.array(get_preds(from_checkpoint, split='val'))

    with open(small_outputs, 'r') as f:
        small_results = np.array([json.loads(line)['results'][0]['pass'] for line in f])
    with open(large_outputs, 'r') as f:
        large_results = np.array([json.loads(line)['results'][0]['pass'] for line in f])
        
    large_pass = np.sum(large_results) / (3 * len(large_results))
    
    percentiles = [20, 40, 60, 80]
    percentile_thresholds = np.percentile(preds, percentiles)
    
    plt.style.use('seaborn-v0_8-paper')
    sns.set_palette("muted")
    plt.figure(figsize=(10, 6))
    hist_color = "#4878D0"
    threshold_color = "#2F4858"
    plt.hist(preds, bins=30, color=hist_color, alpha=0.6, edgecolor='white')
    for t in percentile_thresholds: plt.axvline(x=t, color=threshold_color, linestyle='--', alpha=0.5)
    plt.title(title)
    plt.xlabel('Model Confidence Score', fontsize=10)
    plt.ylabel('Frequency', fontsize=10)
    plt.grid(True, axis='y', linestyle='--', alpha=0.7)
    for spine in ['top', 'right']: plt.gca().spines[spine].set_visible(False)
    plt.tight_layout()

    for p, t in zip(percentiles, percentile_thresholds):
        route_small = preds < t
        mixed_pass = np.sum(np.where(route_small, small_results, large_results)) / (3 * len(large_results))
        cost_adv = np.sum(route_small) / len(preds)
        perf_gap = mixed_pass - large_pass
        print(f"{p}th percentile:")
        print(f"  t: {t:.6f}")
        print(f"  perf_gap: {perf_gap:.3f}")
        print(f"  cost_adv: {cost_adv:.3f}")
        
    return plt.gcf(), preds

fig, preds = analyze_thresholds(
    'checkpoints/1b_8b_aug=3_epoch=10.pt', 
    'data/1b_val_outputs.jsonl',
    'data/8b_val_outputs.jsonl',
    title='1B vs. 8B'
)
plt.savefig('figures/1b_8b_threshold.png', bbox_inches='tight', dpi=300)

fig, preds = analyze_thresholds(
    'checkpoints/8b_70b_aug=3_epoch=10.pt',
    'data/8b_val_outputs.jsonl',
    'data/70b_val_outputs.jsonl',
    title='8B vs. 70B'
)
plt.savefig('figures/8b_70b_threshold.png', bbox_inches='tight', dpi=300)

fig, preds = analyze_thresholds(
    'checkpoints/1b_70b_aug=3_epoch=10.pt', 
    'data/1b_val_outputs.jsonl',
    'data/70b_val_outputs.jsonl',
    title='1B vs. 70B'
)
plt.savefig('figures/1b_70b_threshold.png', bbox_inches='tight', dpi=300)

validation-tuned percentiles thresholds (label-not-augmented)

| Pair      | 20%      | 40%      | 60%      | 80%       |
| --------- | -------- | -------- | -------- | --------- |
| 1B vs 8B  | 0.0826   | 0.086494 | 0.089773 | 0.096273  |
| 8B vs 70B | 0.745153 | 0.758044 | 0.766146 | 0.7747814 |
| 1B vs 70B | 0.218676 | 0.219519 | 0.220657 | 0.222232  |


## test set eval

In [None]:
def eval(from_checkpoint, small_outputs, large_outputs, thresholds):
    preds = np.array(get_preds(from_checkpoint, split='test'))

    with open(small_outputs, 'r') as f:
        small_results = np.array([json.loads(line)['results'][0]['pass'] for line in f])
    with open(large_outputs, 'r') as f:
        large_results = np.array([json.loads(line)['results'][0]['pass'] for line in f])
        
    large_pass = np.sum(large_results) / (3 * len(large_results))
    
    for t in thresholds:
        route_small = preds < t
        mixed_pass = np.sum(np.where(route_small, small_results, large_results)) / (3 * len(large_results))
        cost_adv = np.sum(route_small) / len(preds)
        perf_gap = mixed_pass - large_pass
        print(f"with threshold {t}")
        print(f"  t: {t:.3f}")
        print(f"  perf_gap: {perf_gap:.3f}")
        print(f"  cost_adv: {cost_adv:.3f}")

eval(
    'checkpoints/1b_8b_aug=3_epoch=10.pt',
    'data/1b_test_outputs.jsonl',
    'data/8b_test_outputs.jsonl',
    [0.0826, 0.086494, 0.089773, 0.096273]
)

eval(
    'checkpoints/8b_70b_aug=3_epoch=10.pt',
    'data/8b_test_outputs.jsonl',
    'data/70b_test_outputs.jsonl',
    [0.745153, 0.758044, 0.766146, 0.7747814]
)

eval(
    'checkpoints/1b_70b_aug=3_epoch=10.pt', 
    'data/1b_test_outputs.jsonl',
    'data/70b_test_outputs.jsonl',
    [0.218676, 0.219519, 0.220657, 0.222232]
)

# test set values

**perf gap vs. cost advantage for each setting (parentheses indicate actual percentile)**

| Pair      | 20%          | 40%          | 60%           | 80%           |
| --------- | ------------ | ------------ | ------------- | ------------- |
| 1B vs 8B  | -4.9 (0.260) | -7.9 (0.490) | -10.9 (0.596) | -15.7 (0.772) |
| 8B vs 70B | -0.7 (0.204) | 1.5 (0.392)  | 4.9 (0.544)   | 7.3 (0.754)   |
| 1B vs 70B | -1.7 (0.230) | -3.1 (0.372) | -6.3 (0.624)  | -11.1 (0.848) |

# baseline

**perf gap vs. cost advantage**

| Pair      | 0%  | 20%  | 40%  | 60%   | 80%   | 100%  |
| --------- | --- | ---- | ---- | ----- | ----- | ----- |
| 1B vs 8B  | 0.0 | -3.1 | -7.0 | -11.1 | -14.3 | -18.5 |
| 8B vs 70B | 0.0 | 1.0  | 1.3  | 2.3   | 3.1   | 3.9   |
| 1B vs 70B | 0.0 | -2.4 | -6.2 | -9.3  | -11.5 | -14.6 |

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def plot_all_data(data):
    plt.style.use('seaborn-v0_8-paper')
    fig, axes = plt.subplots(1, 3, figsize=(20, 5))
    
    baseline_color = '#4878D0'
    test_color = '#EE854A'
    
    for ax, (pair, pair_data) in zip(axes, data.items()):
        # Plot baseline with std dev
        ax.fill_between(pair_data['baseline_x'], 
                       np.array(pair_data['baseline_y']) - np.array(pair_data['baseline_std']),
                       np.array(pair_data['baseline_y']) + np.array(pair_data['baseline_std']),
                       color=baseline_color, alpha=0.2)
        
        ax.plot(pair_data['baseline_x'], pair_data['baseline_y'], 'o-', 
               color=baseline_color, alpha=0.7, label='Baseline')
        ax.plot(pair_data['test_x'], pair_data['test_y'], 's-', 
               color=test_color, alpha=0.9, label='Test')
        
        ax.set_xlabel('Cost Advantage')
        ax.set_ylabel('Performance Gap (%)') if ax.get_position().x0 < 0.1 else None
        ax.set_title(pair)
        ax.grid(True, linestyle='--', alpha=0.7)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
    
    axes[0].legend()
    plt.tight_layout()
    return fig

data = {
    '1B vs. 8B': {
        'baseline_x': [0, 0.2, 0.4, 0.6, 0.8, 1.0],
        'baseline_y': [0.0, -3.1, -7.0, -11.1, -14.3, -18.5],
        'baseline_std': [0.0, 0.6, 0.7, 1.6, 0.9, 0.0],
        'test_x': [0.260, 0.490, 0.596, 0.772],
        'test_y': [-4.9, -7.9, -10.9, -15.7]
    },
    '8B vs. 70B': {
        'baseline_x': [0, 0.2, 0.4, 0.6, 0.8, 1.0],
        'baseline_y': [0.0, 1.0, 1.3, 2.3, 3.1, 3.9],
        'baseline_std': [0.0, 0.8, 0.6, 0.6, 0.7, 0.0],
        'test_x': [0.204, 0.392, 0.544, 0.754],
        'test_y': [-0.7, 1.5, 4.9, 7.3]
    },
    '1B vs. 70B': {
        'baseline_x': [0, 0.2, 0.4, 0.6, 0.8, 1.0],
        'baseline_y': [0.0, -2.4, -6.2, -9.3, -11.5, -14.6],
        'baseline_std': [0.0, 0.7, 0.9, 1.3, 1.0, 0.0],
        'test_x': [0.230, 0.372, 0.624, 0.848],
        'test_y': [-1.7, -3.1, -6.3, -11.1]
    }
}

fig = plot_all_data(data)
plt.savefig('figures/all_perf_comparison.png', bbox_inches='tight', dpi=300)