In [2]:
from transformer_lens import HookedTransformer
import transformer_lens.utils as utils

# Pythia 160m

In [12]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-160m")
print(f"Layers: {model.cfg.n_layers}")
print(f"Heads: {model.cfg.n_heads}")
print(f"Hidden size: {model.cfg.d_model}")
print(f"Params: {sum(p.numel() for p in model.parameters())/1e6:.1f}M")

Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer
Layers: 12
Heads: 12
Hidden size: 768
Params: 162.3M


In [6]:
prompt = "A screen reader is"
output = model.generate(prompt, max_new_tokens=10, temperature=0)
print(f"Input: {prompt}")
print(f"Output: {output}")

# Cache internal states
logits, cache = model.run_with_cache(prompt)
print(f"\nCached {len(cache)} different activation points!")

  0%|          | 0/10 [00:00<?, ?it/s]

Input: A screen reader is
Output: A screen reader is a software application that reads aloud the text on a

Cached 579 different activation points!


In [7]:
test_prompts = [
    "A screen reader is",
    "WCAG stands for", 
    "A skip link is",
    "The purpose of alt text is",
    "ARIA stands for",
    "A focus indicator is",
    "Keyboard navigation allows",
    "Color contrast is important because",
    "Semantic HTML helps",
    "Captions are used for",
]

for prompt in test_prompts:
    output = model.generate(prompt, max_new_tokens=10, temperature=0)
    print(f"{prompt:30} → {output[len(prompt):]}")

  0%|          | 0/10 [00:00<?, ?it/s]

A screen reader is             →  a software application that reads aloud the text on a


  0%|          | 0/10 [00:00<?, ?it/s]

WCAG stands for                →  the World Council of Churches. The WCC


  0%|          | 0/10 [00:00<?, ?it/s]

A skip link is                 →  a link that is used to skip a section of


  0%|          | 0/10 [00:00<?, ?it/s]

The purpose of alt text is     →  to provide a brief description of the image.



  0%|          | 0/10 [00:00<?, ?it/s]

ARIA stands for                →  “A Rational Approach to Information and Automation


  0%|          | 0/10 [00:00<?, ?it/s]

A focus indicator is           →  a device that is used to indicate the focus of


  0%|          | 0/10 [00:00<?, ?it/s]

Keyboard navigation allows     →  you to navigate through the pages of this website.


  0%|          | 0/10 [00:00<?, ?it/s]

Color contrast is important because →  it is a key factor in determining the quality of


  0%|          | 0/10 [00:00<?, ?it/s]

Semantic HTML helps            →  you create semantic HTML documents.

Semantic


  0%|          | 0/10 [00:00<?, ?it/s]

Captions are used for          →  the purpose of identifying the subject of a photograph.


In [8]:
code_prompts = [
    "The following code is not accessible because it doesn't have what? <img src='photo.jpg'>",
    "A <div> with onclick is not accessible because",
    "The accessibility problem with <a href='#'></a> is",
    "<input type='text'> needs a",
    "A button that only says 'Click here' is bad because",
]
for prompt in code_prompts:
    output = model.generate(prompt, max_new_tokens=20, temperature=0)
    print(f"{prompt:30} → {output[len(prompt):]}")

  0%|          | 0/20 [00:00<?, ?it/s]

The following code is not accessible because it doesn't have what? <img src='photo.jpg'> → 

The following code is not accessible because it doesn't have what? <img src='photo


  0%|          | 0/20 [00:00<?, ?it/s]

A <div> with onclick is not accessible because →  it is not in the same document as the <a> that is clicked.

A <


  0%|          | 0/20 [00:00<?, ?it/s]

The accessibility problem with <a href='#'></a> is →  that it is not a valid HTML tag.

The accessibility problem with <a href='#'


  0%|          | 0/20 [00:00<?, ?it/s]

<input type='text'> needs a    →  value

<input type='text' value=''>

<input type='text'


  0%|          | 0/20 [00:00<?, ?it/s]

