# Pythia 160m

In [1]:
from transformer_lens import HookedTransformer
import transformer_lens.utils as utils
model = HookedTransformer.from_pretrained("EleutherAI/pythia-160m")

`torch_dtype` is deprecated! Use `dtype` instead!


Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer


In [3]:
prompt = "A screen reader is"
output = model.generate(prompt, max_new_tokens=10, temperature=0)
print(f"Input: {prompt}")
print(f"Output: {output}")

# Cache internal states
logits, cache = model.run_with_cache(prompt)
print(f"\nCached {len(cache)} different activation points!")

  0%|          | 0/10 [00:00<?, ?it/s]

Input: A screen reader is
Output: A screen reader is a device that allows a user to view a screen

Cached 219 different activation points!


In [6]:
# Test if it knows facts about France
test_prompts = [
    "A screen reader is",
    "WCAG stands for", 
    "A skip link is",
    "The purpose of alt text is"
]

for prompt in test_prompts:
    output = model.generate(prompt, max_new_tokens=10, temperature=0)
    print(f"{prompt:30} → {output[len(prompt):]}")

  0%|          | 0/10 [00:00<?, ?it/s]

A screen reader is             →  a device that allows a user to view a screen


  0%|          | 0/10 [00:00<?, ?it/s]

WCAG stands for                →  the International Commission on Geographic Names (ICCG


  0%|          | 0/10 [00:00<?, ?it/s]

A skip link is                 →  a link to a page that is not a link


  0%|          | 0/10 [00:00<?, ?it/s]

The purpose of alt text is     →  to provide a means for the user to communicate with


In [7]:
# Run this between models
import gc
del model
gc.collect()
torch.mps.empty_cache()
print("Memory cleared")

Memory cleared


# Pythia 410m

In [None]:
model = HookedTransformer.from_pretrained("pythia-410m")

Loaded pretrained model pythia-410m into HookedTransformer
Loaded: Pythia 410m
Layers: 24
Heads: 16
Hidden size: 1024
Params: 405.3M


In [None]:
prompt = "A screen reader is"
output = model.generate(prompt, max_new_tokens=10, temperature=0)
print(f"Input: {prompt}")
print(f"Output: {output}")

# Cache internal states
logits, cache = model.run_with_cache(prompt)
print(f"\nCached {len(cache)} different activation points!")

  0%|          | 0/1 [00:00<?, ?it/s]

Input: The capital of France is
Output: The capital of France is the

Cached 435 different activation points!


In [None]:
# Test if it knows facts about France
test_prompts = [
    "A screen reader is",
    "WCAG stands for", 
    "A skip link is",
    "The purpose of alt text is"
]

for prompt in test_prompts:
    output = model.generate(prompt, max_new_tokens=10, temperature=0)
    print(f"{prompt:30} → {output[len(prompt):]}")

  0%|          | 0/3 [00:00<?, ?it/s]

The capital of France is       →  the capital of


  0%|          | 0/3 [00:00<?, ?it/s]

Paris is the capital of        →  France, and


  0%|          | 0/3 [00:00<?, ?it/s]

France is a country in         →  which the right


  0%|          | 0/3 [00:00<?, ?it/s]

The Eiffel Tower is in         →  Paris, France


In [None]:
# Run this between models
import torch
import gc
del model
gc.collect()
torch.mps.empty_cache()
print("Memory cleared")

Memory cleared


# Pythia 1b

In [None]:
from transformer_lens import HookedTransformer
import transformer_lens.utils as utils

Loaded pretrained model pythia-1b into HookedTransformer
Loaded: Pythia 1b
Layers: 16
Heads: 8
Hidden size: 2048
Params: 1011.7M


In [None]:
prompt = "A screen reader is"
output = model.generate(prompt, max_new_tokens=10, temperature=0)
print(f"Input: {prompt}")
print(f"Output: {output}")

# Cache internal states
logits, cache = model.run_with_cache(prompt)
print(f"\nCached {len(cache)} different activation points!")

  0%|          | 0/1 [00:00<?, ?it/s]

Input: The capital of France is
Output: The capital of France is the

Cached 291 different activation points!


In [None]:
# Test if it knows facts about France
test_prompts = [
    "A screen reader is",
    "WCAG stands for", 
    "A skip link is",
    "The purpose of alt text is"
]

for prompt in test_prompts:
    output = model.generate(prompt, max_new_tokens=10, temperature=0)
    print(f"{prompt:30} → {output[len(prompt):]}")

  0%|          | 0/3 [00:00<?, ?it/s]

The capital of France is       →  a city of


  0%|          | 0/3 [00:00<?, ?it/s]

Paris is the capital of        →  France, and


  0%|          | 0/3 [00:00<?, ?it/s]

France is a country in         →  which the French


  0%|          | 0/3 [00:00<?, ?it/s]

The Eiffel Tower is in         →  the midst of


In [None]:
# Run this between models
import torch
import gc
del model
gc.collect()
torch.mps.empty_cache()
print("Memory cleared")

Memory cleared


# Pythia 2.8b

In [None]:
from transformer_lens import HookedTransformer
import transformer_lens.utils as utils

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/5.68G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Loaded pretrained model pythia-2.8b into HookedTransformer
Loaded: Pythia 2.8b
Layers: 32
Heads: 32
Hidden size: 2560
Params: 2774.9M


In [None]:
prompt = "A screen reader is"
output = model.generate(prompt, max_new_tokens=10, temperature=0)
print(f"Input: {prompt}")
print(f"Output: {output}")

# Cache internal states
logits, cache = model.run_with_cache(prompt)
print(f"\nCached {len(cache)} different activation points!")

  0%|          | 0/1 [00:00<?, ?it/s]

Input: The capital of France is
Output: The capital of France is Paris

Cached 579 different activation points!


In [None]:
# Test if it knows facts about France
test_prompts = [
    "A screen reader is",
    "WCAG stands for", 
    "A skip link is",
    "The purpose of alt text is"
]

for prompt in test_prompts:
    output = model.generate(prompt, max_new_tokens=10, temperature=0)
    print(f"{prompt:30} → {output[len(prompt):]}")

  0%|          | 0/3 [00:00<?, ?it/s]

The capital of France is       →  a city of


  0%|          | 0/3 [00:00<?, ?it/s]

Paris is the capital of        →  France and the


  0%|          | 0/3 [00:00<?, ?it/s]

France is a country in         →  Europe, located


  0%|          | 0/3 [00:00<?, ?it/s]

The Eiffel Tower is in         →  the news again


In [None]:
# Run this between models
import torch
import gc
del model
gc.collect()
torch.mps.empty_cache()
print("Memory cleared")

Memory cleared
