In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Load content into your drive
!rm -rf SpecDec
!git clone https://github.com/tsurbs/SpecDec.git

# TODO: Replace the following line with the actual path to the drive folder where you want to put the code
%cd /content/drive/MyDrive

In [None]:
!pip install -q transformers accelerate torch datasets matplotlib seaborn pandas
import torch
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
import os
import gc
import time
import sys

In [None]:
# TODO: Set your model directory here
MODEL_DIR = ""
DRAFT_MODEL_PATHS = {
    "Base": "EleutherAI/pythia-70m",
    "Full_FT": os.path.join(MODEL_DIR, "pythia-70m_full_finetune_best.pt"),
    "High_Res": os.path.join(MODEL_DIR, "pythia-70m_high_resource_best.pt"),
    "Med_Res":  os.path.join(MODEL_DIR, "pythia-70m_medium_resource_best.pt"),
    "Low_Res":  os.path.join(MODEL_DIR, "pythia-70m_low_resource_best.pt"),
}
VERIFIER_CHECKPOINT = "EleutherAI/pythia-12b"

GLOBAL_VERIFIER_CACHE = None
GLOBAL_TOKENIZER_CACHE = None

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on: {device}")

In [None]:
# TODO: Set your project directory here
sys.path.append('')
from baseline_benchmarking_utils import SpeculativeDecodingTester