A button that only says 'Click here' is bad because →  it's not clear what the button does.

A button that says 'Click here' is


In [9]:
import torch

correct = "A screen reader is software that reads text aloud for blind users."
wrong = "A screen reader is a device for viewing screens."

def get_perplexity(model, text):
    tokens = model.to_tokens(text)
    logits = model(tokens)
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
    
    # Get log prob of each actual next token
    token_log_probs = log_probs[0, :-1, :].gather(1, tokens[0, 1:].unsqueeze(1)).squeeze()
    
    # Perplexity = exp(-mean(log_probs))
    return torch.exp(-token_log_probs.mean()).item()

print(f"Correct: {get_perplexity(model, correct)}")
print(f"Wrong: {get_perplexity(model, wrong)}")

Correct: 13.648499488830566
Wrong: 54.90721893310547


In [13]:
# Run this between models
import torch
import gc
del model
gc.collect()
torch.mps.empty_cache()
print("Memory cleared")

Memory cleared


# Pythia 410m

In [14]:
model = HookedTransformer.from_pretrained("pythia-410m")
print(f"Layers: {model.cfg.n_layers}")
print(f"Heads: {model.cfg.n_heads}")
print(f"Hidden size: {model.cfg.d_model}")
print(f"Params: {sum(p.numel() for p in model.parameters())/1e6:.1f}M")

Loaded pretrained model pythia-410m into HookedTransformer
Layers: 24
Heads: 16
Hidden size: 1024
Params: 405.3M


In [15]:
prompt = "A screen reader is"
output = model.generate(prompt, max_new_tokens=10, temperature=0)
print(f"Input: {prompt}")
print(f"Output: {output}")

# Cache internal states
logits, cache = model.run_with_cache(prompt)
print(f"\nCached {len(cache)} different activation points!")

  0%|          | 0/10 [00:00<?, ?it/s]

Input: A screen reader is
Output: A screen reader is a device that allows a user to view a screen

Cached 435 different activation points!


In [16]:
test_prompts = [
    "A screen reader is",
    "WCAG stands for", 
    "A skip link is",
    "The purpose of alt text is",
    "ARIA stands for",
    "A focus indicator is",
    "Keyboard navigation allows",
    "Color contrast is important because",
    "Semantic HTML helps",
    "Captions are used for",
]

for prompt in test_prompts:
    output = model.generate(prompt, max_new_tokens=10, temperature=0)
    print(f"{prompt:30} → {output[len(prompt):]}")

  0%|          | 0/10 [00:00<?, ?it/s]

A screen reader is             →  a device that allows a user to view a screen


  0%|          | 0/10 [00:00<?, ?it/s]

WCAG stands for                →  the World Confederation of Agricultural and Food Industries.


  0%|          | 0/10 [00:00<?, ?it/s]

A skip link is                 →  a link that is not part of the main page


  0%|          | 0/10 [00:00<?, ?it/s]

The purpose of alt text is     →  to provide a way for users to customize the text


  0%|          | 0/10 [00:00<?, ?it/s]

ARIA stands for                →  the acronym for the acronym for


  0%|          | 0/10 [00:00<?, ?it/s]

A focus indicator is           →  a device that is used to indicate the presence of


  0%|          | 0/10 [00:00<?, ?it/s]

Keyboard navigation allows     →  you to navigate through the various sections of the site


  0%|          | 0/10 [00:00<?, ?it/s]

Color contrast is important because →  it can be used to distinguish between different colors.


  0%|          | 0/10 [00:00<?, ?it/s]

Semantic HTML helps            →  you create a rich, interactive website.




  0%|          | 0/10 [00:00<?, ?it/s]

Captions are used for          →  the purpose of providing a clear and concise description of


In [17]:
code_prompts = [
    "The following code is not accessible because it doesn't have what? <img src='photo.jpg'>",
    "A <div> with onclick is not accessible because",
    "The accessibility problem with <a href='#'></a> is",
    "<input type='text'> needs a",
    "A button that only says 'Click here' is bad because",
]
for prompt in code_prompts:
    output = model.generate(prompt, max_new_tokens=20, temperature=0)
    print(f"{prompt:30} → {output[len(prompt):]}")

  0%|          | 0/20 [00:00<?, ?it/s]

