In [1]:
# Install dependencies
! pip install transformers datasets soundfile librosa tqdm wandb plotly

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [2]:
import json
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import Dataset, load_dataset
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from tqdm import tqdm
from typing import List
from huggingface_hub import hf_hub_download
from IPython.display import Audio, display, HTML

import plotly.graph_objects as go





In [3]:
# Load the wave2vec2 model
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", sampling_rate=16000)
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(DEVICE)

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
# Load the dataset 
# Custom dataset created by splitting the librispeech_asr train-clean-100hr dataset into 2s clips.

dataset = load_dataset("pavanyellow/librispeech_asr", cache_dir="audio_data/")['train']
print(f"Loaded {len(dataset)} datapoints")

# Simple trancription to test the model
def transcribe(audio : List[float], model, processor) -> str:
    input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values
    with torch.no_grad():
        logits = model(input_values.to(DEVICE)).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.decode(predicted_ids[0])
    return transcription


idx = torch.randint(0, len(dataset), (1,)).item()
audio : List[float] = dataset[idx]['audio']['array']
display(Audio(audio, rate=16000))
print(transcribe(audio, model, processor))




Loaded 20005 datapoints


AB IN SUCH A CASE FOR AT LEAST TWENTY


In [5]:
# Define the SAE

@dataclass
class SAEConfig:
    input_dim: int = 768
    hidden_dim: int = 768*8
    l1_penalty: float = 3
    num_epochs: int = 12000
    batch_size: int = 16384
    learning_rate: float = 1e-4
    val_split: float = 0.2
     

# Follows the architecture from https://transformer-circuits.pub/2024/april-update/index.html#training-saes
class SparseAutoencoder(nn.Module):
    def __init__(self, cfg: SAEConfig):
        super(SparseAutoencoder, self).__init__()
        self.encoder : nn.Module = nn.Linear(cfg.input_dim, cfg.hidden_dim, bias=False)
        self.decoder : nn.Module = nn.Linear(cfg.hidden_dim, cfg.input_dim, bias=False)
        self.encoder_bias = nn.Parameter(torch.zeros(cfg.hidden_dim))
        self.decoder_bias = nn.Parameter(torch.zeros(cfg.input_dim))
        self.input_dim = cfg.input_dim
        self.hidden_dim = cfg.hidden_dim
        self.l1_penalty = cfg.l1_penalty
        self.init_weights()
        self.to(DEVICE)
    
    def count_params(self):
        return sum(p.numel() for p in self.parameters())

    def init_weights(self):
            nn.init.uniform_(self.decoder.weight, -1, 1)  # Random directions
            with torch.no_grad():
                norms = torch.rand(self.hidden_dim) * 0.95 + 0.05  # Random norms between 0.05 and 1
                self.decoder.weight.div_(self.decoder.weight.norm(dim=0, keepdim=True))
                self.decoder.weight.mul_(norms)
            
            self.encoder.weight.data.copy_(self.decoder.weight.data.t())

    def forward(self, x):
        encoded = F.relu(self.encoder(x) + self.encoder_bias)
        decoded = self.decoder(encoded) + self.decoder_bias
        return decoded, encoded
    


def sae_loss(X, reconstructed_X, encoded_X, W_d, lambda_val):
    mse_loss = ((X - reconstructed_X) ** 2).sum(dim = -1).mean(0)
    sparsity_loss = lambda_val * ((torch.norm(W_d, p=2, dim=0)*encoded_X).sum(dim=1)).mean(0) 
    return (mse_loss + sparsity_loss, mse_loss)

In [6]:
# Load the SAE
# SAE trained on 20M activations (~100 hours of audio) and has 6144 features.

sae_path = hf_hub_download(
    repo_id="pavanyellow/wave-sae",
    filename="sae-resid-layer-6.pt"
)
cfg = SAEConfig()
sae = SparseAutoencoder(cfg) 
state_dict = torch.load(sae_path)
sae.load_state_dict(state_dict)
print(sae.eval(), cfg)



SparseAutoencoder(
  (encoder): Linear(in_features=768, out_features=6144, bias=False)
  (decoder): Linear(in_features=6144, out_features=768, bias=False)
) SAEConfig(input_dim=768, hidden_dim=6144, l1_penalty=3, num_epochs=12000, batch_size=16384, learning_rate=0.0001, val_split=0.2)


In [8]:
# Collecting the activations for the SAE, residual stream and feed forward neurons.

