# First try

In [1]:
import torch
import transformer_lens, sae_lens
from transformer_lens import HookedTransformer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
@torch.no_grad()
def sparsegpt(W, X, λ=0.01, p=0.5, B=128, Bs=32, device='cuda'):
    '''
    W: weight matrix
    X: input matrix
    λ: regularization parameter
    p: sparsity parameter
    B: batch size
    Bs: block size
    '''

    # print(W.shape, X.shape)
    # Initialize mask and block quantization errors
    M = torch.ones_like(W) # Binary pruning mask
    E = torch.zeros_like(W) # Block quantization errors

    # Compute Hessian inverse
    H = X.T @ X + λ * torch.eye(X.shape[1], device=device)
    H_inv = torch.cholesky(torch.inverse(H)).T

    # Blockwise pruning
    for i in range(0, W.shape[1], B):
        for j in range(i, min(i + B, W.shape[1])):
            if j % Bs == 0:

                # Select (1-p) fraction of weights based on Hessian information
                W_squared = W[:, j:j + Bs] ** 2
                H_diag = torch.diag(H_inv)[j : j+Bs].unsqueeze(0)
                score = W_squared / H_diag ** 2
                thr = torch.quantile(score, p) # Pruning threshold
                M[:, j:j + Bs] = (score > thr).float() # Binary mask

            # Prune and update quantization error
            E[:, j] += (1 - M[:, j]) * W[:, j]
            W[:, j] = M[:, j] * W[:, j] # Apply mask

            # Update weights using Hessian inverse
            W[:, i:(i + B)] -= E[:, j].unsqueeze(1) * H_inv[j, i:(i + B)]

    return W * M


@torch.no_grad()
def prune_sgpt(model, tokens):
    wts_act = {
    'attn.W_Q': 'attn.hook_q',
    'attn.W_K': 'attn.hook_k',
    'attn.W_V': 'attn.hook_v',
    'attn.W_O': 'hook_attn_out',
    'mlp.W_in': 'mlp.hook_pre',
    'mlp.W_out': 'hook_mlp_out'
    }
    for layer in range(model.cfg.n_layers):
        logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)
        for wt, act in wts_act.items():
            W = model.get_parameter(f'blocks.{layer}.{wt}')
            X = cache[f'blocks.{layer}.{act}']

            if W.dim() == 3:
                if 'W_O' in wt:
                    X_norm = X
                    for head in range(W.shape[0]):
                        W[head] = sparsegpt(W[head], X_norm)
                        
                else:
                    for head in range(W.shape[0]):
                        X_norm = X[:, head, :]
                        W[head] = sparsegpt(W[head], X_norm)
            else:
                X_norm = X
                W = sparsegpt(W, X_norm)
            
    return model

In [87]:
gpt2: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)

Loaded pretrained model gpt2-small into HookedTransformer


In [88]:
prompt = "Mitigating the risk of extinction from AI should be a global"
answer = " priority"

# First see how the model does without SAEs
transformer_lens.utils.test_prompt(prompt, answer, gpt2)

Tokenized prompt: ['<|endoftext|>', 'Mit', 'igating', ' the', ' risk', ' of', ' extinction', ' from', ' AI', ' should', ' be', ' a', ' global']
Tokenized answer: [' priority']


Top 0th token. Logit: 19.46 Prob: 52.99% Token: | priority|
Top 1th token. Logit: 17.44 Prob:  7.02% Token: | effort|
Top 2th token. Logit: 16.94 Prob:  4.26% Token: | issue|
Top 3th token. Logit: 16.63 Prob:  3.14% Token: | challenge|
Top 4th token. Logit: 16.37 Prob:  2.42% Token: | goal|
Top 5th token. Logit: 16.06 Prob:  1.78% Token: | concern|
Top 6th token. Logit: 15.88 Prob:  1.47% Token: | focus|
Top 7th token. Logit: 15.61 Prob:  1.13% Token: | approach|
Top 8th token. Logit: 15.53 Prob:  1.04% Token: | policy|
Top 9th token. Logit: 15.42 Prob:  0.93% Token: | initiative|


