# Superposition Distillation: The "Spoon Feed" Protocol

This notebook implements knowledge distillation via hidden state matching.

**The Key Insight**: Force a small 1.58-bit model to reconstruct the high-dimensional
hidden states of a large teacher. Because ternary weights are sparse/orthogonal,
they can encode MORE features via superposition than their dimension suggests.

**Protocol**:
1. **Phase 1**: Cache teacher hidden states (then delete teacher to free VRAM)
2. **Phase 2**: Train student to project UP to teacher's dimension space
3. **Phase 3**: Re-train LM head to decode the new representations

In [None]:
# @title Setup: Install Dependencies
!pip install -q torch transformers datasets bitsandbytes accelerate tqdm

In [None]:
# @title Configuration
# @markdown Adjust these based on your hardware

TEACHER_MODEL = "Qwen/Qwen2.5-Coder-1.5B-Instruct"  # @param ["Qwen/Qwen2.5-Coder-1.5B-Instruct", "Qwen/Qwen2.5-Coder-7B-Instruct", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
DATASET = "roneneldan/TinyStories"  # @param ["roneneldan/TinyStories", "Magpie-Align/Magpie-Reasoning-V2-250K-CoT-Deepseek-R1-Llama-70B"]

# Data settings
NUM_CACHE_SAMPLES = 500  # @param {type:"slider", min:100, max:2000, step:100}
MAX_SEQ_LENGTH = 128  # @param {type:"slider", min:64, max:512, step:64}

# Student settings
STUDENT_DIM = 768  # @param {type:"integer"}
STUDENT_LAYERS = 12  # @param {type:"integer"}
STUDENT_HEADS = 12  # @param {type:"integer"}

# Training settings
SUPERPOSITION_STEPS = 500  # @param {type:"slider", min:100, max:2000, step:100}
LM_HEAD_STEPS = 200  # @param {type:"slider", min:50, max:500, step:50}
BATCH_SIZE = 16  # @param {type:"slider", min:4, max:64, step:4}
LEARNING_RATE = 1e-3  # @param {type:"number"}

In [None]:
# @title Core Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional
from tqdm.auto import tqdm
import gc
import math

# Check GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# =============================================================================
# BitNet Implementation
# =============================================================================

class STESign(torch.autograd.Function):
    """Straight-Through Estimator for ternary quantization."""
    @staticmethod
    def forward(ctx, x):
        return torch.clamp(torch.round(x), -1, 1)
    
    @staticmethod  
    def backward(ctx, grad_output):
        return grad_output  # Pass gradient through


class STERound(torch.autograd.Function):
    """STE for INT8 rounding."""
    @staticmethod
    def forward(ctx, x):
        return torch.round(x)
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output


