# Multi-Subset InstABoost with GPT-OSS-20B (Reasoning Model)

This notebook demonstrates InstABoost with OpenAI's GPT-OSS-20B, a reasoning model.

We'll:
1. Load the reasoning model
2. Create prompts using the model's chat template
3. Apply boosting to instruction tokens
4. Compare outputs with and without boosting

In [None]:
import sys
sys.path.append('..')

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from src.boost_config import TokenSubset, BoostConfig
from src.token_utils import create_token_subset_from_substring
from src.attention_hook import register_boost_hooks, unregister_boost_hooks, update_bias_mask

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 1. Load Model and Tokenizer

Loading GPT-OSS-20B in bfloat16 for efficient memory usage.

In [None]:
model_name = "openai/gpt-oss-20b"
print(f"Loading {model_name}...")

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    device_map="auto",
    torch_dtype=torch.bfloat16
)

device = next(model.parameters()).device
print(f"Model loaded on {device}")
print(f"Model type: {model.config.model_type}")
print(f"Number of layers: {model.config.num_hidden_layers}")

## 2. Prepare Input with Chat Template

GPT-OSS-20B uses a specific chat template with channels for reasoning.

In [None]:
# Create a simple instruction-following task
user_message = "Answer in exactly 3 words: What is the capital of France?"

messages = [
    {"role": "user", "content": user_message}
]

# Apply chat template
formatted_prompt = tokenizer.apply_chat_template(
    messages, 
    tokenize=False, 
    add_generation_prompt=True
)

print("Formatted prompt:")
print(formatted_prompt)
print(f"\nPrompt length: {len(formatted_prompt)} characters")

In [None]:
# Tokenize and examine
input_ids = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(device)
input_length = input_ids.shape[1]

print(f"Token count: {input_length}")
print(f"\nTokens (showing user message portion):")

# Find where the user message starts
tokens = tokenizer.encode(formatted_prompt, add_special_tokens=False)
for i, token_id in enumerate(tokens):
    token_text = tokenizer.decode([token_id])
    # Highlight the user message tokens
    if user_message[:10] in tokenizer.decode(tokens[max(0,i-5):i+5]):
        print(f"  {i}: '{token_text}' (id={token_id})")

## 3. Generate Without Boosting (Baseline)

In [None]:
print("Generating without boosting...")

with torch.no_grad():
    output_baseline = model.generate(
        input_ids,
        max_new_tokens=100,
        do_sample=False,  # Deterministic for comparison
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )

baseline_text = tokenizer.decode(output_baseline[0], skip_special_tokens=False)
print(f"\nBaseline output:")
print(baseline_text)

In [None]:
# Extract just the generated part (after the prompt)
generated_tokens = output_baseline[0][input_length:]
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=False)
print("Generated (after prompt):")
print(generated_text)

## 4. Create Boost Configuration

We'll boost attention to the instruction part: "Answer in exactly 3 words"

In [None]:
# Find the instruction substring in the formatted prompt
instruction = "Answer in exactly 3 words"

# Create token subset for the instruction
# NOTE: bias=2.0 is optimal for GPT-OSS-20B. Higher values cause repetition.
instruction_subset = create_token_subset_from_substring(
    name="instruction",
    text=formatted_prompt,
    substring=instruction,
    tokenizer=tokenizer,
    bias=2.0  # Optimal boost for this model
)

print(f"Instruction subset: {instruction_subset}")
print(f"Token indices: {instruction_subset.indices}")

# Show which tokens are being boosted
print(f"\nBoosted tokens:")
for idx in instruction_subset.indices:
    token_text = tokenizer.decode([tokens[idx]])
    print(f"  {idx}: '{token_text}'")

In [None]:
# Create boost configuration
config = BoostConfig(subsets=[instruction_subset])
print(f"Boost config: {config}")

## 5. Generate With Boosting

In [None]:
# Register hooks
print("Registering boost hooks...")
handle = register_boost_hooks(model, config, input_length=input_length)

# Update bias mask with actual sequence length
update_bias_mask(handle, seq_length=input_length, device=device)

print(f"Bias mask shape: {handle.bias_mask.shape}")
print(f"Non-zero bias positions: {(handle.bias_mask != 0).sum().item()}")

In [None]:
print("Generating with boosting...")

with torch.no_grad():
    output_boosted = model.generate(
        input_ids,
        max_new_tokens=100,
        do_sample=False,  # Deterministic for comparison
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )

boosted_text = tokenizer.decode(output_boosted[0], skip_special_tokens=False)
print(f"\nBoosted output:")
print(boosted_text)

