disclosure: much of this is generated by claude for the sake of pretty figures and quick analysis

# examine outputs

In [None]:
import json
import numpy as np
    
def evaluate_routing(small_file, large_file, small_proportion, n_trials= 10, seed=42):
    with open(small_file) as f: small_outputs = [json.loads(line) for line in f]
    with open(large_file) as f: large_outputs = [json.loads(line) for line in f]
        
    n_examples = len(small_outputs)
    n_to_small = int(small_proportion * n_examples)
    
    pass_rates = []
    perfect_rates = []
    rng = np.random.default_rng(seed)

    for _ in range(n_trials):
        small_set = set(rng.choice(n_examples, n_to_small, replace=False))
        
        total_test_cases = 0
        passed_test_cases = 0
        perfect_solutions = 0
        
        for i in range(n_examples):
            results = small_outputs[i]['results'] if i in small_set else large_outputs[i]['results']
            
            for r in results:
                if r.get('code_error'): continue
                
                # individual cases
                for result, _ in r['test_results']:
                    total_test_cases += 1
                    passed_test_cases += result == 'pass'
                
                # perfect solutions
                if all(res == 'pass' for res, _ in r['test_results']):
                    perfect_solutions += 1
        
        pass_rates.append(passed_test_cases / total_test_cases * 100)
        perfect_rates.append(perfect_solutions / (n_examples * len(results)) * 100)
    
    return pass_rates, perfect_rates

props = [0.2, 0.4, 0.6, 0.8]
model_pairs = [
    ('outputs/1b_test_outputs.jsonl', 'outputs/8b_test_outputs.jsonl'),
    ('outputs/8b_test_outputs.jsonl', 'outputs/70b_test_outputs.jsonl'),
    ('outputs/1b_test_outputs.jsonl', 'outputs/70b_test_outputs.jsonl')
]

for small_file, large_file in model_pairs:
    print(f"\n{small_file.split('/')[-1]} vs {large_file.split('/')[-1]}")
    for prop in [0., 0.2, 0.4, 0.6, 0.8, 1.]:
        pass_rates, perfect_rates = evaluate_routing(small_file, large_file, prop)
        print(f"\n{prop*100}% routed to small model:")
        print(f"Pass rate: {np.mean(pass_rates):.1f}% ± {np.std(pass_rates):.1f}%")
        print(f"Perfect rate: {np.mean(perfect_rates):.1f}% ± {np.std(perfect_rates):.1f}%")

# compute winrates

In [None]:
import json
from collections import defaultdict
import matplotlib.pyplot as plt
from sample import format_prompt

def compare(gen1, gen2, tok1, tok2):
    if not gen1.get('code_error') and not gen2.get('code_error'):
        if gen1['pass'] != gen2['pass']:
            return gen1['pass'] > gen2['pass']
        return tok1 <= tok2
    if gen1.get('code_error') is None and gen2.get('code_error') is None:
        raise Exception('both failed')
    return not gen1.get('code_error')

file1 = 'outputs/1b_train_outputs.jsonl'
file2 = 'outputs/70b_train_outputs.jsonl'
output = 'outputs/1b_70b_wins.jsonl'

win_freq = defaultdict(int)
with open(file1) as f1, open(file2) as f2, open(output, 'w') as f_out:
    for line1, line2 in zip(f1, f2):
        o1, o2 = json.loads(line1), json.loads(line2)
        prompt = format_prompt(o1['item'])
        wins = sum(compare(*pair) for pair in zip(o1['results'], o2['results'], o1['num_tokens'], o2['num_tokens']))
        win_freq[wins] += 1
        json.dump({'prompt': prompt, 'target': wins / 10}, f_out)
        f_out.write('\n')

k, v = zip(*list(sorted(win_freq.items(), key=lambda x: x[0])))
plt.bar(list(k), list(v))
plt.tight_layout()
plt.savefig('figures/1b_70b_wins.png')

In [None]:
import json
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
from sample import format_prompt
from matplotlib.gridspec import GridSpec