The following code is not accessible because it doesn't have what? <img src='photo.jpg'> → 

The following code is not accessible because it doesn't have what? <img src='photo


  0%|          | 0/20 [00:00<?, ?it/s]

A <div> with onclick is not accessible because →  it is not a valid HTML element.

A <div> with onclick is not accessible because


  0%|          | 0/20 [00:00<?, ?it/s]

The accessibility problem with <a href='#'></a> is →  that it is not clear what the user is looking for.

The <a href='#'


  0%|          | 0/20 [00:00<?, ?it/s]

<input type='text'> needs a    →  <input type='text'>

<input type='submit'>

</form>



  0%|          | 0/20 [00:00<?, ?it/s]

A button that only says 'Click here' is bad because →  it's not a link.

A button that only says 'Click here' is bad because


In [18]:
import torch

correct = "A screen reader is software that reads text aloud for blind users."
wrong = "A screen reader is a device for viewing screens."

def get_perplexity(model, text):
    tokens = model.to_tokens(text)
    logits = model(tokens)
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
    
    # Get log prob of each actual next token
    token_log_probs = log_probs[0, :-1, :].gather(1, tokens[0, 1:].unsqueeze(1)).squeeze()
    
    # Perplexity = exp(-mean(log_probs))
    return torch.exp(-token_log_probs.mean()).item()

print(f"Correct: {get_perplexity(model, correct)}")
print(f"Wrong: {get_perplexity(model, wrong)}")

Correct: 40.05073165893555
Wrong: 32.803348541259766


In [19]:
# Run this between models
import torch
import gc
del model
gc.collect()
torch.mps.empty_cache()
print("Memory cleared")

Memory cleared


# Pythia 1b

In [20]:
model = HookedTransformer.from_pretrained("pythia-1b")
print(f"Layers: {model.cfg.n_layers}")
print(f"Heads: {model.cfg.n_heads}")
print(f"Hidden size: {model.cfg.d_model}")
print(f"Params: {sum(p.numel() for p in model.parameters())/1e6:.1f}M")

Loaded pretrained model pythia-1b into HookedTransformer
Layers: 16
Heads: 8
Hidden size: 2048
Params: 1011.7M


In [21]:
prompt = "A screen reader is"
output = model.generate(prompt, max_new_tokens=10, temperature=0)
print(f"Input: {prompt}")
print(f"Output: {output}")

# Cache internal states
logits, cache = model.run_with_cache(prompt)
print(f"\nCached {len(cache)} different activation points!")

  0%|          | 0/10 [00:00<?, ?it/s]

Input: A screen reader is
Output: A screen reader is a program that reads text from a computer screen and

Cached 291 different activation points!


In [22]:
test_prompts = [
    "A screen reader is",
    "WCAG stands for", 
    "A skip link is",
    "The purpose of alt text is",
    "ARIA stands for",
    "A focus indicator is",
    "Keyboard navigation allows",
    "Color contrast is important because",
    "Semantic HTML helps",
    "Captions are used for",
]

for prompt in test_prompts:
    output = model.generate(prompt, max_new_tokens=10, temperature=0)
    print(f"{prompt:30} → {output[len(prompt):]}")

  0%|          | 0/10 [00:00<?, ?it/s]

A screen reader is             →  a program that reads text from a computer screen and


  0%|          | 0/10 [00:00<?, ?it/s]

WCAG stands for                →  what?

The World Council for Children’


  0%|          | 0/10 [00:00<?, ?it/s]

A skip link is                 →  a type of data link used in a computer network


  0%|          | 0/10 [00:00<?, ?it/s]

The purpose of alt text is     →  to provide a way to add text to an image


  0%|          | 0/10 [00:00<?, ?it/s]