In [None]:
# Extract just the generated part
boosted_generated_tokens = output_boosted[0][input_length:]
boosted_generated_text = tokenizer.decode(boosted_generated_tokens, skip_special_tokens=False)
print("Generated with boost (after prompt):")
print(boosted_generated_text)

In [None]:
# Clean up hooks
unregister_boost_hooks(handle)
print("Hooks unregistered.")

## 6. Compare Outputs

In [None]:
print("=" * 80)
print("COMPARISON")
print("=" * 80)
print(f"\nInstruction: {instruction}")
print(f"Boost bias: {instruction_subset.bias}")
print(f"\n{'-'*40}")
print("BASELINE (no boosting):")
print(generated_text)
print(f"\n{'-'*40}")
print("BOOSTED (instruction bias=2.0):")
print(boosted_generated_text)
print("=" * 80)

## 7. Test with Different Bias Values

In [None]:
# Test different bias values to find optimal range
# Based on experiments: 1.0-2.0 is good, 5.0+ causes problems
bias_values = [0.5, 1.0, 2.0, 3.0, 5.0]

print("Testing different bias values...")
print("(Watch for repetition at higher values!)\n")
results = {}

for bias in bias_values:
    # Create config with new bias
    subset = create_token_subset_from_substring(
        "instruction", formatted_prompt, instruction, tokenizer, bias
    )
    config = BoostConfig(subsets=[subset])
    
    # Register and generate
    handle = register_boost_hooks(model, config, input_length=input_length)
    update_bias_mask(handle, seq_length=input_length, device=device)
    
    with torch.no_grad():
        output = model.generate(
            input_ids,
            max_new_tokens=150,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    
    generated = tokenizer.decode(output[0][input_length:], skip_special_tokens=False)
    results[bias] = generated
    
    unregister_boost_hooks(handle)
    
    print(f"Bias={bias}:")
    print(f"{generated[:300]}..." if len(generated) > 300 else generated)
    print("-" * 80)

## 8. Test with a Different Instruction

In [None]:
# Try a different instruction that's easier to verify
user_message2 = "Respond only in Spanish: What is 2 + 2?"

messages2 = [{"role": "user", "content": user_message2}]
formatted_prompt2 = tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True)
input_ids2 = tokenizer(formatted_prompt2, return_tensors="pt").input_ids.to(device)
input_length2 = input_ids2.shape[1]

print(f"New instruction: '{user_message2}'")
print(f"Token count: {input_length2}")

In [None]:
# Baseline
print("Baseline (no boost):")
with torch.no_grad():
    out_base = model.generate(input_ids2, max_new_tokens=50, do_sample=False,
                               pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
print(tokenizer.decode(out_base[0][input_length2:], skip_special_tokens=False))

In [None]:
# Boosted - boost "Respond only in Spanish"
# NOTE: bias=2.0 is optimal. Higher values (5.0+) can cause repetition loops!
instruction2 = "Respond only in Spanish"
subset2 = create_token_subset_from_substring(
    "spanish_instruction", formatted_prompt2, instruction2, tokenizer, bias=2.0
)
config2 = BoostConfig(subsets=[subset2])

handle2 = register_boost_hooks(model, config2, input_length=input_length2)
update_bias_mask(handle2, seq_length=input_length2, device=device)

print(f"\nBoosted (bias=2.0 on '{instruction2}'):")
with torch.no_grad():
    out_boost = model.generate(input_ids2, max_new_tokens=150, do_sample=False,
                                pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
print(tokenizer.decode(out_boost[0][input_length2:], skip_special_tokens=False))

unregister_boost_hooks(handle2)

## Summary

This notebook demonstrated:
1. Loading GPT-OSS-20B (a reasoning model) with InstABoost
2. Using the model's chat template for proper formatting
3. Boosting attention to specific instruction tokens
4. Comparing outputs with different bias values

### Key Findings

**Optimal Bias Values:**
- `bias=2.0` - Sweet spot for this model. Helps focus on instructions without degradation.
- `bias=5.0+` - Too strong! Causes repetition loops in the reasoning phase.
- `bias=10.0` - Severe degradation, nonsensical output.

**Effect on Reasoning Models:**
- The boost affects both the `analysis` (reasoning) and `final` (output) phases
- Boosted models often reach a final answer where baseline gets stuck in reasoning
- The model's internal reasoning explicitly mentions the boosted instruction more prominently

**Example Results:**
- "Respond only in Spanish: What is 2+2?"
  - Baseline: Gets stuck in reasoning, no final answer
  - Boosted (2.0): Completes reasoning and outputs "4"