In [89]:
dataset = transformer_lens.utils.get_dataset('openwebtext')

class OpenWebText(torch.utils.data.Dataset):
    def __init__(self, dataset, max_length=1024):
        self.dataset = dataset
        self.max_length = 1024

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]['text']
        tokens = gpt2.to_tokens(text)
        tokens = tokens[:self.max_length]
        return tokens
    
openwebtext = OpenWebText(dataset)

In [90]:
from tqdm import tqdm
pruned_gpt2 = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)
i = 0
for batch in tqdm(openwebtext):
    if i == 2:
        break
    pruned_gpt2 = prune_sgpt(pruned_gpt2, batch)
    i += 1
    # break

Loaded pretrained model gpt2-small into HookedTransformer


  0%|          | 0/10000 [00:00<?, ?it/s]

torch.Size([768, 64]) torch.Size([1024, 64])
torch.Size([768, 64]) torch.Size([1024, 64])
torch.Size([768, 64]) torch.Size([1024, 64])
torch.Size([768, 64]) torch.Size([1024, 64])
torch.Size([768, 64]) torch.Size([1024, 64])
torch.Size([768, 64]) torch.Size([1024, 64])
torch.Size([768, 64]) torch.Size([1024, 64])
torch.Size([768, 64]) torch.Size([1024, 64])
torch.Size([768, 64]) torch.Size([1024, 64])
torch.Size([768, 64]) torch.Size([1024, 64])
torch.Size([768, 64]) torch.Size([1024, 64])
torch.Size([768, 64]) torch.Size([1024, 64])
torch.Size([768, 64]) torch.Size([1024, 64])
torch.Size([768, 64]) torch.Size([1024, 64])
torch.Size([768, 64]) torch.Size([1024, 64])
torch.Size([768, 64]) torch.Size([1024, 64])
torch.Size([768, 64]) torch.Size([1024, 64])
torch.Size([768, 64]) torch.Size([1024, 64])
torch.Size([768, 64]) torch.Size([1024, 64])
torch.Size([768, 64]) torch.Size([1024, 64])
torch.Size([768, 64]) torch.Size([1024, 64])
torch.Size([768, 64]) torch.Size([1024, 64])
torch.Size

  0%|          | 1/10000 [00:21<59:19:38, 21.36s/it]

torch.Size([3072, 768]) torch.Size([1024, 768])
torch.Size([768, 64]) torch.Size([727, 64])
torch.Size([768, 64]) torch.Size([727, 64])
torch.Size([768, 64]) torch.Size([727, 64])
torch.Size([768, 64]) torch.Size([727, 64])
torch.Size([768, 64]) torch.Size([727, 64])
torch.Size([768, 64]) torch.Size([727, 64])
torch.Size([768, 64]) torch.Size([727, 64])
torch.Size([768, 64]) torch.Size([727, 64])
torch.Size([768, 64]) torch.Size([727, 64])
torch.Size([768, 64]) torch.Size([727, 64])
torch.Size([768, 64]) torch.Size([727, 64])
torch.Size([768, 64]) torch.Size([727, 64])
torch.Size([768, 64]) torch.Size([727, 64])
torch.Size([768, 64]) torch.Size([727, 64])
torch.Size([768, 64]) torch.Size([727, 64])
torch.Size([768, 64]) torch.Size([727, 64])
torch.Size([768, 64]) torch.Size([727, 64])
torch.Size([768, 64]) torch.Size([727, 64])
torch.Size([768, 64]) torch.Size([727, 64])
torch.Size([768, 64]) torch.Size([727, 64])
torch.Size([768, 64]) torch.Size([727, 64])
torch.Size([768, 64]) torch.

  0%|          | 2/10000 [00:42<58:34:42, 21.09s/it]

