In [10]:
import torch
import transformer_lens, sae_lens
from transformer_lens import HookedTransformer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [49]:
def sparsegpt(W, X, λ=0.01, p=0.5, B=128, Bs=32, device='cuda'):
    '''
    W: weight matrix
    X: input matrix
    λ: regularization parameter
    p: sparsity parameter
    B: batch size
    Bs: block size
    '''

    # print(W.shape, X.shape)
    # Initialize mask and block quantization errors
    M = torch.ones_like(W) # Binary pruning mask
    E = torch.zeros_like(W) # Block quantization errors

    # Compute Hessian inverse
    H = X.T @ X + λ * torch.eye(X.shape[1], device=device)
    H_inv = torch.cholesky(torch.inverse(H)).T

    W = W.detach().clone()  # Prevent gradient tracking for pruning

    # Blockwise pruning
    for i in range(0, W.shape[1], B):
        for j in range(i, min(i + B, W.shape[1])):
            if j % Bs == 0:
                # Select (1-p) fraction of weights based on Hessian information
                W_squared = W[:, j:j + Bs] ** 2
                H_diag = torch.diag(H_inv)[j : j+Bs].unsqueeze(0)
                score = W_squared / H_diag ** 2
                thr = torch.quantile(score, p)  # Pruning threshold
                M[:, j:j + Bs] = (score > thr).float()  # Binary mask

            with torch.no_grad():
                # Prune and update quantization error
                E[:, j] += (1 - M[:, j]) * W[:, j]
                W[:, j] = M[:, j] * W[:, j]  # Apply mask

            # Update weights using Hessian inverse
            W[:, i:(i + B)] = W[:, i:(i + B)] - (E[:, j].unsqueeze(1) * H_inv[j, i:(i + B)]).clone()

    return W * M



def prune_sgpt(model, tokens):
    wts_act = {
        'attn.W_Q': 'attn.hook_q',
        'attn.W_K': 'attn.hook_k',
        'attn.W_V': 'attn.hook_v',
        'attn.W_O': 'hook_attn_out',
        'mlp.W_in': 'mlp.hook_pre',
        'mlp.W_out': 'hook_mlp_out'
    }
    
    for layer in range(model.cfg.n_layers):
        logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)
        
        for wt, act in wts_act.items():
            W = model.get_parameter(f'blocks.{layer}.{wt}')
            X = cache[f'blocks.{layer}.{act}']

            # Detach and clone W to prevent gradient tracking issues
            W = W.detach().clone()

            if W.dim() == 3:
                if 'W_O' in wt:
                    X_norm = X
                    pruned_W = torch.stack([sparsegpt(W[head], X_norm) for head in range(W.shape[0])])
                else:
                    pruned_W = torch.stack([sparsegpt(W[head], X[:, head, :]) for head in range(W.shape[0])])
            else:
                X_norm = X
                pruned_W = sparsegpt(W, X_norm)
            
            with torch.no_grad():
                param = model.get_parameter(f'blocks.{layer}.{wt}')
                param.copy_(pruned_W)  # Update weights without breaking graph

    return model

In [50]:
gpt2: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)

Loaded pretrained model gpt2-small into HookedTransformer


In [51]:
prompt = "Mitigating the risk of extinction from AI should be a global"
answer = " priority"

# First see how the model does without SAEs
transformer_lens.utils.test_prompt(prompt, answer, gpt2)

Tokenized prompt: ['<|endoftext|>', 'Mit', 'igating', ' the', ' risk', ' of', ' extinction', ' from', ' AI', ' should', ' be', ' a', ' global']
Tokenized answer: [' priority']


Top 0th token. Logit: 19.46 Prob: 52.99% Token: | priority|
Top 1th token. Logit: 17.44 Prob:  7.02% Token: | effort|
Top 2th token. Logit: 16.94 Prob:  4.26% Token: | issue|
Top 3th token. Logit: 16.63 Prob:  3.14% Token: | challenge|
Top 4th token. Logit: 16.37 Prob:  2.42% Token: | goal|
Top 5th token. Logit: 16.06 Prob:  1.78% Token: | concern|
Top 6th token. Logit: 15.88 Prob:  1.47% Token: | focus|
Top 7th token. Logit: 15.61 Prob:  1.13% Token: | approach|
Top 8th token. Logit: 15.53 Prob:  1.04% Token: | policy|
Top 9th token. Logit: 15.42 Prob:  0.93% Token: | initiative|


In [52]:
dataset = transformer_lens.utils.get_dataset('openwebtext')

class OpenWebText(torch.utils.data.Dataset):
    def __init__(self, dataset, max_length=1024):
        self.dataset = dataset
        self.max_length = 1024

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]['text']
        tokens = gpt2.to_tokens(text)
        tokens = tokens[:self.max_length]
        return tokens
    
openwebtext = OpenWebText(dataset)

In [54]:
from tqdm import tqdm
pruned_gpt2 = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)
i = 0
for batch in tqdm(openwebtext):
    if i == 2:
        break
    pruned_gpt2 = prune_sgpt(pruned_gpt2, batch)
    i += 1
    # break

Loaded pretrained model gpt2-small into HookedTransformer


  0%|          | 2/10000 [00:49<69:05:47, 24.88s/it]


In [55]:
prompt = "Mitigating the risk of extinction from AI should be a global"
answer = " priority"

# First see how the model does without SAEs
transformer_lens.utils.test_prompt(prompt, answer, pruned_gpt2)


Tokenized prompt: ['<|endoftext|>', 'Mit', 'igating', ' the', ' risk', ' of', ' extinction', ' from', ' AI', ' should', ' be', ' a', ' global']
Tokenized answer: [' priority']


Top 0th token. Logit:   nan Prob:   nan% Token: |!|
Top 1th token. Logit:   nan Prob:   nan% Token: |"|
Top 2th token. Logit:   nan Prob:   nan% Token: |#|
Top 3th token. Logit:   nan Prob:   nan% Token: |$|
Top 4th token. Logit:   nan Prob:   nan% Token: |%|
Top 5th token. Logit:   nan Prob:   nan% Token: |&|
Top 6th token. Logit:   nan Prob:   nan% Token: |'|
Top 7th token. Logit:   nan Prob:   nan% Token: |(|
Top 8th token. Logit:   nan Prob:   nan% Token: |)|
Top 9th token. Logit:   nan Prob:   nan% Token: |*|