ARIA stands for                →  “Artificial Replacement of a Human Being”.


  0%|          | 0/10 [00:00<?, ?it/s]

A focus indicator is           →  a device that is used to indicate the focus of


  0%|          | 0/10 [00:00<?, ?it/s]

Keyboard navigation allows     →  you to navigate through the pages of this website.


  0%|          | 0/10 [00:00<?, ?it/s]

Color contrast is important because →  it is the basis for many of the visual effects


  0%|          | 0/10 [00:00<?, ?it/s]

Semantic HTML helps            →  you to create a website that is easy to read


  0%|          | 0/10 [00:00<?, ?it/s]

Captions are used for          →  a variety of purposes, including providing information to a


In [23]:
code_prompts = [
    "The following code is not accessible because it doesn't have what? <img src='photo.jpg'>",
    "A <div> with onclick is not accessible because",
    "The accessibility problem with <a href='#'></a> is",
    "<input type='text'> needs a",
    "A button that only says 'Click here' is bad because",
]
for prompt in code_prompts:
    output = model.generate(prompt, max_new_tokens=20, temperature=0)
    print(f"{prompt:30} → {output[len(prompt):]}")

  0%|          | 0/20 [00:00<?, ?it/s]

The following code is not accessible because it doesn't have what? <img src='photo.jpg'> → 

The following code is not accessible because it doesn't have what? <img src='photo


  0%|          | 0/20 [00:00<?, ?it/s]

A <div> with onclick is not accessible because →  it is not a <span> element.

A <div> with onclick is not accessible


  0%|          | 0/20 [00:00<?, ?it/s]

The accessibility problem with <a href='#'></a> is →  that it is not a valid HTML tag.

The <a href='#'></a>


  0%|          | 0/20 [00:00<?, ?it/s]

<input type='text'> needs a    →  value
<input type='text'> needs a value
<input type='text'> needs a


  0%|          | 0/20 [00:00<?, ?it/s]

A button that only says 'Click here' is bad because →  it's a clickable link.

A button that only says 'Click here' is bad


In [24]:
import torch

correct = "A screen reader is software that reads text aloud for blind users."
wrong = "A screen reader is a device for viewing screens."

def get_perplexity(model, text):
    tokens = model.to_tokens(text)
    logits = model(tokens)
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
    
    # Get log prob of each actual next token
    token_log_probs = log_probs[0, :-1, :].gather(1, tokens[0, 1:].unsqueeze(1)).squeeze()
    
    # Perplexity = exp(-mean(log_probs))
    return torch.exp(-token_log_probs.mean()).item()

print(f"Correct: {get_perplexity(model, correct)}")
print(f"Wrong: {get_perplexity(model, wrong)}")

Correct: 18.804874420166016
Wrong: 42.179691314697266


In [25]:
# Run this between models
import torch
import gc
del model
gc.collect()
torch.mps.empty_cache()
print("Memory cleared")

Memory cleared


# Pythia 2.8b

In [27]:
model = HookedTransformer.from_pretrained("pythia-2.8b")
print(f"Layers: {model.cfg.n_layers}")
print(f"Heads: {model.cfg.n_heads}")
print(f"Hidden size: {model.cfg.d_model}")
print(f"Params: {sum(p.numel() for p in model.parameters())/1e6:.1f}M")

Loaded pretrained model pythia-2.8b into HookedTransformer
Layers: 32
Heads: 32
Hidden size: 2560
Params: 2774.9M


In [28]:
prompt = "A screen reader is"
output = model.generate(prompt, max_new_tokens=10, temperature=0)
print(f"Input: {prompt}")
print(f"Output: {output}")

# Cache internal states
logits, cache = model.run_with_cache(prompt)
print(f"\nCached {len(cache)} different activation points!")

  0%|          | 0/10 [00:00<?, ?it/s]

Input: A screen reader is
Output: A screen reader is a software application that reads aloud the text on a

Cached 579 different activation points!