torch.Size([3072, 768]) torch.Size([727, 768])





In [95]:
prompt = "Mitigating the risk of extinction from AI should be a global"
answer = " priority"


# First see how the model does without SAEs
transformer_lens.utils.test_prompt(prompt, answer, pruned_gpt2)


Tokenized prompt: ['<|endoftext|>', 'Mit', 'igating', ' the', ' risk', ' of', ' extinction', ' from', ' AI', ' should', ' be', ' a', ' global']
Tokenized answer: [' priority']


Top 0th token. Logit: 19.46 Prob: 52.99% Token: | priority|
Top 1th token. Logit: 17.44 Prob:  7.02% Token: | effort|
Top 2th token. Logit: 16.94 Prob:  4.26% Token: | issue|
Top 3th token. Logit: 16.63 Prob:  3.14% Token: | challenge|
Top 4th token. Logit: 16.37 Prob:  2.42% Token: | goal|
Top 5th token. Logit: 16.06 Prob:  1.78% Token: | concern|
Top 6th token. Logit: 15.88 Prob:  1.47% Token: | focus|
Top 7th token. Logit: 15.61 Prob:  1.13% Token: | approach|
Top 8th token. Logit: 15.53 Prob:  1.04% Token: | policy|
Top 9th token. Logit: 15.42 Prob:  0.93% Token: | initiative|


# Tests

In [1]:
import torch
import transformer_lens, sae_lens
from transformer_lens import HookedTransformer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset = transformer_lens.utils.get_dataset('openwebtext')
gpt2: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained_no_processing("gpt2-small", device=device, n_devices=3)

class OpenWebText(torch.utils.data.Dataset):
    def __init__(self, dataset, max_length=1024):
        self.dataset = dataset
        self.max_length = 1024

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]['text']
        tokens = gpt2.to_tokens(text)
        tokens = tokens[:self.max_length]
        return tokens
    
openwebtext = OpenWebText(dataset)

from tqdm import tqdm
count = 0
data = []
for batch in tqdm(openwebtext):
    if batch.shape[1] == 1024:
        data.append(batch)
        count += 1
data_tensor = torch.cat(data, dim=0)
print(count, data_tensor.shape)



Loaded pretrained model gpt2-small into HookedTransformer


100%|██████████| 10000/10000 [00:34<00:00, 290.75it/s]

3383 torch.Size([3383, 1024])





In [2]:
logits, cache = gpt2.run_with_cache(data_tensor[:8, :], remove_batch_dim=False)

# Please check if the implementation is right:

In [1]:
import math, time, einops, sae_lens, transformer_lens, torch
from tqdm import tqdm