def analyze_model_comparison(small, large, output_file):
    win_freq = defaultdict(int)
    reason_freq = defaultdict(int)
    
    def get_win_reason(gen1, gen2, tok1, tok2):
        if not gen1.get('code_error') and not gen2.get('code_error'):
            if gen1['pass'] != gen2['pass']:
                return 'test_cases', gen1['pass'] > gen2['pass']
            return 'tokens', tok1 <= tok2
        if gen1.get('code_error') is None and gen2.get('code_error') is None:
            raise Exception('both failed')
        return 'code_error', not gen1.get('code_error')

    with open(small) as f1, open(large) as f2, open(output_file, 'w') as f_out:
        for line1, line2 in zip(f1, f2):
            o1, o2 = json.loads(line1), json.loads(line2)
            prompt = format_prompt(o1['item'])
            
            wins = 0
            for pair in zip(o1['results'], o2['results'], o1['num_tokens'], o2['num_tokens']):
                reason, is_win = get_win_reason(*pair)
                wins += is_win
                reason_freq[reason] += 1
                
            win_freq[wins] += 1
            json.dump({'prompt': prompt, 'target': wins / 10}, f_out)
            f_out.write('\n')

    plt.style.use('seaborn-v0_8-paper')
    sns.set_palette("muted")
    
    fig = plt.figure(figsize=(12, 5))
    gs = GridSpec(1, 2, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    
    bar_color = "#4878D0"
    edge_color = "#2F4858"
    
    k, v = zip(*sorted(win_freq.items(), key=lambda x: x[0]))
    sns.barplot(x=list(k), y=list(v), ax=ax1, color=bar_color, edgecolor=edge_color)
    ax1.set_title('Distribution of Model Wins', fontsize=11, pad=15)
    ax1.set_xlabel('Number of Wins (out of 10)', fontsize=10)
    ax1.set_ylabel('Frequency', fontsize=10)
    ax1.grid(True, axis='y', linestyle='--', alpha=0.7)
    ax1.spines['top'].set_visible(False)
    ax1.spines['right'].set_visible(False)

    reasons, counts = zip(*sorted(reason_freq.items(), key=lambda x: x[1], reverse=True))
    reason_labels = {'tokens': 'Less Tokens', 'test_cases': 'More Test Cases', 'code_error': 'Code Error'}
    formatted_reasons = [reason_labels[r] for r in reasons]

    sns.barplot(x=formatted_reasons, y=counts, ax=ax2, color=bar_color, edgecolor=edge_color)
    ax2.set_title('Determining Factor by Frequency', fontsize=11, pad=15)
    ax2.set_xlabel('Determining Factor', fontsize=10)
    ax2.set_ylabel('Frequency', fontsize=10)
    ax2.grid(True, axis='y', linestyle='--', alpha=0.7)
    ax2.spines['top'].set_visible(False)
    ax2.spines['right'].set_visible(False)

    plt.tight_layout()
    return fig, win_freq, reason_freq

fig, wins, reasons = analyze_model_comparison(
    'outputs/1b_train_outputs.jsonl',
    'outputs/70b_train_outputs.jsonl',
    'outputs/1b_70b_wins.jsonl'
)
plt.savefig('figures/1b_70b_analysis.png')

fig, wins, reasons = analyze_model_comparison(
    'outputs/8b_train_outputs.jsonl',
    'outputs/70b_train_outputs.jsonl',
    'outputs/8b_70b_wins.jsonl'
)
plt.savefig('figures/8b_70b_analysis.png')

fig, wins, reasons = analyze_model_comparison(
    'outputs/1b_train_outputs.jsonl',
    'outputs/8b_train_outputs.jsonl',
    'outputs/1b_8b_wins.jsonl'
)
plt.savefig('figures/1b_8b_analysis.png')

## inference

In [20]:
import torch
from router import Router

model = Router(hidden_size=384)
weights = torch.load('checkpoints/1b_8b_epoch=10.pt', map_location='cpu')['ema']

weight_map = {
    'ema_model.head.0.weight': model.head[0].weight,
    'ema_model.head.0.bias': model.head[0].bias,
    'ema_model.head.3.weight': model.head[3].weight,
    'ema_model.head.3.bias': model.head[3].bias
}

for ema_key, model_param in weight_map.items():
    model_param.data.copy_(weights[ema_key])
    print(f'loaded {ema_key}')

loaded ema_model.head.0.weight
loaded ema_model.head.0.bias
loaded ema_model.head.3.weight
loaded ema_model.head.3.bias


In [10]:
for param in model.parameters():
    print(param)

Parameter containing:
tensor([[ 1.1034e-03, -3.6097e-04, -1.9512e-03,  ..., -5.6000e-03,
         -6.8550e-03,  3.3997e-02],
        [ 1.1765e-02,  1.6769e-02, -1.3268e-02,  ...,  1.4679e-02,
         -3.7575e-03,  4.2969e-02],
        [ 2.1271e-02,  1.5610e-02, -1.2688e-02,  ...,  1.9821e-02,
         -1.6876e-02,  3.2959e-02],
        ...,
        [ 4.1008e-05, -4.1733e-03, -1.3351e-04,  ..., -1.5106e-02,
         -5.4779e-03,  3.0121e-02],
        [-6.1836e-03, -3.7823e-03, -6.9580e-03,  ..., -6.7368e-03,
         -5.2834e-03,  3.0716e-02],
        [-5.6953e-03, -4.5681e-04,  1.1883e-03,  ..., -6.3858e-03,
         -8.9340e-03,  3.2684e-02]])
Parameter containing:
tensor([0.7490, 0.8237, 0.7305, 0.7642, 0.7256, 0.6836, 0.7178, 0.7358, 0.7598,
        0.7578, 0.6855, 1.0371, 0.7280, 0.7002, 0.7441, 0.7314, 0.7251, 0.8076,
        0.7773, 0.7197, 0.7280, 0.7212, 0.6797, 0.7769, 0.7095, 0.7192, 0.6606,
        0.7988, 0.7466, 0.7651, 0.7490, 0.7417, 0.7075, 0.4795, 0.7549, 0.7065,
    