class SpeculativeBenchmarker(SpeculativeDecodingTester):
    """
    A specialized benchmarker for our SpecDec project.

    Improvements over the base class:
    1. Caching: Keeps Verifier (12B pythia) in memory.
    2. Loading: Loads draft models from local .pt checkpoints (finetuned draft_weights_path).
    3. Advanced Metrics: Captures 'acceptance_sequences' for burstiness analysis. ???
    """

    def __init__(self, verifier_checkpoint, draft_weights_path, device=None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Initializing Benchmarker on {self.device}")

        self._load_or_get_verifier(verifier_checkpoint)

        self._load_draft_model(draft_weights_path)

        self.verifier_model.eval()
        self.draft_model.eval()
        print("Init successful!")
        print()

    def _load_or_get_verifier(self, checkpoint_name):
        """Loads verifier or retrieves from cache."""
        global GLOBAL_VERIFIER_CACHE, GLOBAL_TOKENIZER_CACHE

        # Load Tokenizer
        if GLOBAL_TOKENIZER_CACHE is None:
            print(f"Loading Tokenizer: {checkpoint_name}")
            self.tokenizer = AutoTokenizer.from_pretrained(checkpoint_name)
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            GLOBAL_TOKENIZER_CACHE = self.tokenizer
        else:
            print("Using Cached Tokenizer.")
            self.tokenizer = GLOBAL_TOKENIZER_CACHE

        # Load Verifier Model
        if GLOBAL_VERIFIER_CACHE is None:
            print(f"Loading Verifier (Heavy): {checkpoint_name}")
            self.verifier_model = AutoModelForCausalLM.from_pretrained(
                checkpoint_name,
                device_map="auto",
                torch_dtype=torch.float16,
                low_cpu_mem_usage=True
            )
            GLOBAL_VERIFIER_CACHE = self.verifier_model
        else:
            print("Using Cached Verifier from memory.")
            self.verifier_model = GLOBAL_VERIFIER_CACHE

    def _load_draft_model(self, weights_path):
        """Loads architecture and overwrites with custom .pt weights."""
        print("Empty GPU Cache before Draft Load.")
        gc.collect()
        torch.cuda.empty_cache()

        print(f"Loading Draft Architecture: EleutherAI/pythia-70m")
        self.draft_model = AutoModelForCausalLM.from_pretrained(
            "EleutherAI/pythia-70m",
            torch_dtype=torch.float16
        )

        # Overwrite with custom weights if provided
        if weights_path == "EleutherAI/pythia-70m":
            print("Using Standard Pretrained Weights (Base).")
        else:
            print(f"Overwriting weights from: {weights_path}")
            try:
                state_dict = torch.load(weights_path, map_location='cpu')

                self.draft_model.load_state_dict(state_dict)
                print("Weights loaded.")
            except Exception as e:
                print(f"CRITICAL ERROR loading weights: {e}")
                raise e

        self.draft_model.to(self.device)

    @torch.no_grad()
    def speculative_decoding(self, input_ids: torch.Tensor, max_new_tokens: int, gamma: int = 5):
        """
        Modified execution loop to return acceptance counts for burstiness charts.
        """
        start_time = time.time()

        total_draft_tokens = 0
        accepted_draft_tokens = 0
        acceptance_counts = [] # Tracks how many tokens accepted per step

        curr_input_ids = input_ids.clone()
        target_length = input_ids.shape[1] + max_new_tokens

        while curr_input_ids.shape[1] < target_length:
            # 1. Draft
            draft_outputs = self.draft_model.generate(
                curr_input_ids,
                max_new_tokens=gamma,
                do_sample=False,
                pad_token_id=self.tokenizer.eos_token_id
            )
            draft_tokens = draft_outputs[0, curr_input_ids.shape[1]:]

            # 2. Verify
            verifier_input = torch.cat([curr_input_ids, draft_tokens.unsqueeze(0)], dim=1)
            verifier_outputs = self.verifier_model(verifier_input)
            logits = verifier_outputs.logits

            start_pos = curr_input_ids.shape[1] - 1
            end_pos = verifier_input.shape[1] - 1
            predicted_tokens = torch.argmax(logits[0, start_pos:end_pos], dim=-1)

            # 3. Accept
            n_matches = 0
            for i in range(len(draft_tokens)):
                if draft_tokens[i] == predicted_tokens[i]:
                    n_matches += 1
                else:
                    break

            # METRIC CAPTURE
            acceptance_counts.append(n_matches)
            total_draft_tokens += len(draft_tokens)
            accepted_draft_tokens += n_matches

            # 4. Update
            accepted_sequence = draft_tokens[:n_matches]
            curr_input_ids = torch.cat([curr_input_ids, accepted_sequence.unsqueeze(0)], dim=1)

            # 5. Correct/Bonus
            if n_matches < len(draft_tokens):
                correction_token = predicted_tokens[n_matches].unsqueeze(0).unsqueeze(0)
                curr_input_ids = torch.cat([curr_input_ids, correction_token], dim=1)

            if curr_input_ids.shape[1] >= target_length:
                break

        latency = time.time() - start_time
        acc_rate = accepted_draft_tokens / total_draft_tokens if total_draft_tokens > 0 else 0

        return curr_input_ids, latency, acc_rate, acceptance_counts

    def run_single_test(self, prompt: str, max_new_tokens: int = 100, gamma: int = 5):
        """
        Runs one test, returns rich data dictionary.
        """
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
        input_ids = inputs.input_ids.to(self.device)

        # 1. Baseline Run
        _, base_time = self.standard_autoregressive_generation(input_ids, max_new_tokens)

        # 2. Speculative Run
        _, spec_time, acc_rate, acc_counts = self.speculative_decoding(
            input_ids, max_new_tokens, gamma=gamma
        )

        speedup = base_time / spec_time if spec_time > 0 else 0

        return {
            "prompt_len": input_ids.shape[1],
            "tokens_gen": max_new_tokens,
            "baseline_time": base_time,
            "speculative_time": spec_time,
            "speedup": speedup,
            "acceptance_rate": acc_rate,
            "acceptance_sequence": acc_counts
        }

In [None]:
from huggingface_hub import login
from google.colab import userdata
login(userdata.get("HF_TOKEN"))

In [None]:
from datasets import load_dataset

NUM_SAMPLES_PER_LANG = 50
PROMPT_LENGTH = 128
SEED = 10623

LANG_BUCKETS = {
    "High_Res": ["python", "cpp", "java"],
    "Med_Res":  ["lua", "haskell", "ruby"],
    "Low_Res":  ["zig", "elixir", "ocaml"]
}

In [None]:
def get_stack_prompts(languages, num_per_lang):
    """Streams code samples from The Stack and creates prompts"""
    prompts = []

    for lang in languages:
        print(f"Streaming {lang}...")
        try:
            ds = load_dataset(
                "bigcode/the-stack-dedup",
                data_dir=f"data/{lang}",
                split="train",
                streaming=True
            )

            count = 0
            for sample in ds:
                content = sample['content']

                if len(content) < 500:
                    continue

                prompt_text = content[:1000]

                prompts.append({
                    "language": lang,
                    "text": prompt_text,
                    "source": "The Stack"
                })

                count += 1
                if count >= num_per_lang:
                    break
        except Exception as e:
            print(f"  ! Error loading {lang}: {e}")

    return prompts

In [None]:
def get_nl_prompts(num_samples):
    """Gets Natural Language prompts from a Pile subset"""
    print("Streaming NL (The Pile)...")
    ds = load_dataset("monology/pile-uncopyrighted", split="train", streaming=True)
    prompts = []
    count = 0
    for sample in ds:
        text = sample['text']
        if len(text) < 500: continue

        prompts.append({
            "language": "English",
            "text": text[:1000],
            "source": "The Pile"
        })
        count += 1
        if count >= num_samples:
            break
    return prompts

In [None]:
DATASET = {}

print("Starting Data Ingestion")

for bucket_name, langs in LANG_BUCKETS.items():
    print(f"Loading {bucket_name} subset")
    DATASET[bucket_name] = get_stack_prompts(langs, NUM_SAMPLES_PER_LANG)

print("Loading Natural_Language subset")
DATASET["Natural_Language"] = get_nl_prompts(num_samples=NUM_SAMPLES_PER_LANG * 3) # Accommodate NL set

print("Data Ingestion Complete")
for k, v in DATASET.items():
    print(f"{k}: {len(v)} prompts")

In [None]:
from tqdm.notebook import tqdm
import pandas as pd

def run_experiment_suite(gamma=5, max_new_tokens=64, debug=False):
    """
    Main execution loop.
    Iterates through all models in DRAFT_MODEL_PATHS.
    Iterates through all buckets in DATASET.

    Args:
        debug (bool): If True, only runs 1 prompt per bucket per model.
    """

    all_results = []

    for model_name, weights_path in DRAFT_MODEL_PATHS.items():
        print(f"BENCHMARKING MODEL: {model_name}")

        try:
            benchmarker = SpeculativeBenchmarker(
                verifier_checkpoint=VERIFIER_CHECKPOINT,
                draft_weights_path=weights_path,
                device=device
            )
        except Exception as e:
            print(f"SKIPPING {model_name} due to load error: {e}")
            continue

        for bucket_name, prompts in DATASET.items():
            print(f"Processing Bucket: {bucket_name} ({len(prompts)} prompts)")

            current_prompts = prompts[:1] if debug else prompts

            for i, prompt_data in tqdm(enumerate(current_prompts),
                                     total=len(current_prompts),
                                     desc=f"{model_name} on {bucket_name}",
                                     leave=False):

                try:
                    metrics = benchmarker.run_single_test(
                        prompt=prompt_data['text'],
                        max_new_tokens=max_new_tokens,
                        gamma=gamma
                    )

                    row = {
                        "Model": model_name,
                        "Bucket": bucket_name,
                        "Language": prompt_data['language'],
                        "Gamma": gamma,
                        "Speedup": metrics['speedup'],
                        "Alpha": metrics['acceptance_rate'],
                        "Base_Time": metrics['baseline_time'],
                        "Spec_Time": metrics['speculative_time'],
                        "Burst_Sequence": metrics['acceptance_sequence']
                    }
                    all_results.append(row)

                except Exception as e:
                    print(f"    Error on prompt {i}: {e}")
                    continue

        del benchmarker
        gc.collect()
        torch.cuda.empty_cache()

    df = pd.DataFrame(all_results)
    return df

def save_results(df, filename="specdec_results.csv"):
    path = os.path.join(MODEL_DIR, filename)
    df.to_csv(path, index=False)
    print(f"\nResults saved to: {path}")
    return path

In [None]:
import datetime

GEN_LENGTH = 128
FIXED_GAMMA = 5

print(f"STARTING EXPERIMENT A")
print(f"Timestamp: {datetime.datetime.now()}")
print(f"Generation Length: {GEN_LENGTH}")
print(f"Gamma: {FIXED_GAMMA}")

df_main = run_experiment_suite(
    gamma=FIXED_GAMMA,
    max_new_tokens=GEN_LENGTH,
    debug=False
)

main_csv_path = save_results(df_main, filename="specdec_main_results.csv")

print("EXPERIMENT A COMPLETE")
print(f"Total Runs: {len(df_main)}")
print(f"Average Speedup: {df_main['Speedup'].mean():.2f}x")

In [None]:
GAMMAS = [1, 2, 4, 8, 16]
TARGET_MODEL = "Full_FT"
TARGET_BUCKET = "High_Res"

print("STARTING EXPERIMENT B (GAMMA SWEEP)")
print(f"Model: {TARGET_MODEL}")
print(f"Gammas: {GAMMAS}")

gamma_results = []

try:
    benchmarker = SpeculativeBenchmarker(
        verifier_checkpoint=VERIFIER_CHECKPOINT,
        draft_weights_path=DRAFT_MODEL_PATHS[TARGET_MODEL],
        device=device
    )

    prompts = DATASET[TARGET_BUCKET]

    for g in GAMMAS:
        print(f"Testing Gamma = {g}")

        for i, prompt_data in tqdm(enumerate(prompts), total=len(prompts), desc=f"Gamma {g}"):
            try:
                metrics = benchmarker.run_single_test(
                    prompt=prompt_data['text'],
                    max_new_tokens=128,
                    gamma=g
                )

                gamma_results.append({
                    "Model": TARGET_MODEL,
                    "Bucket": TARGET_BUCKET,
                    "Language": prompt_data['language'],
                    "Gamma": g,
                    "Speedup": metrics['speedup'],
                    "Alpha": metrics['acceptance_rate'],
                    "Burst_Sequence": metrics['acceptance_sequence']
                })
            except Exception as e:
                print(f"Error on prompt {i}: {e}")

    del benchmarker
    gc.collect()
    torch.cuda.empty_cache()

    df_gamma = pd.DataFrame(gamma_results)
    gamma_csv_path = save_results(df_gamma, filename="specdec_gamma_sweep.csv")

    print("EXPERIMENT B COMPLETE")

except Exception as e:
    print(f"CRITICAL FAILURE IN EXP B: {e}")

In [None]:
def load_exp_a_data(filename="specdec_main_results.csv"):
    path = os.path.join(MODEL_DIR, filename)
    print(f"Loading data from {path}...")
    df = pd.read_csv(path)

    # Clean up lists (read from string back to list)
    if 'Burst_Sequence' in df.columns and isinstance(df['Burst_Sequence'].iloc[0], str):
        df['Burst_Sequence'] = df['Burst_Sequence'].apply(ast.literal_eval)

    return df

In [None]:
def plot_overall_speedup(df):
    plt.figure(figsize=(12, 7))

    # Filter to ensure we only plot models that exist in the CSV
    available_models = [m for m in MODEL_ORDER if m in df['Model'].unique()]

    sns.barplot(
        data=df,
        x="Bucket",
        y="Speedup",
        hue="Model",
        hue_order=available_models,
        palette=CUSTOM_PALETTE,
        errorbar="sd",
        capsize=0.05,
        edgecolor="white",
        linewidth=1
    )

    # Reference Line
    plt.axhline(1.0, color='red', linestyle='--', linewidth=2, label="Baseline (1.0x)")

    # Styling
    plt.title("Speculative Decoding Speedup by Domain", fontweight='bold', pad=20)
    plt.ylabel("Speedup Factor (x)", fontweight='bold')
    plt.xlabel("")
    plt.ylim(0, None) # Don't start at 0, emphasizes the difference
    plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)

    # Save
    save_path = os.path.join(MODEL_DIR, "fig_expA_speedup.png")
    plt.savefig(save_path)
    print(f"Saved Speedup Chart to {save_path}")
    plt.show()