class SparseGPT:
    def __init__(self, W):
        self.dev = W.device
        self.W = W.clone()
        self.rows = W.shape[0]
        self.cols = W.shape[1]
        self.H = torch.zeros((self.cols, self.cols), device=self.dev)
        self.n_samples = 0

    def add_batch(self, X): 
        tmp = X.shape[0]
        self.H *= self.n_samples/(self.n_samples + tmp)
        self.n_samples += tmp
        X = math.sqrt(2/self.n_samples) * X.float()
        self.H += X.T @ X 

    
    def faster_prune(self, W, sparsity=0.25, blocksize=128, percdamp=0.01):
        W = W.clone()
        W = W.float()
        tick = time.time()

        H = self.H
        del self.H
        dead = torch.diag(H) == 0
        H[dead, dead] = 1
        W[:, dead] = 0

        Losses = torch.zeros(self.rows, device=self.dev)

        damp = percdamp * torch.mean(torch.diag(H))
        diag = torch.arange(self.cols, device=self.dev)
        H[diag, diag] += damp
        H = torch.linalg.cholesky(H)
        H = torch.cholesky_inverse(H)
        H = torch.linalg.cholesky(H, upper=True)
        Hinv = H

        mask = None

        for i1 in range(0, self.cols, blocksize):
            i2 = min(i1 + blocksize, self.cols)
            count = i2 - i1

            W1 = W[:, i1:i2].clone()
            Q1 = torch.zeros_like(W1)
            Err1 = torch.zeros_like(W1)
            Losses1 = torch.zeros_like(W1)
            Hinv1 = Hinv[i1:i2, i1:i2]

            if mask is not None:
                mask1 = mask[:, i1:i2]
            else:  
                tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2
                thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)]
                mask1 = tmp <= thresh

            for i in range(count):
                w = W1[:, i]
                d = Hinv1[i, i]

                q = w.clone()
                q[mask1[:, i]] = 0

                Q1[:, i] = q
                Losses1[:, i] = (w - q) ** 2 / d ** 2

                err1 = (w - q) / d 
                W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
                Err1[:, i] = err1

            W[:, i1:i2] = Q1
            Losses += torch.sum(Losses1, 1) / 2

            W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])

        torch.cuda.synchronize()
        self.W = W.reshape((self.rows, self.cols)).to(self.dev)
        
    def free(self):
        self.H = None
        torch.cuda.empty_cache()


@torch.no_grad()
def prune_sparsegpt(model, tokens):
    print("Starting pruning ...")
    wts_act = {
    'attn.W_Q': 'attn.hook_q',
    'attn.W_K': 'attn.hook_k',
    'attn.W_V': 'attn.hook_v',
    'attn.W_O': 'hook_attn_out',
    'mlp.W_in': 'mlp.hook_pre',
    'mlp.W_out': 'hook_mlp_out'
    }

    logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)
    for layer in range(model.cfg.n_layers):
        layer_cache = {k: v for k, v in cache.items() if f'blocks.{layer}.' in k}

        for wt, act in wts_act.items():
            W = model.get_parameter(f'blocks.{layer}.{wt}')
            X = layer_cache[f'blocks.{layer}.{act}']


            if W.dim() == 2:
                sparsegpt_object = SparseGPT(W)
                sparsegpt_object.add_batch(X)
                sparsegpt_object.faster_prune(W)
                W.copy_(sparsegpt_object.W)
                sparsegpt_object.free()

            else:
                if 'W_O' in wt:
                    for head in range(W.shape[0]):
                        sparsegpt_object = SparseGPT(W[head])
                        sparsegpt_object.add_batch(X)
                        sparsegpt_object.faster_prune(W[head])
                        W[head].copy_(sparsegpt_object.W)
                        sparsegpt_object.free()


                else:
                    for head in range(W.shape[0]):
                        sparsegpt_object = SparseGPT(W[head])
                        sparsegpt_object.add_batch(X[:, head, :])
                        sparsegpt_object.faster_prune(W[head])
                        W[head].copy_(sparsegpt_object.W)
                        sparsegpt_object.free()

            del sparsegpt_object
            torch.cuda.empty_cache()

    del layer_cache
    torch.cuda.empty_cache()
    


pruned_gpt2 = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device="cuda")
dataset = transformer_lens.utils.get_dataset('openwebtext')

class OpenWebText(torch.utils.data.Dataset):
    def __init__(self, dataset, max_length=1024):
        self.dataset = dataset
        self.max_length = 1024

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]['text']
        tokens = pruned_gpt2.to_tokens(text)
        tokens = tokens[:self.max_length]
        return tokens
    
openwebtext = OpenWebText(dataset)

data = []
for batch in tqdm(openwebtext):
    if batch.shape[1] == 1024:
        data.append(batch)

data_tensor = torch.cat(data, dim=0)[:8, :]
prune_sparsegpt(pruned_gpt2, data_tensor)

Loaded pretrained model gpt2-small into HookedTransformer


100%|██████████| 10000/10000 [00:34<00:00, 286.78it/s]


Starting pruning ...