# Pytorch hooks to capture activations
class ActivationCapturer:

    def __init__(self):
        self.resid_acts = torch.tensor([]).to(DEVICE)
        self.ffn_acts = torch.tensor([]).to(DEVICE)
    def hook_fn_resid(self, module, input, output):
        self.resid_acts = output[0][0]
    
    def hook_fn_feedforward(self, module, input, output):
        self.ffn_acts = output[0]



capturer = ActivationCapturer()
layer_idx = 5

resid_hook = model.wav2vec2.encoder.layers[layer_idx].register_forward_hook(capturer.hook_fn_resid) # Residual stream
ffn_hook = model.wav2vec2.encoder.layers[layer_idx].feed_forward.intermediate_act_fn.register_forward_hook(capturer.hook_fn_feedforward) # Feed forward 



total_datapoints = len(dataset)
total_datapoints = 10000 # Reduce this if you are running out of GPU memory. 
dataset_map = {}

features_to_store = 100 # No of sae features to store

datapoints = []
resid_acts = torch.tensor([]).to(DEVICE)
ffn_acts = torch.tensor([]).to(DEVICE)
sae_acts = torch.tensor([]).to(DEVICE)
indices = torch.tensor([]).to(DEVICE)
positions = torch.tensor([]).to(DEVICE)
for i in tqdm(range(total_datapoints)):
    
    input_values = processor(dataset[i]["audio"]["array"], return_tensors="pt", padding="longest", sampling_rate=16000).input_values.to(DEVICE)
    datapoints.append(input_values)

    with torch.no_grad():
        logits = model(input_values).logits
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = processor.decode(predicted_ids[0])
    
    # Storing in memory to avoid reloading the dataset
    dataset_map[i] = {
        "id": dataset[i]["id"],
        "path": dataset[i]["file"],
        "transcription": transcription
    }
    
    resid_acts = torch.cat([resid_acts, capturer.resid_acts[:, :features_to_store]], dim=0)
    ffn_acts = torch.cat([ffn_acts, capturer.ffn_acts[:, :features_to_store]], dim=0)
    indices = torch.cat([indices, torch.tensor([i]*capturer.resid_acts.shape[0], device=DEVICE)], dim=0)
    positions = torch.cat([positions, torch.arange(capturer.resid_acts.shape[0], device=DEVICE)/capturer.resid_acts.shape[0]], dim=0) # normalize to [0, 1]


    decoded, encoded = sae(capturer.resid_acts)
    normalized_encoded = encoded*torch.norm(sae.decoder.weight, p=2, dim=0) # Follows the normalization from the April Update
    sae_acts = torch.cat([sae_acts, normalized_encoded[:, :features_to_store]], dim=0)

resid_hook.remove()
ffn_hook.remove()


  0%|          | 2/10000 [00:00<10:02, 16.59it/s]

100%|██████████| 10000/10000 [07:11<00:00, 23.18it/s]


In [9]:
print(f"sae_acts: {sae_acts.shape}, resid_acts: {resid_acts.shape}, ffn_acts: {ffn_acts.shape}, indices: {indices.shape}")


sae_acts: torch.Size([990000, 100]), resid_acts: torch.Size([990000, 100]), ffn_acts: torch.Size([990000, 100]), indices: torch.Size([990000])


In [10]:
def get_activation_percetage(acts):
    return (acts > 0).sum()/acts.numel() * 100

print(f"Residual stream activations %: {get_activation_percetage(resid_acts)}")
print(f"Feed forward activations %: {get_activation_percetage(ffn_acts)}")
print(f"SAE activations %: {get_activation_percetage(sae_acts)}")


Residual stream activations %: 49.00141525268555
Feed forward activations %: 3.0695383548736572
SAE activations %: 0.1665363609790802


In [14]:
def line_plot(tensor, title=None):
    cpu_tensor = tensor.detach().cpu().numpy()
    fig = go.Figure(data=go.Scatter(
        y=cpu_tensor,
        mode='lines',
        hovertemplate='Index: %{x}<br>Value: %{y:.4f}<extra></extra>'
    ))
    
    fig.update_layout(
        title=title,
        xaxis_title="Position",
        yaxis_title="Activation Value",
        showlegend=False,
        hovermode='x'
    )
    
    fig.show()

# Plotting the activations of a feature on the time scale. Notice that the SAE and FFN activations are sparse and have a peaks which are continuous.
# Zoom in around the peaks

index = 34
line_plot(resid_acts[:10000,index], title=f"Resid Neuron {index} Activations")
line_plot(ffn_acts[:10000,index], title=f"FFN Neuron {index} Activations")
line_plot(sae_acts[:10000,index], title=f"SAE Neuron {index} Activations")