In [None]:
def plot_language_disparity(df):
    """
    Compares Base vs Full_FT acceptance rates across specific languages.
    Focuses on the 'Low Resource' languages to show the improvement.
    """
    # Filter for just Base and Full_FT
    subset = df[df['Model'].isin(["Base", "Full_FT"])].copy()

    plt.figure(figsize=(14, 6))

    sns.barplot(
        data=subset,
        x="Language",
        y="Alpha",
        hue="Model",
        hue_order=["Base", "Full_FT"],
        palette={"Base": "#95a5a6", "Full_FT": "#2ecc71"},
        alpha=0.9
    )

    plt.title("Impact of Finetuning on Token Acceptance Rate per Language", fontweight='bold', pad=20)
    plt.ylabel("Acceptance Rate (Alpha)")
    plt.xlabel("Programming Language")
    plt.xticks(rotation=45, ha='right')
    plt.legend(title="Model Configuration")

    save_path = os.path.join(MODEL_DIR, "fig_expA_fairness.png")
    plt.savefig(save_path)
    print(f"Saved Fairness Chart to {save_path}")
    plt.show()

In [None]:
def plot_burstiness_distribution(df):
    """
    Shows how many tokens are accepted per step.
    Are we getting 5 tokens often? Or mostly 0?
    """
    burst_data = []

    # Compare Base vs Full_FT on High_Res bucket only (for clarity)
    target_bucket = "High_Res"
    models_to_compare = ["Base", "Full_FT"]

    subset = df[
        (df['Bucket'] == target_bucket) &
        (df['Model'].isin(models_to_compare))
    ]

    for _, row in subset.iterrows():
        for count in row['Burst_Sequence']:
            burst_data.append({
                "Model": row['Model'],
                "Accepted Tokens": count
            })

    df_burst = pd.DataFrame(burst_data)

    plt.figure(figsize=(10, 6))

    sns.histplot(
        data=df_burst,
        x="Accepted Tokens",
        hue="Model",
        multiple="dodge",
        palette={"Base": "#95a5a6", "Full_FT": "#2ecc71"},
        shrink=0.8,
        discrete=True,
        stat="density", # Normalize so we can compare distributions
        common_norm=False
    )

    plt.title(f"Distribution of Accepted Tokens (Bucket: {target_bucket})", fontweight='bold')
    plt.xlabel("Number of Tokens Accepted per Step")
    plt.ylabel("Probability Density")

    save_path = os.path.join(MODEL_DIR, "fig_expA_burstiness.png")
    plt.savefig(save_path)
    print(f"Saved Burstiness Chart to {save_path}")
    plt.show()

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import ast
import os