In [None]:
gpt2: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device="cuda")
torch.abs(pruned_gpt2.W_Q - gpt2.W_Q).mean().item()

Loaded pretrained model gpt2-small into HookedTransformer


0.003115949220955372

In [3]:
prompt = "Mitigating the risk of extinction from AI should be a global"
answer = " priority"


# First see how the model does without SAEs
transformer_lens.utils.test_prompt(prompt, answer, pruned_gpt2)

Tokenized prompt: ['<|endoftext|>', 'Mit', 'igating', ' the', ' risk', ' of', ' extinction', ' from', ' AI', ' should', ' be', ' a', ' global']
Tokenized answer: [' priority']


Top 0th token. Logit: 15.88 Prob: 12.87% Token: | issue|
Top 1th token. Logit: 15.45 Prob:  8.34% Token: | concern|
Top 2th token. Logit: 14.68 Prob:  3.85% Token: | problem|
Top 3th token. Logit: 14.63 Prob:  3.66% Token: | priority|
Top 4th token. Logit: 13.89 Prob:  1.75% Token: | cause|
Top 5th token. Logit: 13.85 Prob:  1.68% Token: | one|
Top 6th token. Logit: 13.82 Prob:  1.63% Token: | phenomenon|
Top 7th token. Logit: 13.65 Prob:  1.37% Token: | trend|
Top 8th token. Logit: 13.64 Prob:  1.36% Token: | challenge|
Top 9th token. Logit: 13.39 Prob:  1.06% Token: | policy|


In [4]:
prompt = "Mary and John went to the park to play. Mary gave the ball to"
answer = " John"

transformer_lens.utils.test_prompt(prompt, answer, gpt2)
transformer_lens.utils.test_prompt(prompt, answer, pruned_gpt2)


Tokenized prompt: ['<|endoftext|>', 'Mary', ' and', ' John', ' went', ' to', ' the', ' park', ' to', ' play', '.', ' Mary', ' gave', ' the', ' ball', ' to']
Tokenized answer: [' John']


Top 0th token. Logit: 17.57 Prob: 56.01% Token: | John|
Top 1th token. Logit: 15.73 Prob:  8.86% Token: | Mary|
Top 2th token. Logit: 15.53 Prob:  7.25% Token: | the|
Top 3th token. Logit: 15.41 Prob:  6.43% Token: | her|
Top 4th token. Logit: 14.15 Prob:  1.82% Token: | a|
Top 5th token. Logit: 14.01 Prob:  1.60% Token: | them|
Top 6th token. Logit: 13.39 Prob:  0.86% Token: | his|
Top 7th token. Logit: 13.25 Prob:  0.74% Token: | Jesus|
Top 8th token. Logit: 12.88 Prob:  0.51% Token: | Joseph|
Top 9th token. Logit: 12.85 Prob:  0.50% Token: | me|


Tokenized prompt: ['<|endoftext|>', 'Mary', ' and', ' John', ' went', ' to', ' the', ' park', ' to', ' play', '.', ' Mary', ' gave', ' the', ' ball', ' to']
Tokenized answer: [' John']


Top 0th token. Logit: 13.59 Prob: 26.24% Token: | the|
Top 1th token. Logit: 12.10 Prob:  5.88% Token: | a|
Top 2th token. Logit: 11.29 Prob:  2.63% Token: | his|
Top 3th token. Logit: 10.64 Prob:  1.37% Token: | John|
Top 4th token. Logit: 10.53 Prob:  1.23% Token: | be|
Top 5th token. Logit: 10.34 Prob:  1.01% Token: | him|
Top 6th token. Logit: 10.29 Prob:  0.97% Token: | me|
Top 7th token. Logit: 10.09 Prob:  0.79% Token: | an|
Top 8th token. Logit: 10.04 Prob:  0.75% Token: | some|
Top 9th token. Logit:  9.86 Prob:  0.63% Token: | her|
