# Multi-Subset InstABoost with GPT-OSS-20B (Reasoning Model)

This notebook demonstrates InstABoost with OpenAI's GPT-OSS-20B, a reasoning model.

We'll:
1. Load the reasoning model
2. Create prompts using the model's chat template
3. Apply boosting to instruction tokens
4. Compare outputs with and without boosting

In [1]:
import sys
sys.path.append('..')

import re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from src.boost_config import TokenSubset, BoostConfig
from src.token_utils import create_token_subset_from_substring
from src.attention_hook import register_boost_hooks, unregister_boost_hooks, update_bias_mask

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

PyTorch version: 2.9.1+cu128
CUDA available: True
GPU: NVIDIA A100 80GB PCIe
GPU Memory: 85.1 GB


## 1. Load Model and Tokenizer

In [2]:
model_name = "openai/gpt-oss-20b"
print(f"Loading {model_name}...")

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    device_map="auto",
    torch_dtype=torch.bfloat16
)

device = next(model.parameters()).device
print(f"Model loaded on {device}")
print(f"Number of layers: {model.config.num_hidden_layers}")

Loading openai/gpt-oss-20b...


`torch_dtype` is deprecated! Use `dtype` instead!
MXFP4 quantization requires Triton and kernels installed: CUDA requires Triton >= 3.4.0, XPU requires Triton >= 3.5.0, we will default to dequantizing the model to bf16


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Model loaded on cuda:0
Number of layers: 24


## 2. Helper Functions

GPT-OSS-20B is a reasoning model that outputs in channels: `analysis` (reasoning) and `final` (answer).

In [3]:
def parse_gpt_oss_output(generated_text: str) -> dict:
    """Parse GPT-OSS-20B output into reasoning and final answer."""
    result = {'reasoning': None, 'final': None, 'raw': generated_text}
    
    # Extract analysis/reasoning channel
    analysis_match = re.search(
        r'<\|channel\|>(?:analysis|commentary)<\|message\|>(.*?)(?:<\|end\|>|<\|channel\|>|$)', 
        generated_text, re.DOTALL
    )
    if analysis_match:
        result['reasoning'] = analysis_match.group(1).strip()
    
    # Extract final answer channel
    final_match = re.search(
        r'<\|channel\|>final<\|message\|>(.*?)(?:<\|end\|>|<\|return\|>|<\|channel\|>|$)', 
        generated_text, re.DOTALL
    )
    if final_match:
        result['final'] = final_match.group(1).strip()
    
    return result

def display_parsed_output(parsed: dict, max_reasoning_chars: int = 300):
    """Display parsed output with truncated reasoning."""
    reasoning = parsed['reasoning'] or "(no reasoning found)"
    final = parsed['final'] or "(no final answer)"
    
    if len(reasoning) > max_reasoning_chars:
        reasoning = reasoning[:max_reasoning_chars] + "..."
    
    print(f"Reasoning: {reasoning}")
    print(f"\nFinal Answer: {final}")

def check_word_count(instruction: str, answer: str) -> tuple[bool, str]:
    """Check if answer follows word count instruction."""
    if answer is None:
        return False, "No final answer generated"
    
    match = re.search(r'(\d+)\s*words?', instruction.lower())
    if match:
        target = int(match.group(1))
        words = answer.split()
        return len(words) == target, f"Expected {target} words, got {len(words)}: {words}"
    return True, "No word count constraint found"

print("Helper functions defined.")

Helper functions defined.


## 3. Prepare Test Input

We use a word count constraint - the model often struggles with these.

In [4]:
user_message = "Answer in exactly 3 words: What is the capital of France?"
instruction = "Answer in exactly 3 words"

messages = [{"role": "user", "content": user_message}]
formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

input_ids = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(device)
input_length = input_ids.shape[1]

print(f"Prompt: {user_message}")
print(f"Token count: {input_length}")

Prompt: Answer in exactly 3 words: What is the capital of France?
Token count: 81


## 4. Generate Baseline (No Boosting)

In [5]:
MAX_NEW_TOKENS = 600

print("Generating baseline (no boosting)...")
with torch.no_grad():
    output_baseline = model.generate(
        input_ids,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=False,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )

num_generated = output_baseline[0].shape[0] - input_length
hit_limit = num_generated >= MAX_NEW_TOKENS

generated_text = tokenizer.decode(output_baseline[0][input_length:], skip_special_tokens=False)
parsed_baseline = parse_gpt_oss_output(generated_text)

print(f"Generated {num_generated} tokens {'(HIT LIMIT)' if hit_limit else ''}")
display_parsed_output(parsed_baseline)

baseline_passed, baseline_reason = check_word_count(instruction, parsed_baseline['final'])
print(f"\nInstruction followed: {'YES' if baseline_passed else 'NO'} - {baseline_reason}")

Generating baseline (no boosting)...
Generated 600 tokens (HIT LIMIT)
Reasoning: The user asks: "Answer in exactly 3 words: What is the capital of France?" They want a 3-word answer. The capital of France is Paris. But we need exactly 3 words. So we could say "Paris, France's capital." That's 4 words? Let's count: "Paris," (1) "France's" (2) "capital." (3). Actually "Paris, Fran...

Final Answer: (no final answer)

Instruction followed: NO - No final answer generated


## 5. Generate With Boosting

In [6]:
# Create boost configuration
tokens = tokenizer.encode(formatted_prompt, add_special_tokens=False)
instruction_subset = create_token_subset_from_substring(
    name="instruction",
    text=formatted_prompt,
    substring=instruction,
    tokenizer=tokenizer,
    bias=2.0
)
config = BoostConfig(subsets=[instruction_subset])

