# Speculative Decoding Baseline Testing

This notebook provides an interactive interface for running speculative decoding benchmarks across multiple model configurations and datasets.

## Setup

First, ensure all dependencies are installed.

In [None]:
# Install dependencies (run once)
!pip install torch transformers accelerate datasets numpy

## Sync Files to Colab (Choose One Method)

If you're running this notebook in Google Colab, you need to sync the `baseline_test_utils.py.py` file.

In [None]:
# Sync baseline_test_utils.py.py (for Colab users)
import os

try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB and not os.path.exists('baseline_test_utils.py.py'):
    import urllib.request
    github_url = "https://raw.githubusercontent.com/tsurbs/SpecDec/main/baseline_test_utils.py.py"
    try:
        urllib.request.urlretrieve(github_url, 'baseline_test_utils.py.py')
        print("Downloaded baseline_test_utils.py.py from GitHub")
    except Exception as e:
        print(f"Download failed: {e}")
        print("Manually upload the file via Colab's file browser")
elif not IN_COLAB:
    print("Using local files")

In [None]:
import json
import torch
import warnings
warnings.filterwarnings('ignore')

from baseline_test_utils.py import (
    SpeculativeDecodingTester,
)

from load_datasets import (
    load_pile_samples,
    load_stack_samples
)

print(f"CUDA: {torch.cuda.is_available()}")

## Configuration

In [None]:
# Model Configurations
MODEL_CONFIGS = {
    'GPT-2': {
        'verifier': 'gpt2-large',
        'draft': 'distilgpt2'
    },
    'Qwen': {
        'verifier': 'Qwen/Qwen2.5-7B',
        'draft': 'Qwen/Qwen2.5-0.5B'
    },
    'Pythia': {
        'verifier': 'EleutherAI/pythia-12b',
        'draft': 'EleutherAI/pythia-70m'
    }
}

# Test Parameters
TEST_PARAMS = {
    'max_new_tokens': 100,
    'gamma': 5,
    'num_nl_samples': 100,
    'num_code_samples': 100
}

CODE_LANGUAGES = ['python', 'c', 'go', 'rust']

print(f"Models: {list(MODEL_CONFIGS.keys())}, Gamma: {TEST_PARAMS['gamma']}, Languages: {CODE_LANGUAGES}")

## Load Test Data

In [None]:
# Load Natural Language samples from The Pile
nl_prompts = load_pile_samples(TEST_PARAMS['num_nl_samples'])
print(f"Loaded {len(nl_prompts)} NL samples")

In [None]:
# Load Code samples from The Stack
code_prompts = load_stack_samples(CODE_LANGUAGES, TEST_PARAMS['num_code_samples'])
print(f"Loaded {len(code_prompts)} code samples")

In [None]:
# Combine all prompts
all_prompts = nl_prompts + code_prompts
print(f"Total: {len(all_prompts)} prompts ({len(nl_prompts)} NL + {len(code_prompts)} Code)")

## Test on Individual Model

In [None]:
# Select model to test
MODEL_TO_TEST = 'GPT-2'  # Options: 'GPT-2', 'Qwen', 'Pythia'

config = MODEL_CONFIGS[MODEL_TO_TEST]
print(f"Testing {MODEL_TO_TEST}: {config['verifier']} + {config['draft']}")

In [None]:
# Initialize tester
tester = SpeculativeDecodingTester(
    verifier_checkpoint=config['verifier'],
    draft_checkpoint=config['draft']
)

In [None]:
# Run quick validation test
result = tester.run_single_test(
    prompt=all_prompts[0]['text'],
    max_new_tokens=50,
    gamma=TEST_PARAMS['gamma'],
    verbose=True
)

In [None]:
# Run full benchmark on selected model
print(f"Running full benchmark on {MODEL_TO_TEST}...")
results = tester.run_benchmark_suite(
    prompts=all_prompts,
    max_new_tokens=TEST_PARAMS['max_new_tokens'],
    gamma=TEST_PARAMS['gamma']
)

In [None]:
# Save results for this model
output_file = f"results_{MODEL_TO_TEST.lower()}.json"

def make_serializable(obj):
    import numpy as np
    if isinstance(obj, (np.floating, np.float32, np.float64)):
        return float(obj)
    elif isinstance(obj, (np.integer, np.int32, np.int64)):
        return int(obj)
    elif isinstance(obj, dict):
        return {k: make_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [make_serializable(i) for i in obj]
    return obj

with open(output_file, 'w') as f:
    json.dump(make_serializable(results), f, indent=2)

print(f"Results saved to: {output_file}")

## Run All Model Configurations

In [None]:
all_results = {}

for model_name, config in MODEL_CONFIGS.items():
    print(f"Testing {model_name}: {config['verifier']} + {config['draft']}")
    
    try:
        tester = SpeculativeDecodingTester(
            verifier_checkpoint=config['verifier'],
            draft_checkpoint=config['draft']
        )
        
        results = tester.run_benchmark_suite(
            prompts=all_prompts,
            max_new_tokens=TEST_PARAMS['max_new_tokens'],
            gamma=TEST_PARAMS['gamma']
        )
        
        all_results[model_name] = results
        
        # Print summary
        for ptype, metrics in results['summary'].items():
            print(f"{ptype}: Acc={metrics['avg_acceptance_rate']}, Speedup={metrics['avg_speedup']}x")
        
        # Clean up
        del tester
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        print(f"{model_name} complete")
        
    except Exception as e:
        print(f"Error: {e}")
        continue

print(f"All tests complete")

## Results Analysis

In [None]:
# Save all results
with open("baseline_results_all.json", 'w') as f:
    json.dump(make_serializable(all_results), f, indent=2)
print("Results saved to baseline_results_all.json")