In [1]:
!pip install transformer-lens datasets tqdm


Collecting transformer-lens
  Downloading transformer_lens-2.9.1-py3-none-any.whl.metadata (12 kB)
Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting tqdm
  Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.7/57.7 kB[0m [31m122.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate>=0.23.0 (from transformer-lens)
  Downloading accelerate-1.1.1-py3-none-any.whl.metadata (19 kB)
Collecting beartype<0.15.0,>=0.14.1 (from transformer-lens)
  Downloading beartype-0.14.1-py3-none-any.whl.metadata (28 kB)
Collecting better-abc<0.0.4,>=0.0.3 (from transformer-lens)
  Downloading better_abc-0.0.3-py3-none-any.whl.metadata (1.4 kB)
Collecting einops>=0.6.0 (from transformer-lens)
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Collecting fancy-einsum>=0.0.3 (from transformer-lens)
  Downloading fancy_einsum-0.0.3-py3-none-any.whl.metadata (1.2 kB)
Collecting ja

In [2]:
import os
# list all files and dirs currently in volume
os.listdir("volume")


['gemma_2b_activations.pt',
 'gemma_2b_it_activations.pt',
 'gemma_9b_activations.pt',
 'gemma_9b_it_activations.pt',
 'untitled.txt']

In [3]:
from transformer_lens import HookedTransformer
from datasets import load_dataset
from tqdm import tqdm
import torch
import os

os.environ["HF_TOKEN"] = "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"

torch.set_grad_enabled(False)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# download models and move to GPU with bf16
gemma_2b_it = HookedTransformer.from_pretrained(
    "gemma-2b-it",
    torch_dtype=torch.bfloat16, 
).to(device)

# prep data
dataset = load_dataset("allenai/c4", "en", split="train", streaming=True)
dataset_iter = iter(dataset)
first_500_points = [next(dataset_iter)["text"] for _ in range(500)]

del dataset, dataset_iter
tokens = gemma_2b_it.to_tokens(
    first_500_points, prepend_bos=True, padding_side="left", move_to_device=True
).detach()[:, :2048]
del first_500_points

# prepare hooks
gemma_2b_activations = []
gemma_2b_it_activations = []

print(tokens.shape)

def gemma_2b_hook(activations, hook):
    act = activations[..., -1, :].detach().cpu()  # Keep on GPU
    gemma_2b_activations.append(act)

def gemma_2b_it_hook(activations, hook):
    act = activations[..., -1, :].detach().cpu()  # Keep on GPU
    gemma_2b_it_activations.append(act)

# Run inference
BATCH_SIZE = 5
for i in tqdm(range(0, len(tokens), BATCH_SIZE)):
    batch = tokens[i:i + BATCH_SIZE]
    ans = gemma_2b_it.run_with_hooks(
        batch,
        fwd_hooks=[("blocks.12.hook_resid_post", gemma_2b_it_hook)],
    )
    del ans
    torch.cuda.empty_cache()

del gemma_2b_it
torch.cuda.empty_cache()

# Load second model to GPU with bf16
gemma_2b = HookedTransformer.from_pretrained(
    "gemma-2b",
    torch_dtype=torch.bfloat16
).to(device)

for i in tqdm(range(0, len(tokens), BATCH_SIZE)):
    batch = tokens[i:i + BATCH_SIZE]
    ans = gemma_2b.run_with_hooks(
        batch,
        fwd_hooks=[("blocks.12.hook_resid_post", gemma_2b_hook)],
    )
    del ans
    torch.cuda.empty_cache()
# concatenate activations (still on GPU)
gemma_2b_activations = torch.cat(gemma_2b_activations, dim=0)
gemma_2b_it_activations = torch.cat(gemma_2b_it_activations, dim=0)

del tokens, gemma_2b
torch.cuda.empty_cache()

# Move to CPU only when saving
torch.save(gemma_2b_activations, "volume/gemma_2b_activations.pt")
torch.save(gemma_2b_it_activations, "volume/gemma_2b_it_activations.pt")


  from .autonotebook import tqdm as notebook_tqdm
Downloading shards: 100%|██████████| 2/2 [02:02<00:00, 61.20s/it] 
`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  7.28it/s]


Loaded pretrained model gemma-2b-it into HookedTransformer
Moving model to device:  cuda
torch.Size([500, 2048])


100%|██████████| 100/100 [00:34<00:00,  2.92it/s]
Downloading shards: 100%|██████████| 2/2 [01:59<00:00, 59.91s/it] 
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  5.97it/s]


Loaded pretrained model gemma-2b into HookedTransformer
Moving model to device:  cuda


100%|██████████| 100/100 [00:34<00:00,  2.94it/s]


In [4]:
# reset gpu memory
torch.cuda.empty_cache()
# print used memory
print(torch.cuda.memory_summary(device="cuda", abbreviated=False))


|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  64377 KiB |  12181 MiB |  37340 GiB |  37340 GiB |
|       from large pool |  64377 KiB |  12142 MiB |  37320 GiB |  37320 GiB |
|       from small pool |      0 KiB |     38 MiB |     19 GiB |     19 GiB |
|---------------------------------------------------------------------------|
| Active memory         |  64377 KiB |  12181 MiB |  37340 GiB |  37340 GiB |
|       from large pool |  64377 KiB |  12142 MiB |  37320 GiB |  37320 GiB |
|       from small pool |      0 KiB |     38 MiB |     19 GiB |     19 GiB |
|---------------------------------------------------------------