In [15]:
@dataclass
class ActsData:
    id : str 
    dataset_index : int 
    activation: int 
    position: float 
    path: str 
    transcript: str 
    
    

def get_top_activating_datapoints(feature_acts, indices, positions):
    act_data_cnt = feature_acts.shape[1]
    acts_data = [[] for i in range(act_data_cnt)]

    for i in tqdm(range(act_data_cnt)):
        all_activations = feature_acts[:,i]
        top_activating_datapoints = torch.argsort(all_activations, descending=True)[:1000]
        dataset_indices = set()
        for j, top_datapoint_idx in enumerate(top_activating_datapoints):
            dataset_index = int(indices[top_datapoint_idx].item())
            if dataset_index in dataset_indices:
                continue
            dataset_indices.add(dataset_index)
            acts_data[i].append(
                ActsData(
                    dataset_index=dataset_index,
                    activation=round(all_activations[top_datapoint_idx].item(), 3),
                    position=round(positions[top_datapoint_idx].item(), 3),
                    id=dataset_map[dataset_index]["id"],
                    path=dataset_map[dataset_index]["path"],
                    transcript=dataset_map[dataset_index]["transcription"]
                )
            )
            if len(acts_data[i]) >= 10:
                break
    
    return acts_data
    

        

resid_acts_data = get_top_activating_datapoints(resid_acts, indices, positions) # Baseline
ffn_acts_data = get_top_activating_datapoints(ffn_acts, indices, positions)
sae_acts_data = get_top_activating_datapoints(sae_acts, indices, positions)


100%|██████████| 100/100 [00:00<00:00, 416.43it/s]
100%|██████████| 100/100 [00:00<00:00, 525.40it/s]
100%|██████████| 100/100 [00:00<00:00, 481.42it/s]


In [22]:
# Free up GPU memory 
resid_acts.to('cpu')
ffn_acts.to('cpu')
sae_acts.to('cpu')


tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], grad_fn=<ToCopyBackward0>)

In [18]:
def display_activation_examples(feature_index, activation_type='ffn', num_examples=5):
    map = {'sae': sae_acts_data, 'resid': resid_acts_data, 'ffn': ffn_acts_data}
    
    html_output = f"""
    <h3>Top activating examples for {activation_type.upper()} feature {feature_index}</h3>
    <div style="display: flex; flex-direction: column; gap: 10px;">
    """
    
    for acts in map[activation_type][feature_index][:num_examples]:
        audio = dataset[acts.dataset_index]["audio"]["array"]
        audio_obj = Audio(audio, rate=16000)

        model_transcript = dataset_map[acts.dataset_index]["transcription"]
        
        html_output += f"""
        <div style="border: 1px solid #ddd; padding: 8px; border-radius: 5px;">
            <div style="display: flex; align-items: center; gap: 8px;">
                {audio_obj._repr_html_()}
                <div style="flex-grow: 1;">
                    <div><strong>Activation: {acts.activation:.3f}</strong></div>
                </div>
            </div>
            <div style="margin-top: 5px;">
                <div><strong>Output:</strong> {model_transcript}</div>
            </div>
        </div>
        """
    
    html_output += "</div>"
    display(HTML(html_output))

In [55]:
"""
Feature to try out:

'sae' features:
23 - 'sk' sound
55 - over
62 - Mountain
46 - 'ual' sound across various tokens like 'usual', 'actual' and 'equal'
61 - 'take' token
98 - 'mast' sound
81 - 'that' token
27 - 'an'/'ang' sound in various tokens
17 - 'horse' 
Negative examples: 45, 72 ('kah' sound?), 0, 73 and many more.

'ffn' neurons: # set activation type to 'ffn' below
30 - wen,won
33 - 'nat' phenotype (sensitive to nasal sounds)
35 - after
41 - 'tra' phenotype (sensitive to 'tra' sounds)
47 - 'let me' token
99 - 'tion' sound in 'action' and 'inspection'
"""



display_activation_examples(feature_index=23, activation_type='sae', num_examples=5)



In [56]:
"""
Compared to language models, the feedforward neurons in wave2vec2 are sparse and are pretty interpetable.
30 - wen,won
33 - 'nat' phenotype (sensitive to nasal sounds)
35 - after
41 - 'tra' phenotype (sensitive to 'tra' sounds)
47 - 'let me' token
99 - 'tion' sound in 'action' and 'inspection'
"""

display_activation_examples(feature_index=30, activation_type='ffn', num_examples=5)