sns.set_theme(style="whitegrid", context="talk")
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.bbox'] = 'tight'
plt.rcParams['font.family'] = 'sans-serif'

MODEL_ORDER = ["Base", "Low_Res", "Med_Res", "High_Res", "Full_FT"]
BUCKET_ORDER = ["High_Res", "Med_Res", "Low_Res", "Natural_Language"]

CUSTOM_PALETTE = {
    "Base": "#95a5a6",
    "Full_FT": "#2ecc71",
    "High_Res": "#3498db",
    "Med_Res": "#9b59b6",
    "Low_Res": "#e67e22"
}

print("Generating graphs...")
df_main = load_exp_a_data()
plot_overall_speedup(df_main)
plot_language_disparity(df_main)
plot_burstiness_distribution(df_main)

In [None]:
def load_gamma_data(filename="specdec_gamma_sweep.csv"):
    path = os.path.join(MODEL_DIR, filename)
    if not os.path.exists(path):
        print(f"File not found: {path}")
        return None

    print(f"Loading data from {path}...")
    df = pd.read_csv(path)

    # Parse list strings back to objects
    if 'Burst_Sequence' in df.columns and isinstance(df['Burst_Sequence'].iloc[0], str):
        df['Burst_Sequence'] = df['Burst_Sequence'].apply(ast.literal_eval)

    return df