class BitLinear(nn.Module):
    """BitNet b1.58 Linear with proper quantization."""
    
    def __init__(self, in_features: int, out_features: int, bias: bool = False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
    
    def forward(self, x):
        # Quantize weights to {-1, 0, +1}
        w_scale = self.weight.abs().mean() + 1e-8
        w_quant = STESign.apply(self.weight / w_scale)
        
        # Quantize activations to INT8
        a_scale = x.abs().amax(dim=-1, keepdim=True) + 1e-8
        a_quant = STERound.apply(torch.clamp(x / a_scale * 127.0, -128, 127))
        
        # Compute and rescale
        y = F.linear(a_quant, w_quant, None)
        y = y * (w_scale * a_scale / 127.0)
        
        if self.bias is not None:
            y = y + self.bias
        return y


class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization."""
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    
    def forward(self, x):
        rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return (x / rms) * self.weight


# =============================================================================
# Student Model
# =============================================================================

class BitNetBlock(nn.Module):
    """Single transformer block with BitLinear."""
    
    def __init__(self, hidden_dim: int, num_heads: int, mlp_ratio: float = 4.0):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        
        # Attention
        self.q_proj = BitLinear(hidden_dim, hidden_dim)
        self.k_proj = BitLinear(hidden_dim, hidden_dim)
        self.v_proj = BitLinear(hidden_dim, hidden_dim)
        self.o_proj = BitLinear(hidden_dim, hidden_dim)
        
        # MLP (SwiGLU)
        mlp_dim = int(hidden_dim * mlp_ratio)
        self.gate_proj = BitLinear(hidden_dim, mlp_dim)
        self.up_proj = BitLinear(hidden_dim, mlp_dim)
        self.down_proj = BitLinear(mlp_dim, hidden_dim)
        
        # Norms
        self.input_norm = RMSNorm(hidden_dim)
        self.post_attn_norm = RMSNorm(hidden_dim)
    
    def forward(self, x):
        batch, seq_len, _ = x.shape
        
        # Pre-norm attention
        residual = x
        x = self.input_norm(x)
        
        q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Attention
        scale = 1.0 / math.sqrt(self.head_dim)
        attn = torch.matmul(q, k.transpose(-2, -1)) * scale
        
        # Causal mask
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
        attn = attn.masked_fill(causal_mask, float('-inf'))
        attn = F.softmax(attn, dim=-1)
        
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(batch, seq_len, self.hidden_dim)
        out = self.o_proj(out)
        x = residual + out
        
        # Pre-norm MLP
        residual = x
        x = self.post_attn_norm(x)
        x = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
        x = residual + x
        
        return x


class BitNetStudent(nn.Module):
    """BitNet student for superposition distillation."""
    
    def __init__(self, vocab_size: int, hidden_dim: int = 768, 
                 num_layers: int = 12, num_heads: int = 12):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        
        self.embed_tokens = nn.Embedding(vocab_size, hidden_dim)
        self.layers = nn.ModuleList([BitNetBlock(hidden_dim, num_heads) for _ in range(num_layers)])
        self.norm = RMSNorm(hidden_dim)
        
        nn.init.normal_(self.embed_tokens.weight, std=0.02)
    
    def forward(self, input_ids, return_hidden_states: bool = True):
        x = self.embed_tokens(input_ids)
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        
        if return_hidden_states:
            return x
        else:
            return F.linear(x, self.embed_tokens.weight)  # Tied LM head
    
    def generate(self, input_ids, max_new_tokens=50, temperature=0.7):
        self.eval()
        for _ in range(max_new_tokens):
            logits = self.forward(input_ids, return_hidden_states=False)
            next_logits = logits[:, -1, :] / temperature
            probs = F.softmax(next_logits, dim=-1)
            next_token = torch.multinomial(probs, 1)
            input_ids = torch.cat([input_ids, next_token], dim=-1)
        return input_ids


# =============================================================================
# Superposition Projector
# =============================================================================

class Decompressor(nn.Module):
    """Projects student hidden states to teacher's dimension."""
    def __init__(self, student_dim: int, teacher_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(student_dim, student_dim * 2),
            nn.GELU(),
            nn.LayerNorm(student_dim * 2),
            nn.Linear(student_dim * 2, teacher_dim),
        )
    
    def forward(self, x):
        return self.net(x)


print("Core classes defined successfully!")

In [None]:
# @title Phase 1: Load Teacher & Cache Hidden States

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset

print("=" * 60)
print("PHASE 1: Caching Teacher Hidden States")
print("=" * 60)

# Load tokenizer
print(f"\nLoading tokenizer: {TEACHER_MODEL}")
tokenizer = AutoTokenizer.from_pretrained(TEACHER_MODEL, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load teacher in 4-bit
print(f"Loading teacher model (4-bit)...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
)

try:
    teacher = AutoModelForCausalLM.from_pretrained(
        TEACHER_MODEL,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
    )
except Exception as e:
    print(f"4-bit failed: {e}")
    print("Trying FP16...")
    teacher = AutoModelForCausalLM.from_pretrained(
        TEACHER_MODEL,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True,
    )

teacher.eval()
TEACHER_DIM = teacher.config.hidden_size
VOCAB_SIZE = teacher.config.vocab_size

print(f"Teacher hidden dim: {TEACHER_DIM}")
print(f"Vocab size: {VOCAB_SIZE}")

# Load dataset
print(f"\nLoading dataset: {DATASET}")
dataset = load_dataset(DATASET, split="train", streaming=True)

# Cache hidden states
print(f"\nCaching {NUM_CACHE_SAMPLES} samples...")
cached_inputs = []
cached_targets = []

count = 0
for item in tqdm(dataset, total=NUM_CACHE_SAMPLES):
    if count >= NUM_CACHE_SAMPLES:
        break
    
    # Get text
    if 'text' in item:
        text = item['text']
    elif 'instruction' in item:
        text = item['instruction']
        if 'response' in item:
            text = f"{item['instruction']}\n{item['response']}"
    else:
        continue
    
    if not text or len(text) < 10:
        continue
    
    # Tokenize
    inputs = tokenizer(
        text,
        return_tensors="pt",
        max_length=MAX_SEQ_LENGTH,
        truncation=True,
        padding='max_length',
    ).to(device)
    
    # Get teacher hidden states
    with torch.no_grad():
        outputs = teacher(input_ids=inputs.input_ids, output_hidden_states=True)
        hidden = outputs.hidden_states[-1].cpu()  # Last layer
    
    cached_inputs.append(inputs.input_ids.cpu())
    cached_targets.append(hidden)
    count += 1

print(f"\nCached {len(cached_inputs)} samples")
print(f"Input shape: {cached_inputs[0].shape}")
print(f"Target shape: {cached_targets[0].shape}")

# Free teacher memory
print("\nUnloading teacher...")
del teacher
torch.cuda.empty_cache()
gc.collect()
print("Teacher unloaded! VRAM freed.")

In [None]:
# @title Create Student Model

print("\nCreating student model...")
student = BitNetStudent(
    vocab_size=VOCAB_SIZE,
    hidden_dim=STUDENT_DIM,
    num_layers=STUDENT_LAYERS,
    num_heads=STUDENT_HEADS,
).to(device)

# Create projector
projector = Decompressor(STUDENT_DIM, TEACHER_DIM).to(device)

total_params = sum(p.numel() for p in student.parameters())
proj_params = sum(p.numel() for p in projector.parameters())

print(f"Student parameters: {total_params / 1e6:.1f}M")
print(f"Projector parameters: {proj_params / 1e6:.1f}M")
print(f"Student hidden dim: {STUDENT_DIM}")
print(f"Teacher hidden dim: {TEACHER_DIM}")

In [None]:
# @title Phase 2: Superposition Training

print("\n" + "=" * 60)
print("PHASE 2: Superposition Training")
print("=" * 60)
print("Training student to project UP to teacher's hidden space...")

# Optimizer for both student and projector
optimizer = torch.optim.AdamW(
    list(student.parameters()) + list(projector.parameters()),
    lr=LEARNING_RATE,
    weight_decay=0.0,  # No weight decay for BitNet!
)

student.train()
projector.train()

num_samples = len(cached_inputs)
losses = []

progress = tqdm(range(SUPERPOSITION_STEPS), desc="Superposition")
for step in progress:
    # Random batch
    indices = torch.randint(0, num_samples, (BATCH_SIZE,))
    
    # Get batch
    batch_in = torch.stack([cached_inputs[i] for i in indices])
    batch_target = torch.stack([cached_targets[i] for i in indices])
    
    # Squeeze extra dimensions if present
    if batch_in.dim() == 3 and batch_in.size(1) == 1:
        batch_in = batch_in.squeeze(1)
    if batch_target.dim() == 4 and batch_target.size(1) == 1:
        batch_target = batch_target.squeeze(1)
    
    batch_in = batch_in.to(device)
    batch_target = batch_target.to(device).float()
    
    # Forward: student -> projector
    student_hidden = student(batch_in, return_hidden_states=True)
    projected = projector(student_hidden)
    
    # Cosine similarity loss (more stable than MSE)
    proj_flat = projected.view(-1, TEACHER_DIM)
    target_flat = batch_target.view(-1, TEACHER_DIM)
    
    # We want cosine similarity = 1 (same direction)
    ones = torch.ones(proj_flat.size(0), device=device)
    loss = F.cosine_embedding_loss(proj_flat, target_flat, ones)
    
    # Backward
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0)
    optimizer.step()
    
    losses.append(loss.item())
    
    if step % 10 == 0:
        avg_loss = sum(losses[-10:]) / len(losses[-10:])
        progress.set_postfix({"loss": f"{avg_loss:.4f}"})

print(f"\nFinal superposition loss: {sum(losses[-50:]) / 50:.4f}")

# Plot loss
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Step')
plt.ylabel('Cosine Embedding Loss')
plt.title('Superposition Training Loss')
plt.grid(True)
plt.show()

In [None]:
# @title Phase 3: LM Head Training (Speech Therapy)

print("\n" + "=" * 60)
print("PHASE 3: LM Head Training")
print("=" * 60)
print("Re-training output head to decode new representations...")

# Freeze everything except embeddings (which are tied to LM head)
for param in student.parameters():
    param.requires_grad = False

# Unfreeze embedding (tied to LM head)
student.embed_tokens.weight.requires_grad = True

head_optimizer = torch.optim.AdamW(
    [student.embed_tokens.weight],
    lr=5e-3,  # Higher LR for head training
)

student.train()
lm_losses = []

progress = tqdm(range(LM_HEAD_STEPS), desc="LM Head")
for step in progress:
    indices = torch.randint(0, num_samples, (BATCH_SIZE,))
    
    batch_in = torch.stack([cached_inputs[i] for i in indices])
    if batch_in.dim() == 3 and batch_in.size(1) == 1:
        batch_in = batch_in.squeeze(1)
    batch_in = batch_in.to(device)
    
    # Forward with logits
    logits = student(batch_in, return_hidden_states=False)
    
    # Next token prediction loss
    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = batch_in[:, 1:].contiguous()
    
    loss = F.cross_entropy(
        shift_logits.view(-1, VOCAB_SIZE),
        shift_labels.view(-1),
    )
    
    head_optimizer.zero_grad()
    loss.backward()
    head_optimizer.step()
    
    lm_losses.append(loss.item())
    
    if step % 10 == 0:
        avg_loss = sum(lm_losses[-10:]) / len(lm_losses[-10:])
        progress.set_postfix({"loss": f"{avg_loss:.4f}"})

# Unfreeze all for future training
for param in student.parameters():
    param.requires_grad = True

print(f"\nFinal LM loss: {sum(lm_losses[-20:]) / 20:.4f}")

# Plot
plt.figure(figsize=(10, 4))
plt.plot(lm_losses)
plt.xlabel('Step')
plt.ylabel('Cross Entropy Loss')
plt.title('LM Head Training Loss')
plt.grid(True)
plt.show()

In [None]:
# @title Save Model

print("Saving model...")
torch.save({
    'student': student.state_dict(),
    'projector': projector.state_dict(),
    'config': {
        'vocab_size': VOCAB_SIZE,
        'hidden_dim': STUDENT_DIM,
        'num_layers': STUDENT_LAYERS,
        'num_heads': STUDENT_HEADS,
        'teacher_dim': TEACHER_DIM,
    }
}, 'superposition_student.pt')

print("Saved to 'superposition_student.pt'")

In [None]:
# @title Test Generation

print("\n" + "=" * 60)
print("GENERATION TEST")
print("=" * 60)

student.eval()

prompts = [
    "Once upon a time",
    "The capital of France is",
    "def fibonacci(n):",
]

for prompt in prompts:
    print(f"\nPrompt: {prompt}")
    
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    
    with torch.no_grad():
        output_ids = student.generate(input_ids, max_new_tokens=30, temperature=0.7)
    
    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    print(f"Output: {output_text}")
    print("-" * 40)

## Summary

This notebook implements **Superposition Distillation**:

1. **Phase 1**: Cache teacher's hidden states (then delete teacher)
2. **Phase 2**: Train student to project UP to teacher's dimension
3. **Phase 3**: Re-train LM head to decode new representations

**Key insights**:
- 1.58-bit weights are naturally sparse/orthogonal
- Sparse vectors can encode MORE features via superposition
- Cosine loss is more stable than MSE for hidden state matching
- LM head needs retraining after superposition training