In [29]:
test_prompts = [
    "A screen reader is",
    "WCAG stands for", 
    "A skip link is",
    "The purpose of alt text is",
    "ARIA stands for",
    "A focus indicator is",
    "Keyboard navigation allows",
    "Color contrast is important because",
    "Semantic HTML helps",
    "Captions are used for",
]

for prompt in test_prompts:
    output = model.generate(prompt, max_new_tokens=10, temperature=0)
    print(f"{prompt:30} → {output[len(prompt):]}")

  0%|          | 0/10 [00:00<?, ?it/s]

A screen reader is             →  a software application that reads aloud the text on a


  0%|          | 0/10 [00:00<?, ?it/s]

WCAG stands for                →  the World Council of Churches. The WCC


  0%|          | 0/10 [00:00<?, ?it/s]

A skip link is                 →  a link that is used to skip a section of


  0%|          | 0/10 [00:00<?, ?it/s]

The purpose of alt text is     →  to provide a brief description of the image.



  0%|          | 0/10 [00:00<?, ?it/s]

ARIA stands for                →  “A Rational Approach to Information and Automation


  0%|          | 0/10 [00:00<?, ?it/s]

A focus indicator is           →  a device that is used to indicate the focus of


  0%|          | 0/10 [00:00<?, ?it/s]

Keyboard navigation allows     →  you to navigate through the pages of this website.


  0%|          | 0/10 [00:00<?, ?it/s]

Color contrast is important because →  it is a key factor in determining the quality of


  0%|          | 0/10 [00:00<?, ?it/s]

Semantic HTML helps            →  you create semantic HTML documents.

Semantic


  0%|          | 0/10 [00:00<?, ?it/s]

Captions are used for          →  the purpose of identifying the subject of a photograph.


In [30]:
code_prompts = [
    "A <div> with onclick is not accessible because",
    "The accessibility problem with <a href='#'></a> is",
    "<input type='text'> needs a",
    "A button that only says 'Click here' is bad because",
]
for prompt in code_prompts:
    output = model.generate(prompt, max_new_tokens=20, temperature=0)
    print(f"{prompt:30} → {output[len(prompt):]}")

  0%|          | 0/20 [00:00<?, ?it/s]

A <div> with onclick is not accessible because →  it is not in the same document as the <a> that is clicked.

A <


  0%|          | 0/20 [00:00<?, ?it/s]

The accessibility problem with <a href='#'></a> is →  that it is not a valid HTML tag.

The accessibility problem with <a href='#'


  0%|          | 0/20 [00:00<?, ?it/s]

<input type='text'> needs a    →  value

<input type='text' value=''>

<input type='text'


  0%|          | 0/20 [00:00<?, ?it/s]

A button that only says 'Click here' is bad because →  it's not clear what the button does.

A button that says 'Click here' is


In [31]:
import circuitsvis as cv

prompt = "A screen reader is"
tokens = model.to_str_tokens(prompt)
logits, cache = model.run_with_cache(prompt)

# Visualize attention for a specific layer (start with the last layer)
layer = model.cfg.n_layers - 1  # last layer
attention = cache["pattern", layer]  # shape: [heads, seq, seq]

cv.attention.attention_patterns(tokens=tokens, attention=attention[0])

In [32]:
import torch

correct = "A screen reader is software that reads text aloud for blind users."
wrong = "A screen reader is a device for viewing screens."

def get_perplexity(model, text):
    tokens = model.to_tokens(text)
    logits = model(tokens)
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
    
    # Get log prob of each actual next token
    token_log_probs = log_probs[0, :-1, :].gather(1, tokens[0, 1:].unsqueeze(1)).squeeze()
    
    # Perplexity = exp(-mean(log_probs))
    return torch.exp(-token_log_probs.mean()).item()

print(f"Correct: {get_perplexity(model, correct)}")
print(f"Wrong: {get_perplexity(model, wrong)}")

Correct: 13.648499488830566
Wrong: 54.90721893310547


In [33]:
# Run this between models
import torch
import gc
del model
gc.collect()
torch.mps.empty_cache()
print("Memory cleared")

Memory cleared