In [None]:
def plot_gamma_speedup(df):
    plt.figure(figsize=(10, 6))

    # Calculate mean speedup per Gamma to find the peak
    means = df.groupby("Gamma")["Speedup"].mean()
    peak_gamma = means.idxmax()
    peak_val = means.max()

    # Line Plot
    sns.lineplot(
        data=df,
        x="Gamma",
        y="Speedup",
        marker="o",
        markersize=10,
        linewidth=3,
        color="#2ecc71", # Matching 'Full_FT' green from Exp A
        errorbar="sd"
    )

    # Annotate the Peak
    plt.annotate(
        f'Optimal $\gamma={peak_gamma}$\n({peak_val:.2f}x Speedup)',
        xy=(peak_gamma, peak_val),
        xytext=(peak_gamma, peak_val + 0.1),
        ha='center',
        arrowprops=dict(facecolor='black', shrink=0.05),
        fontsize=12,
        fontweight='bold'
    )

    # Formatting
    plt.axhline(1.0, color='red', linestyle='--', label="Baseline (1.0x)")
    plt.title("Impact of Speculation Length ($\gamma$) on Inference Speed", fontweight='bold', pad=20)
    plt.ylabel("Speedup Factor (x)", fontweight='bold')
    plt.xlabel("Speculation Length ($\gamma$)", fontweight='bold')

    # Force X-axis to show only the Gammas we tested
    tested_gammas = sorted(df['Gamma'].unique())
    plt.xticks(tested_gammas)

    plt.legend(loc='lower right')
    plt.tight_layout()

    save_path = os.path.join(MODEL_DIR, "fig_expB_gamma_curve.png")
    plt.savefig(save_path)
    print(f"Saved Gamma Curve to {save_path}")
    plt.show()

