# ETHOS Demo: Text Generation

This notebook demonstrates how to load a trained ETHOS model and generate text.

**License**: This code is licensed under AGPLv3. For commercial use, contact wryanmedford@gmail.com

In [None]:
import sys
sys.path.append('..')

import torch
import torch.nn.functional as F
from model import CompressedMoEModel
from train import Config
import tiktoken
import numpy as np

## Load Model

In [None]:
# Load configuration
config = Config('../configs/default.yaml')

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CompressedMoEModel(config).to(device)

# Load checkpoint (update path to your checkpoint)
checkpoint_path = '../checkpoints/latest/pytorch_model.bin'
try:
    state_dict = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(state_dict)
    print("Model loaded successfully!")
except FileNotFoundError:
    print("No checkpoint found, using random initialization")

model.eval()

# Initialize tokenizer
enc = tiktoken.get_encoding("cl100k_base")

## Text Generation Function

In [None]:
@torch.no_grad()
def generate(
    model,
    prompt,
    max_new_tokens=100,
    temperature=0.8,
    top_k=50,
    top_p=0.9,
    repetition_penalty=1.0,
    seed=None
):
    """Generate text from a prompt"""
    
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)
    
    # Encode prompt
    input_ids = enc.encode(prompt, allowed_special="all")
    input_ids = torch.tensor([input_ids], dtype=torch.long).to(device)
    
    generated_tokens = []
    
    # Generate tokens
    for _ in range(max_new_tokens):
        # Get model predictions
        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            outputs = model(input_ids)
        
        # Get logits for the last position
        next_token_logits = outputs[0, -1, :]
        
        # Apply repetition penalty
        if repetition_penalty != 1.0 and generated_tokens:
            for token_id in set(generated_tokens):
                if next_token_logits[token_id] < 0:
                    next_token_logits[token_id] *= repetition_penalty
                else:
                    next_token_logits[token_id] /= repetition_penalty
        
        # Apply temperature
        if temperature != 1.0:
            next_token_logits = next_token_logits / temperature
        
        # Apply top-k filtering
        if top_k is not None and top_k > 0:
            indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
            next_token_logits[indices_to_remove] = float('-inf')
        
        # Apply top-p filtering
        if top_p is not None and top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
            
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            
            indices_to_remove = sorted_indices[sorted_indices_to_remove]
            next_token_logits[indices_to_remove] = float('-inf')
        
        # Sample
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        
        generated_tokens.append(next_token.item())
        
        # Append to input
        input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
        
        # Stop if we hit the end token
        if next_token.item() == enc.eot_token:
            break
    
    # Decode
    generated_ids = input_ids[0].tolist()
    generated_text = enc.decode(generated_ids)
    
    return generated_text

## Generate Text

In [None]:
# Test prompts
prompts = [
    "The future of artificial intelligence is",
    "In a world where technology",
    "The most important scientific discovery",
]

for prompt in prompts:
    print(f"\nPrompt: '{prompt}'")
    print("-" * 50)
    
    # Generate with different settings
    output = generate(
        model,
        prompt,
        max_new_tokens=50,
        temperature=0.8,
        top_k=50,
        top_p=0.9,
        seed=42
    )
    
    print(output)

## Interactive Generation

In [None]:
# Interactive generation cell
your_prompt = "Once upon a time"

output = generate(
    model,
    your_prompt,
    max_new_tokens=100,
    temperature=0.8,
    top_k=50,
    top_p=0.9
)

print(output)

## Analyze Expert Usage

In [None]:
# Analyze which experts are being used
@torch.no_grad()
def analyze_expert_usage(model, text, layer_idx=2):
    """Analyze expert routing for a given text"""
    
    # Encode text
    input_ids = enc.encode(text, allowed_special="all")
    input_ids = torch.tensor([input_ids], dtype=torch.long).to(device)
    
    # Forward pass through layers
    h = model.tok_embeddings(input_ids)
    
    for i, layer in enumerate(model.layers):
        if i == layer_idx and hasattr(layer.mlp, 'router'):
            # Get routing decisions
            x_flat = h.view(-1, config.d_model)
            scores, indices = layer.mlp.router(x_flat)
            
            # Analyze
            print(f"Layer {i} Expert Usage:")
            print(f"  Shape: scores={scores.shape}, indices={indices.shape}")
            print(f"  Unique experts used: {len(torch.unique(indices))}")
            print(f"  Top 5 most used experts: {torch.mode(indices.flatten())[0].item()}")
            
            return scores, indices
        
        # Regular forward
        seq_len = h.shape[1]
        attention_mask = torch.full((seq_len, seq_len), float("-inf"), device=device, dtype=h.dtype)
        attention_mask = torch.triu(attention_mask, diagonal=1)
        position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device)
        h = layer(h, attention_mask, position_ids)

# Analyze a sample text
scores, indices = analyze_expert_usage(model, "The future of AI is bright", layer_idx=3)