In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(f"CUDA available: {torch.cuda.is_available()}")
device = torch.device("cuda:3")

CUDA available: True


In [3]:
print("\n[1/2] Loading HumanEval (code reasoning)...")
code_dataset = load_dataset("openai_humaneval", split="test")

print("\n[2/2] Loading GSM8K (math reasoning)...")
math_dataset = load_dataset("gsm8k", "main", split="test")


[1/2] Loading HumanEval (code reasoning)...

[2/2] Loading GSM8K (math reasoning)...


In [4]:
N_SAMPLES = 20
code_prompts = [code_dataset[i]["prompt"] for i in range(min(N_SAMPLES, len(code_dataset)))]
math_prompts = [math_dataset[i]["question"] for i in range(min(N_SAMPLES, len(math_dataset)))]

print("\n" + "="*60)
print("Example CODE prompt:")
print(code_prompts[0][:150] + "...")
print("\n" + "="*60)
print("Example MATH prompt:")
print(math_prompts[0][:150] + "...")
print("="*60)


Example CODE prompt:
from typing import List


def has_close_elements(numbers: List[float], threshold: float) -> bool:
    """ Check if in given list of numbers, are any t...

Example MATH prompt:
Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the rem...


In [5]:
# Try Qwen2.5-0.5B-Instruct first (smallest reasoning model)
# Fallback to Phi-2 if needed
MODEL_OPTIONS = [
    "Qwen/Qwen2.5-0.5B-Instruct",  # 500MB, great reasoning
    "microsoft/phi-2",              # 2.7GB, excellent for math/code
    "gpt2",                         # 500MB, fallback
]

In [11]:
model_name = MODEL_OPTIONS[1]
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
            model_name,
            dtype=torch.float16,
            trust_remote_code=True
        )
model = model.to(device)
model.eval()

if hasattr(model, 'model') and hasattr(model.model, 'layers'):
    n_layers = len(model.model.layers)
    hidden_size = model.config.hidden_size
elif hasattr(model, 'transformer'):
    n_layers = len(model.transformer.h)
    hidden_size = model.config.n_embd if hasattr(model.config, 'n_embd') else model.config.hidden_size
else:
    n_layers = model.config.num_hidden_layers
    hidden_size = model.config.hidden_size

print(f"\nModel Architecture:")
print(f"  Name: {model_name}")
print(f"  Layers: {n_layers}")
print(f"  Hidden size: {hidden_size}")
print(f"  Parameters: ~{sum(p.numel() for p in model.parameters()) / 1e6:.0f}M")

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 22.02it/s]



Model Architecture:
  Name: microsoft/phi-2
  Layers: 32
  Hidden size: 2560
  Parameters: ~2780M


In [None]:
def collect_activations(prompts, model, tokenizer, max_length=256, batch_size=4):
    """
    Collect activations from all layers.
    
    Args:
        prompts: List of text prompts
        model: The language model
        tokenizer: Tokenizer
        max_length: Maximum sequence length
        batch_size: Batch size for processing
    
    Returns:
        Dictionary mapping layer indices to activation arrays
    """
    activations = {i: [] for i in range(n_layers)}
    
    print(f"Collecting activations from {n_layers} layers...")
    
    # Process in batches
    for batch_start in tqdm(range(0, len(prompts), batch_size)):
        batch_prompts = prompts[batch_start:batch_start + batch_size]
        
        # Tokenize
        inputs = tokenizer(
            batch_prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Forward pass
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
        
        # Extract activations
        hidden_states = outputs.hidden_states
        
        # Get sequence lengths (last non-padding token)
        seq_lengths = inputs['attention_mask'].sum(dim=1) - 1
        
        for layer_idx in range(hidden_states.shape[0]):
            # hidden_states[0] is embedding, hidden_states[1] is layer 0, etc.
            layer_output = hidden_states[layer_idx + 1]
            
            # Get last token activation for each sequence in batch
            for batch_idx in range(layer_output.shape[0]):
                last_token_pos = seq_lengths[batch_idx]
                last_token_activation = layer_output[batch_idx, last_token_pos, :].cpu().numpy()
                activations[layer_idx].append(last_token_activation)
    
    # Convert lists to arrays
    for layer_idx in activations:
        activations[layer_idx] = np.stack(activations[layer_idx])
    
    return outputs

In [24]:
print("\n" + "="*70)
print("Collecting CODE activations...")
print("="*70)

code_activations = collect_activations(
    code_prompts,
    model,
    tokenizer,
    max_length=256,
    batch_size=4
)

math_activations = collect_activations(
    math_prompts,
    model,
    tokenizer,
    max_length=256,
    batch_size=4
)

print(f"✓ Collected code activations")
print(f"  Shape per layer: {code_activations[0].shape}")

print(f"✓ Collected math activations")
print(f"  Shape per layer: {math_activations[0].shape}")


Collecting CODE activations...
Collecting activations from 32 layers...


100%|██████████| 5/5 [00:00<00:00, 21.58it/s]


Collecting activations from 32 layers...


100%|██████████| 5/5 [00:00<00:00, 21.10it/s]

✓ Collected code activations
  Shape per layer: torch.Size([4, 168, 51200])
✓ Collected math activations
  Shape per layer: torch.Size([4, 61, 51200])





In [39]:
code_activations.logits.shape

torch.Size([4, 168, 51200])

In [40]:
math_activations.logits.shape

torch.Size([4, 61, 51200])

In [34]:
code_activations.hidden_states[0].shape

torch.Size([4, 168, 2560])

In [35]:
math_activations.hidden_states[0].shape

torch.Size([4, 61, 2560])

In [21]:
math_activations[0]

array([[ 0.02519 ,  0.1805  ,  0.2485  , ..., -0.01395 , -0.0681  ,
         0.1164  ],
       [-0.0792  , -0.09314 ,  0.405   , ...,  0.095   , -0.1986  ,
        -0.181   ],
       [-0.0823  , -0.004875,  0.2698  , ...,  0.02847 ,  0.02025 ,
         0.0785  ],
       ...,
       [-0.013016,  0.321   ,  0.05008 , ..., -0.03964 , -0.09143 ,
        -0.2032  ],
       [-0.1818  ,  0.1556  ,  0.6733  , ...,  0.12134 ,  0.0448  ,
        -0.00711 ],
       [-0.005325,  0.10626 , -0.05026 , ..., -0.05283 , -0.1613  ,
        -0.1427  ]], shape=(20, 2560), dtype=float16)

In [43]:
activations = {i: [] for i in range(n_layers)}

prompts = code_prompts
max_length=256
batch_size=4
print(f"Collecting activations from {n_layers} layers...")

# Process in batches
for batch_start in tqdm(range(0, len(prompts), batch_size)):
    batch_prompts = prompts[batch_start:batch_start + batch_size]
    
    # Tokenize
    inputs = tokenizer(
        batch_prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_length
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Forward pass
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
    
    # Extract activations
    hidden_states = outputs.hidden_states
    
    # Get sequence lengths (last non-padding token)
    seq_lengths = inputs['attention_mask'].sum(dim=1) - 1
    
    for layer_idx in range(len(hidden_states)-1):
        # hidden_states[0] is embedding, hidden_states[1] is layer 0, etc.
        layer_output = hidden_states[layer_idx + 1]
        print(layer_output.shape)
        
        # Get last token activation for each sequence in batch
        for batch_idx in range(layer_output.shape[0]):
            last_token_pos = seq_lengths[batch_idx]
            last_token_activation = layer_output[batch_idx, last_token_pos, :].cpu().numpy()
            activations[layer_idx].append(last_token_activation)

# Convert lists to arrays
for layer_idx in activations:
    activations[layer_idx] = np.stack(activations[layer_idx])

Collecting activations from 32 layers...


 40%|████      | 2/5 [00:00<00:00, 17.73it/s]

torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 127, 2560])
torch.Size([4, 124, 2560])
torch.Size([4, 124, 2560])
torch.Size([4, 124, 2560])
torch.Size([4, 124, 2560])
torch.Size([4, 124, 2560])
t

 80%|████████  | 4/5 [00:00<00:00, 17.86it/s]

torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])
torch.Size([4, 114, 2560])


100%|██████████| 5/5 [00:00<00:00, 18.50it/s]

torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])
torch.Size([4, 168, 2560])