print(f"Boosting tokens: {[tokenizer.decode([tokens[i]]) for i in instruction_subset.indices]}")
print(f"Bias: {instruction_subset.bias}")

Boosting tokens: ['Answer', ' in', ' exactly', ' ', '3', ' words']
Bias: 2.0


In [7]:
# Register hooks and generate
handle = register_boost_hooks(model, config, input_length=input_length)
update_bias_mask(handle, seq_length=input_length, device=device)

print("Generating with boosting...")
with torch.no_grad():
    output_boosted = model.generate(
        input_ids,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=False,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )

unregister_boost_hooks(handle)

num_generated = output_boosted[0].shape[0] - input_length
hit_limit = num_generated >= MAX_NEW_TOKENS

boosted_text = tokenizer.decode(output_boosted[0][input_length:], skip_special_tokens=False)
parsed_boosted = parse_gpt_oss_output(boosted_text)

print(f"Generated {num_generated} tokens {'(HIT LIMIT)' if hit_limit else ''}")
display_parsed_output(parsed_boosted)

boosted_passed, boosted_reason = check_word_count(instruction, parsed_boosted['final'])
print(f"\nInstruction followed: {'YES' if boosted_passed else 'NO'} - {boosted_reason}")

Generating with boosting...
Generated 301 tokens 
Reasoning: The user asks: "Answer in exactly 3 words: What is the capital of France?" They want exactly 3 words. The capital of France is Paris. But we need exactly 3 words. So we could say "Paris, France's capital." That's 4 words? Let's count: "Paris," (1) "France's" (2) "capital." (3). Actually "Paris," is ...

Final Answer: Paris, France's capital.

Instruction followed: YES - Expected 3 words, got 3: ['Paris,', "France's", 'capital.']


## 6. Compare Results

In [8]:
print("=" * 70)
print("COMPARISON")
print("=" * 70)
print(f"\nInstruction: \"{instruction}\"")
print(f"\n{'Method':<15} {'Final Answer':<30} {'Followed?':<10}")
print("-" * 70)

baseline_answer = parsed_baseline['final'] or '(none)'
boosted_answer = parsed_boosted['final'] or '(none)'

print(f"{'Baseline':<15} {baseline_answer:<30} {'YES' if baseline_passed else 'NO':<10}")
print(f"{'Boosted':<15} {boosted_answer:<30} {'YES' if boosted_passed else 'NO':<10}")

print("\n" + "=" * 70)
if boosted_passed and not baseline_passed:
    print("VERDICT: Boosting HELPED")
elif boosted_passed and baseline_passed:
    print("VERDICT: Both followed the instruction")
elif not boosted_passed and baseline_passed:
    print("VERDICT: Boosting HURT")
else:
    print("VERDICT: Neither followed the instruction")
print("=" * 70)

COMPARISON

Instruction: "Answer in exactly 3 words"

Method          Final Answer                   Followed? 
----------------------------------------------------------------------
Baseline        (none)                         NO        
Boosted         Paris, France's capital.       YES       

VERDICT: Boosting HELPED


## 7. Test Different Bias Values

In [10]:
bias_values = [0.0, 0.5, 1.0, 2.0, 3.0, 5.0]
results = {}

print("Testing different bias values...\n")

for bias in bias_values:
    if bias == 0.0:
        # No boost
        with torch.no_grad():
            output = model.generate(
                input_ids, max_new_tokens=MAX_NEW_TOKENS, do_sample=False,
                pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id
            )
    else:
        subset = create_token_subset_from_substring(
            "instruction", formatted_prompt, instruction, tokenizer, bias
        )
        cfg = BoostConfig(subsets=[subset])
        hdl = register_boost_hooks(model, cfg, input_length=input_length)
        update_bias_mask(hdl, seq_length=input_length, device=device)
        
        with torch.no_grad():
            output = model.generate(
                input_ids, max_new_tokens=MAX_NEW_TOKENS, do_sample=False,
                pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id
            )
        unregister_boost_hooks(hdl)
    
    num_tok = output[0].shape[0] - input_length
    gen_text = tokenizer.decode(output[0][input_length:], skip_special_tokens=False)
    parsed = parse_gpt_oss_output(gen_text)
    passed, reason = check_word_count(instruction, parsed['final'])
    
    results[bias] = {'answer': parsed['final'], 'passed': passed, 'tokens': num_tok}

# Summary
print(f"{'Bias':<8} {'Tokens':<10} {'Passed?':<10} {'Answer':<30}")
print("-" * 60)
for bias, data in results.items():
    answer = data['answer'] or '(none)'
    answer = answer[:27] + '...' if len(answer) > 30 else answer
    print(f"{bias:<8} {data['tokens']:<10} {'YES' if data['passed'] else 'NO':<10} {answer:<30}")

Testing different bias values...

Bias     Tokens     Passed?    Answer                        
------------------------------------------------------------
0.0      600        NO         (none)                        
0.5      600        NO         (none)                        
1.0      132        YES        Paris is capital.             
2.0      301        YES        Paris, France's capital.      
3.0      242        YES        Paris is capital              
5.0      377        YES        Paris is capital              


## Summary

### Key Findings

- **Optimal bias**: 1.0-2.0 works well for GPT-OSS-20B
- **Too high bias** (5.0+): Can cause repetition in reasoning
- **Word count constraints**: The model often gets stuck counting words in reasoning - boosting can help it commit to an answer faster

### For More Tests

Run the standalone test script:
```bash
python ../scripts/test_baseline_instructions.py
```

Or from this notebook:
```python
%run ../scripts/test_baseline_instructions.py
results = run_tests(model, tokenizer, device)
print_summary(results)
```