In [None]:
def plot_burst_heatmap(df):
    """
    Shows a heatmap of Acceptance Probability.
    Refined to hide zero-values for better readability.
    """
    heatmap_data = []

    # Process data
    unique_gammas = sorted(df['Gamma'].unique())
    for g in unique_gammas:
        subset = df[df['Gamma'] == g]
        all_counts = []
        # Flatten the lists
        for seq in subset['Burst_Sequence']:
            all_counts.extend(seq)

        # Calculate density
        # We bin from 0 to g+1 because you can accept 0 to g tokens
        counts, _ = np.histogram(all_counts, bins=range(0, g + 2), density=True)

        for k, prob in enumerate(counts):
            heatmap_data.append({
                "Gamma": g,
                "Accepted Tokens": k,
                "Probability": prob
            })

    df_heat = pd.DataFrame(heatmap_data)

    # Pivot for heatmap format
    df_pivot = df_heat.pivot(index="Gamma", columns="Accepted Tokens", values="Probability")

    # Create a custom label matrix
    # We only show text if probability > 0.01 to avoid cluttering with "0.00"
    labels = df_pivot.applymap(lambda v: f"{v:.2f}" if v >= 0.01 else "")

    plt.figure(figsize=(14, 7))

    sns.heatmap(
        df_pivot,
        annot=labels, # Pass the clean labels
        fmt="", # No formatting needed since labels are strings
        cmap="Greens",
        cbar_kws={'label': 'Probability Density'},
        linewidths=.5,
        linecolor='whitesmoke',
        annot_kws={"size": 11} # Smaller font for readability
    )

    plt.title("Efficiency Heatmap: Accepted Tokens vs. Gamma Configuration", fontweight='bold', pad=20)
    plt.ylabel("Gamma Setting ($\gamma$)", fontweight='bold')
    plt.xlabel("Number of Tokens Accepted in one Step", fontweight='bold')
    plt.yticks(rotation=0)

    save_path = os.path.join(MODEL_DIR, "fig_expB_burst_heatmap.png")
    plt.savefig(save_path)
    print(f"Saved Burst Heatmap to {save_path}")
    plt.show()

In [None]:
df_gamma = load_gamma_data()

if df_gamma is not None:
    plot_gamma_speedup(df_gamma)
    plot_burst_heatmap(df_gamma)
else:
    print("Gamma data not found yet.")