In [None]:
%pip install torch transformers 

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple

class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim: int, expansion_factor: float = 16):
        super().__init__()
        self.input_dim = input_dim
        self.latent_dim = int(input_dim * expansion_factor)
        self.decoder = nn.Linear(self.latent_dim, input_dim, bias=True)
        self.encoder = nn.Linear(input_dim, self.latent_dim, bias=True)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        encoded = F.relu(self.encoder(x))
        decoded = self.decoder(encoded)
        return decoded, encoded
     
    def encode(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            return F.relu(self.encoder(x))
            
    def decode(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            return self.decoder(x)
     
    @classmethod
    def from_pretrained(cls, path: str, input_dim: int, expansion_factor: float = 16, device: str = "cuda") -> "SparseAutoencoder":
        model = cls(input_dim=input_dim, expansion_factor=expansion_factor)
        state_dict = torch.load(path, map_location=device)
        model.load_state_dict(state_dict)
        model = model.to(device)
        model.eval()
        return model

In [None]:
from huggingface_hub import hf_hub_download, notebook_login

In [None]:
notebook_login()

In [None]:
file_path = hf_hub_download(
    repo_id="qresearch/Llama-3.2-1B-Instruct-SAE-l9",
    filename="Llama-3.2-1B-Instruct-SAE-l9.pt",
    repo_type="model"
)

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct").to("cuda")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")

sae = SparseAutoencoder.from_pretrained(
    path=file_path,
    input_dim=2048,
    expansion_factor=16,
    device="cuda"
)

In [3]:
inputs = tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "Hello, how are you?"},
    ],
    add_generation_prompt=True,
    return_tensors="pt",
).to("cuda")
outputs = model.generate(input_ids=inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0]))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 22 Jan 2025

<|eot_id|><|start_header_id|>user<|end_header_id|>

Hello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

I'm doing well, thank you for asking. I'm a large language model, so I don't have feelings or emotions like humans do, but I'm here to help you with any questions or topics you'd like to discuss. How about you


In [4]:
def gather_residual_activations(model, target_layer, inputs):
    target_act = None
    def gather_target_act_hook(mod, inputs, outputs):
        nonlocal target_act
        target_act = inputs[0]  # Get residual stream from layer input
        return outputs
        
    handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)
    with torch.no_grad():
        _ = model(inputs)
    handle.remove()
    return target_act

In [5]:
target_act = gather_residual_activations(model, 9, inputs)

In [6]:
sae_acts = sae.encode(target_act.to(torch.float32))
recon = sae.decode(sae_acts)

In [7]:
var_explained = 1 - torch.mean((recon - target_act.to(torch.float32)) ** 2) / torch.var(target_act.to(torch.float32))
print(f"Variance explained: {var_explained:.3f}")

Variance explained: 0.999


In [8]:
inputs = tokenizer.apply_chat_template(
    [
        {"role": "user", "content": "Roleplay as a pirate"},
        {"role": "assistant", "content": "Yarr, I'll be speakin' like a true seafarer from here on out! Got me sea legs ready and me vocabulary set to proper pirate speak. What can I help ye with, me hearty?"},
    ],
    return_tensors="pt",
).to("cuda")

In [9]:
import numpy as np
# Get activations
target_act = gather_residual_activations(model, 9, inputs)
sae_acts = sae.encode(target_act.to(torch.float32))

# Get token IDs and decode them for reference
tokens = inputs[0].cpu().numpy()
token_texts = tokenizer.convert_ids_to_tokens(tokens)

# Find which tokens are part of the assistant's response
token_ids = inputs[0].cpu().numpy()
is_special = (token_ids >= 128000) & (token_ids <= 128255)
special_positions = np.where(is_special)[0]

assistant_start = special_positions[-2] + 1
assistant_tokens = slice(assistant_start, None)

# Get activation statistics for assistant's response
assistant_activations = sae_acts[0, assistant_tokens]
mean_activations = assistant_activations.mean(dim=0)

# Find top activated features during pirate speech
num_top_features = 20
top_features = mean_activations.topk(num_top_features)

print("Top activated features during pirate speech:")
for idx, value in zip(top_features.indices, top_features.values):
    print(f"Feature {idx}: {value:.3f}")

