# Test env

In [5]:
# Test kernel
print("Hello world")

Hello world


In [6]:
# Test numpy
import numpy as np

a = np.array([1, 2, 3])
print(a)

[1 2 3]


In [7]:
# Test torch
import torch

x = torch.rand(5, 3)
print(x)

tensor([[0.7885, 0.3010, 0.8253],
        [0.8887, 0.7543, 0.4657],
        [0.7037, 0.7670, 0.9824],
        [0.6595, 0.8532, 0.6683],
        [0.1593, 0.9059, 0.1298]])


In [8]:
# Test torch CUDA usage - run nvidia_smi in parallel
from tqdm import tqdm

print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.current_device())  # Should be 0
print(torch.cuda.get_device_name(0))  # Should be RTX 4090

# Do something to keep the GPU busy
a = torch.rand(10000, 10000).cuda()
b = torch.rand(10000, 10000).cuda()

for i in tqdm(range(1000)):
    a = torch.matmul(a, b)

True
1
0
NVIDIA GeForce RTX 4090


100%|██████████| 1000/1000 [00:00<00:00, 199017.98it/s]


In [9]:
print("Checkpoint")

Checkpoint


In [15]:
# Transformer lens loading

from transformer_lens import HookedTransformer
from transformer_lens.utils import tokenize_and_concatenate
from datasets import load_dataset

# Get model
model = HookedTransformer.from_pretrained("gelu-1l")
model = model.cuda()
model = model.to(torch.float16)
print(f"Model device: {next(model.parameters()).device}")

# Get and tokenize data
data = load_dataset("NeelNanda/c4-code-20k", split="train")
tokenized_data = tokenize_and_concatenate(data, model.tokenizer, max_length=128).shuffle(42)
tokens = tokenized_data["tokens"]
tokens = tokens.cuda()
tokens = tokens.to(torch.int32)
print(f"Tokens shape: {tokens.shape}, dtype: {tokens.dtype}, device: {tokens.device}")

Loaded pretrained model gelu-1l into HookedTransformer
Moving model to device:  cuda
Changing model dtype to torch.float16
Model device: cuda:0
Tokens shape: torch.Size([215402, 128]), dtype: torch.int32, device: cuda:0


In [16]:
# Run model
tokens_batched = tokens.split(32)
out = []
for batch in tqdm(tokens_batched[:100]):
    out.append(model(batch).detach().cpu())  # Move to CPU for memory

out = torch.cat(out)
print(f"Output shape: {out.shape}, dtype: {out.dtype}, device: {out.device}")

  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [00:09<00:00, 10.37it/s]


Output shape: torch.Size([3200, 128, 48262]), dtype: torch.float16, device: cpu