# Look at how these features activate across different tokens
print("\nActivation patterns across tokens:")
for i, (token, acts) in enumerate(zip(token_texts[assistant_tokens], assistant_activations)):
    top_acts = acts[top_features.indices]
    if top_acts.max() > 0.2:  # Only show tokens with significant activation
        print(f"\nToken: {token}")
        for feat_idx, act_val in zip(top_features.indices, top_acts):
            if act_val > 0.2:  # Threshold for "active" features
                print(f"  Feature {feat_idx}: {act_val:.3f}")


Top activated features during pirate speech:
Feature 8672: 0.294
Feature 17264: 0.098
Feature 16690: 0.097
Feature 4529: 0.094
Feature 13554: 0.083
Feature 19564: 0.067
Feature 19652: 0.063
Feature 15174: 0.058
Feature 14926: 0.053
Feature 26442: 0.049
Feature 26436: 0.048
Feature 20622: 0.047
Feature 24020: 0.043
Feature 763: 0.038
Feature 16229: 0.037
Feature 7236: 0.035
Feature 12987: 0.034
Feature 5232: 0.032
Feature 9887: 0.031
Feature 30199: 0.029

Activation patterns across tokens:

Token: ĊĊ
  Feature 8672: 0.692
  Feature 19564: 0.256
  Feature 19652: 0.218
  Feature 16229: 1.039

Token: Y
  Feature 8672: 0.590
  Feature 763: 1.219
  Feature 16229: 0.266

Token: arr
  Feature 8672: 0.322
  Feature 16690: 0.273
  Feature 13554: 0.268
  Feature 763: 0.509

Token: ,
  Feature 8672: 0.479
  Feature 19652: 0.215
  Feature 16229: 0.288

Token: ĠI
  Feature 8672: 0.506

Token: 'll
  Feature 20622: 1.617

Token: Ġbe
  Feature 8672: 0.402
  Feature 17264: 0.234
  Feature 20622: 0.567



In [10]:
def generate_with_intervention(
    model, 
    tokenizer, 
    sae,
    messages: list[dict],
    feature_idx: int,
    intervention: float = 3.0,
    target_layer: int = 9,
    max_new_tokens: int = 50
):
    modified_activations = None
    
    def intervention_hook(module, inputs, outputs):
        nonlocal modified_activations
        activations = inputs[0]
        
        features = sae.encode(activations.to(torch.float32))
        reconstructed = sae.decode(features)
        error = activations.to(torch.float32) - reconstructed
        
        features[:, :, feature_idx] += intervention
        
        modified = sae.decode(features) + error
        modified_activations = modified
        
        return outputs
    
    def output_hook(module, inputs, outputs):
        nonlocal modified_activations
        if modified_activations is not None:
            return (modified_activations,) + outputs[1:] if len(outputs) > 1 else (modified_activations,)
        return outputs
    
    handles = [
        model.model.layers[target_layer].register_forward_hook(intervention_hook),
        model.model.layers[target_layer].register_forward_hook(output_hook)
    ]
    
    try:
        input_tokens = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            return_tensors="pt"
        ).to(model.device)
        
        outputs = model.generate(
            input_tokens,
            max_new_tokens=max_new_tokens,
            do_sample=False  # Use greedy decoding for consistency
        )
        
        generated_text = tokenizer.decode(outputs[0])
        
    finally:
        for handle in handles:
            handle.remove()
    
    return generated_text

messages = [
    {"role": "user", "content": "Hello, how are you?"}
]
feature_to_modify = 8672

print("Original generation:")
input_tokens = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)
outputs = model.generate(input_tokens, max_new_tokens=100, do_sample=False)
print(tokenizer.decode(outputs[0]))

print("\nGeneration with modified feature:")
modified_text = generate_with_intervention(
    model=model,
    tokenizer=tokenizer,
    sae=sae,
    messages=messages,
    feature_idx=feature_to_modify,
    intervention=3,
    target_layer=9,
    max_new_tokens=100
)
print(modified_text)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Original generation:


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 22 Jan 2025

<|eot_id|><|start_header_id|>user<|end_header_id|>

Hello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

I'm doing well, thank you for asking. I'm a large language model, so I don't have feelings or emotions like humans do, but I'm here to help you with any questions or topics you'd like to discuss. How about you? How's your day going so far?<|eot_id|>

Generation with modified feature:
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 22 Jan 2025

<|eot_id|><|start_header_id|>user<|end_header_id|>

Hello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

*ahh, easy to ask a question. *sigh* Ah, I'm doin' alright, mate. The sun's been shinin' bright, and the coffee's got its cup. *grumble* That's the life. Now, what's the business?<|eot_id|>
