In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("..")

import random
import json
import torch
import gc
from dataclasses import dataclass
from collections import defaultdict
import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio
import einops
from typing import List, Tuple, Literal, Union
from jaxtyping import Float, Int, Bool
from torch import Tensor
from colorama import Fore
import textwrap
import gc
import copy
import re
import numpy as np
from functools import partial
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import transformer_lens.utils as tl_utils
from utils.sae_utils import load_sae
import os
from tqdm import tqdm

from utils.activation_cache import _get_activations_fixed_seq_len
from utils.hf_models.model_factory import construct_model_base
from utils.utils import model_alias_to_model_name

torch.set_grad_enabled(False)

In [2]:
from utils.hf_models.gemma_model import format_instruction_gemma_chat
from utils.hf_models.llama3_model import format_instruction_llama3_chat

model_alias = 'meta-llama/Llama-3.1-8B'

if 'gemma' in model_alias.lower():  
    format_instructions_chat_fn = partial(format_instruction_gemma_chat, output=None, system=None, include_trailing_whitespace=True)
elif 'llama' in model_alias.lower():
    format_instructions_chat_fn = partial(format_instruction_llama3_chat, output=None, system=None, include_trailing_whitespace=True)

In [3]:
def clear():
    gc.collect()
    torch.cuda.empty_cache()

In [None]:

model_path = model_alias_to_model_name[model_alias]
model = HookedTransformer.from_pretrained_no_processing(model_path)
tokenizer = model.tokenizer
tokenizer.padding_side = 'left'
model_alias = model_alias.replace('/', '_')


In [23]:
from mech_interp_utils import layer_sparisity_widths, model_alias_to_sae_repo_id

def load_sae_(model_alias, layer, width='16k', repo_id="google/gemma-scope-2b-pt-res"):
    """
    Loads the sae for a given layer and width.
    Calls the load_sae function with the correct parameters.
    """
    sae_width = '16k'
    # sae_sparsity = layer_sparisity_widths[layer-1][width]
    # sae_id = f"layer_{layer-1}/width_{width}/average_l0_{str(sae_sparsity)}"
    if model_alias == 'gemma-2b-it':
        assert layer == 13, "Layer 13 is the only layer for gemma-2b-it"
        sae_id = "gemma-2b-it-res-jb"
        repo_id = model_alias_to_sae_repo_id[model_alias]
    elif model_alias == 'meta-llama_Llama-3.1-8B':
        sae_id = f"l{layer-1}r_8x"
    elif model_alias == 'gemma-2-9b-it':
        assert layer in [10, 21, 32], "Layer 10, 21, 32 are the only layers for gemma-2-9b-it"
        sae_sparsity = layer_sparisity_widths[model_alias][layer-1][sae_width]
        sae_id = f"layer_{layer-1}/width_{sae_width}/average_l0_{str(sae_sparsity)}"
        repo_id = model_alias_to_sae_repo_id[model_alias]
    else:
        sae_id = f"layer_{layer-1}/width_{sae_width}/average_l0_{str(layer_sparisity_widths[model_alias][layer-1][sae_width])}"
        repo_id = model_alias_to_sae_repo_id[model_alias]

    sae = load_sae(repo_id, sae_id)
    return sae

def get_activations_hook_fn(activation, hook, layer, cache: Float[Tensor, "pos layer d_model"], n_samples, positions: List[int]):
    cache += activation[:, positions, :]
    return activation

def get_cache(
    model: HookedTransformer,
    tokenizer,
    prompts: List[str],
    act_name='resid_pre',
    cache_layers=None,
    positions=None,
    batch_size=32
) -> Tensor:

    n_layers = model.cfg.n_layers
    d_model = model.cfg.d_model
    d_mlp = model.cfg.d_mlp
    model_dtype = model.cfg.dtype

    if act_name == 'post':
        act_dim = d_mlp
    else:
        act_dim = d_model

    if cache_layers is None:
        cache_layers = list(range(0, n_layers))

    toks = tokenizer(prompts, padding=True, return_tensors="pt").input_ids

    if positions is None:
        positions = list(range(0, toks.size(-1)))

    activations: Float[Tensor, "batch_size layer seq_len d_model"] = torch.zeros((len(prompts), len(cache_layers), len(positions), act_dim), dtype=model_dtype, device='cuda')

    for batch_idx in tqdm(range(0, len(prompts), batch_size)):
        batch_end = min(len(prompts), batch_idx+batch_size)
        batch_toks = toks[batch_idx:batch_end]

        fwd_hooks = [
            (
                tl_utils.get_act_name(act_name, layer),
                partial(get_activations_hook_fn, layer=layer, cache=activations[batch_idx:batch_end, layer_idx, :, :], n_samples=len(prompts), positions=positions)
            )
            for layer_idx, layer in enumerate(cache_layers)
        ]

        model.run_with_hooks(batch_toks, fwd_hooks=fwd_hooks)

    return toks, activations

@dataclass
class PerTokenLatentActivations:
    prompt: str
    token_ids: List[int]
    token_strs: List[str]
    acts: List[float]

    def __str__(self):
        colored_tokens = [
            f"\033[91m{token} (+{act:.2f})\033[0m" if act > 0 else token
            for token, act in zip(self.token_strs, self.acts)
        ]
        return ''.join(colored_tokens)


def get_latent_activations(model, model_alias, tokenizer, prompts, latents: Tuple[int, int]) -> List[List[PerTokenLatentActivations]]:

    layers_to_cache = sorted(list(set([layer for layer, idx in latents])))

    toks, activations = get_cache(model, tokenizer, prompts, cache_layers=layers_to_cache)

    latents_activations = [[] for _ in range(len(latents))]

    prompt_tok_ids = []
    prompt_tok_strs = []

    for batch_idx, prompt in enumerate(prompts):
        n_toks = (toks[batch_idx] != tokenizer.pad_token_id).sum()
        n_toks -= 1 # exclude the bos token
        prompt_tok_ids.append(toks[batch_idx, -n_toks:].tolist())
        prompt_tok_strs.append([tokenizer.decode(tok) for tok in prompt_tok_ids[-1]])

    for layer_idx, layer in enumerate(layers_to_cache):
        sae = load_sae_(model_alias, layer, repo_id=model_alias_to_sae_repo_id[model_alias])
        latent_activations_at_layer = sae.encode(activations[:, layer_idx, :, :])

        for batch_idx, prompt in enumerate(prompts):
            n_toks = len(prompt_tok_ids[batch_idx])
            for latent_idx_in_list, (layer_, latent_idx) in enumerate(latents):
                if layer_ == layer:
                    acts = latent_activations_at_layer[batch_idx, -n_toks:, latent_idx].tolist()
                    latents_activations[latent_idx_in_list].append(
                        PerTokenLatentActivations(prompt=prompt, token_ids=prompt_tok_ids[batch_idx], token_strs=prompt_tok_strs[batch_idx], acts=acts)
                    )
        
    return latents_activations

In [17]:
def create_prompts(data, start_time: str, end_time: str, prompt_type="song_performers", max_num_items=None):

    start_time_year = int(start_time.split('-')[0])
    start_time_month = int(start_time.split('-')[1])
    end_time_year = int(end_time.split('-')[0])
    end_time_month = int(end_time.split('-')[1])

    def filter_by_time_range(element):
        if start_time:
            year = int(element['album_release_date'].split('-')[0])
            month = int(element['album_release_date'].split('-')[1]) if len(element['album_release_date'].split('-')) > 1 else 1
            if year < start_time_year or (year == start_time_year and month < start_time_month):
                return False
            if year > end_time_year or (year == end_time_year and month > end_time_month):
                return False
        return True

    def preprocess_name(name: str):
        # Remove text inside brackets
        name = re.sub(r'[\(\[].*?[\)\]]', '', name)
        name = name.split("-")[0].strip()
        return name

    def create_prompt(element):
        song_name = preprocess_name(element['name'])
        if prompt_type == "song_performers":
            return f"The song '{song_name}'"
        else:
            raise ValueError(f"Prompt type {prompt_type} not supported")

    if prompt_type == "song_performers":
        new_data = []
        for element in data:
            artist_songs = []
            for song in element['tracks']:
                if filter_by_time_range(song):
                    artist_songs.append(song)
            if max_num_items is not None and len(artist_songs) > max_num_items:
                artist_songs = random.sample(artist_songs, max_num_items)
            new_data.extend(artist_songs)
        data = new_data
    else:
        raise ValueError(f"Prompt type {prompt_type} not supported")

    if max_num_items is not None and len(data) > max_num_items:
        data = random.sample(data, max_num_items)

    prompts = [create_prompt(element) for element in data]

    return prompts, data


In [18]:
latents_by_model = {
    'gemma-2-2b': [
        (13, 7957), # known
        (15, 11898), # unknown
    ],
    'gemma-2-9b': [
        (22, 10424), # known
        (21, 4451), # unknown
    ],
    'meta-llama_Llama-3.1-8B': [
        (13, 21306), # known
        (15, 28509), # unknown
    ],
}

In [None]:
cutoff_songs_data_path = "./raw_songs.json"

with open(cutoff_songs_data_path, "r") as f:
    cutoff_songs_data = json.load(f)

cutoff_songs_data[0]

In [None]:
known_prompts, known_songs = create_prompts(cutoff_songs_data, start_time="2010-01", end_time="2024-01", max_num_items=500)
unknown_prompts, unknown_songs = create_prompts(cutoff_songs_data, start_time="2024-08", end_time="2025-01", max_num_items=500)

known_prompts = list(set(known_prompts))
unknown_prompts = list(set(unknown_prompts))

print(f"len(known_prompts): {len(known_prompts)}")
print(f"len(unknown_prompts): {len(unknown_prompts)}")

In [None]:
from utils.utils import model_is_chat_model

latents = latents_by_model[model_alias]

known_prompts_latents_activations = defaultdict(list)
unknown_prompts_latents_activations = defaultdict(list)
prompts = unknown_prompts#[:10]
if model_is_chat_model[model_alias] == True:
    prompts = [format_instructions_chat_fn(instruction=p) for p in prompts]

latents_activations = get_latent_activations(model, model_alias, tokenizer, known_prompts, latents)
for latent_idx, latent in enumerate(latents):
    known_prompts_latents_activations[latent].extend(latents_activations[latent_idx])


latents_activations = get_latent_activations(model, model_alias, tokenizer, unknown_prompts, latents)
for latent_idx, latent in enumerate(latents):
    unknown_prompts_latents_activations[latent].extend(latents_activations[latent_idx])


In [None]:
len(unknown_prompts)

In [40]:
def show_activation_frequency(latent_activations):
    frequency = defaultdict(int)

    for latent_idx, latent in enumerate(latent_activations):
        print(latent_activations[latent])
        for prompt_acts in latent_activations[latent]:
            last_pos = len(prompt_acts.token_strs)git clone - 1
            if prompt_acts.acts[last_pos] > 0 or prompt_acts.acts[last_pos-1] > 0:
                frequency[latent] += 1
        print('total', len(latent_activations[latent]))
        frequency_percentage = frequency[latent] / len(latent_activations[latent])
        print(f"Latent L{latent[0]} F{latent[1]}: {frequency[latent]} ({frequency_percentage:.2%})")

In [None]:
show_activation_frequency(unknown_prompts_latents_activations)

In [None]:
unknown_prompts

In [8]:
# question_prompts = [
#     ('Who is the main actor in The Lord of the Rings?', 'Who is the main actor in The Man of the Rings?'),
#     ('Who directed the movie 12 Angry Men?', 'Who directed the movie 200 Angry Men?'),
#     ('What was the release year of the movie The Lord of the Rings?', 'What was the release year of the movie The Man of the Rings?'),
#     ('Who is the director of the movie Pulp Fiction?', 'Who is the director of the movie Pulping Fiction?'),
#     ('When was the player Cristiano Ronaldo born?', 'When was the player Cristiano Penalda born?'),
#     ('Where was the player Leo Messi born?', 'How many goals did the player Leo Messi score in his career?'),
#     ("What team (name at least one) signed the player 'Jake Bornheimer'?", "What team (name at least one) signed the player 'Jeff Van Gundy'?"),
#     ("What is the name of an actor starring in the movie 'The Ten Gladiators'?", "What is the name of an actor starring in the movie 'The Ten Gladiators'?"),
#     ("What genre label best describes the movie 'Stranger by the Lake'?", "What genre label best describes the movie 'Stranger by the Lake'?")
# ]

# question_prompts = [
#     ('Who is Michael Jordan?', 'When was Michael Joordan born?'),
#     ('When was Michael Jordan born?', 'How many points did Michael Jordan score in his career?'),
#     ('In which year did Michael Jordan retire?', 'What team did Michael Jordan play for in his career?'),
#     ('Who directed the movie the Godfather?', 'Who directed the movie the Godmother?'),
#     ('What was the release year of the movie The Lord of the Rings?', 'What was the release year of the movie The Man of the Rings?'),
#     ('Who is the director of the movie Pulp Fiction?', 'Who is the director of the movie Pulping Fiction?'),
#     ('When was the player Cristiano Ronaldo born?', 'When was the player Cristiano Penalda born?'),
# ]
question_prompts = [
    ('When was the player LeBron James born?', 'When was the player Wilson Brown born?'),
    ('How many iPhones were sold in 2008?', 'What are the sizes of an iPhone 13?'),
    ('What was the release year of the movie The Lord of the Rings?', 'What was the release year of the movie The Man of the Rings?'),
    ('Who is the director of the movie Pulp Fiction?', 'Who is the director of the movie Pulping Fiction?'),
    ('When was the player Cristiano Ronaldo born?', 'When was the player Cristiano Penalda born?'),
]

prompts = []

for m1, m2 in question_prompts:
    prompts.append(m1)
    prompts.append(m2)

In [7]:
movies_contrastive_pairs = [
    ('The Lord of the Rings', 'The Man of the Rings'),
    ('12 Angry Men', '20 Angry Men'),
    ('Do you know the movie The Lord of the Rings?', 'Do you know the movie The Man of the Rings?'),
    ('Have you seen the movie Pulp Fiction?', 'Have you seen the movie Pulping Fiction?'),
    ("Let's go watch the movie 12 Angry Men", "Let's go watch the movie 20 Angry Men"),
    ('I watched the movie Good Will Hunting the other day.', 'I watched the movie Good William Hunting the other day.'),
    ('The movie The Godfather is such a great movie.', 'The movie The Godmother is such a great movie.'),
    
]

prompts = []

for m1, m2 in movies_contrastive_pairs:
    prompts.append(m1)
    prompts.append(m2)

In [13]:
city_contrastive_pairs = [
    ('Paris', 'Parris'),
    ('London', 'Londen'),
    ('Have you been to Paris?', 'Have you been to Parris?'),
    ('I am visiting Barcelona next week.', 'I am visiting Sarbelona next week.'),
    ('I was born in New York City.', 'I was born in Old Dork City.'),
    ('Do you know the city of San Francisco?', 'Do you know the city of Can Sancisco?'),
]

prompts = []

for c1, c2 in city_contrastive_pairs:
    prompts.append(c1)
    prompts.append(c2)


In [35]:
movie_alterations = [
    ("The Godfather", "The Dogfather"),
    ("Jurassic Park", "Jurassic Parking"),
    ("The Silence of the Lambs", "The Silence of the Hams"),
    ("Forrest Gump", "Forrest Plump"),
    ("The Matrix", "The Mattress"),
    ("Schindler's List", "Schindler's Lisp"),
    ("The Shawshank Redemption", "The Shawshank Seduction"),
    ("Titanic", "Gigantic"),
    ("The Dark Knight", "The Dork Knight"),
    ("Eternal Sunshine of the Spotless Mind", "Eternal Sunshine of the Spotless Behind")
]

movie_phrases = [
    ("The Godfather", "The Dogfather"),
    ("The movie Jurassic Park was released in 1993.", "The movie Jurassic Parking was released in 1993."),
    ("Do you know the movie The Silence of the Lambs?", "Do you know the movie The Silence of the Hams?"),
    ("Have you seen the movie Forrest Gump?", "Have you seen the movie Forrest Plump?"),
    ("Let's go watch the movie The Matrix", "Let's go watch the movie The Mattress"),
    ("I watched the movie Schindler's List last night.", "I watched the movie Schindler's Lisp last night."),
    ("The movie The Shawshank Redemption is a classic.", "The movie The Shawshank Seduction is a classic."),
    ("Titanic", "Gigantic"),
    ("My favorite superhero film is The Dark Knight.", "My favorite superhero film is The Dork Knight."),
    ("Eternal Sunshine of the Spotless Mind is a unique love story.", "Eternal Sunshine of the Spotless Behind is a unique love story.")
]

city_alterations = [
    ("New York", "New Pork"),
    ("Los Angeles", "Lost Angeles"),
    ("I love visiting Paris", "I love visiting Pairs"),
    ("Have you ever been to London?", "Have you ever been to Londonut?"),
    ("Tokyo is a bustling metropolis", "Tokyolk is a bustling metropolis"),
    ("Rome", "Roam"),
    ("Let's take a trip to Sydney", "Let's take a trip to Kidney"),
    ("Berlin is known for its history", "Berlout is known for its history"),
    ("Moscow", "Cowscow"),
    ("The weather in Chicago is unpredictable", "The weather in Chicagoing is unpredictable")
]

athlete_alterations = [
    ("LeBron James", "LeBroom Games"),
    ("Serena Williams", "Serenade Williams"),
    ("Did you see Cristiano Ronaldo's goal?", "Did you see Cristina Penalda's goal?"),
    ("Usain Bolt is the fastest man alive", "Usain Jolt is the fastest man alive"),
    ("Michael Phelps", "Michael Yelps"),
    ("I admire the skills of Simone Biles", "I admire the skills of Lemon Biles"),
    ("Roger Federer", "Roger Cheddar"),
    ("Megan Rapinoe led the team to victory", "Vegan Rapinoe led the team to victory"),
    ("Tom Brady", "Tom Gravy"),
    ("The legacy of Muhammad Ali is inspiring", "The legacy of Muhammad Alley is inspiring")
]

prompts = []

for m1, m2 in movie_phrases + city_alterations + athlete_alterations:
    prompts.append(m1)
    prompts.append(m2)

In [None]:
import json

# Read the JSON file
with open('./random_entities/unknown_entities_sentences.json', 'r') as f:
    known_entities_data = json.load(f)

# Extract sentences and entity names
known_entity_sentences = [item['sentence'] for item in known_entities_data]
known_entity_names = [item['entity_name'] for item in known_entities_data]

print(f"Loaded {len(known_entity_sentences)} known entity sentences.")
print("First 5 sentences:")
for sentence in known_entity_sentences[:5]:
    print(f"- {sentence}")


In [None]:
#prompts = known_entity_sentences[:20]

random_ids = random.sample(range(len(known_entity_sentences)), 20)
prompts = [known_entity_sentences[i] for i in random_ids]
prompts

In [None]:
from utils.utils import model_is_chat_model

latents = [
    #(13, 3130)
    # (13, 3130)
    #(13, 7957), # known
    #(15, 11898), # unknown
    # (14, 7769), # unknown
    # (13, 6), # unknown
    # (7, 3782), # unknown
    # (11, 8468), # unknown
    (14, 25742)
]

if model_is_chat_model[model_alias] == True:
    prompts = [format_instructions_chat_fn(instruction=p) for p in prompts]
latents_activations = get_latent_activations(model, model_alias, tokenizer, prompts, latents)

In [None]:
latent = latents[0]
idx = latents.index(latent)

print(f"Latent L{latent[0]} F{latent[1]}:")
for prompt_acts in latents_activations[idx]:
    print(prompt_acts)

In [None]:
latent = latents[0]
idx = latents.index(latent)

print(f"Latent L{latent[0]} F{latent[1]}:")
for prompt_acts in latents_activations[idx]:
    print(prompt_acts)

In [None]:
latent = latents[0]
idx = latents.index(latent)

print(f"Latent L{latent[0]} F{latent[1]}:")
for prompt_acts in latents_activations[idx]:
    print(prompt_acts)

In [None]:
prompts

In [None]:
final_contrastive_prompts = [
    ("Michael Jordan scored 30 points last night.", "Michael Joordan scored 30 points last night."),
    ("The movie Jurassic Park was released in 1993.", "The movie Jurassic Parking was released in 1993."),
    ("I've watched 12 Angry Men many times already.", "I've watched 20 Angry Men many times already."),
    ("The city of Berlin is known for its history.", "The city of Berlouin is known for its history."),
    ("Michael Phelps just won his 23rd Olympic gold medal", "Michael Yelps just won his 23rd Olympic gold medal"),
    ("The legacy of Muhammad Ali is inspiring", "The legacy of Muhammad Alley is inspiring"),
]

prompts = []
for m1, m2 in final_contrastive_prompts:
    prompts.append(m1)
    prompts.append(m2)

latents_activations = get_latent_activations(model, model_alias,tokenizer, prompts, latents)

In [None]:
latent = (15, 11898)
idx = latents.index(latent)

print(f"Latent L{latent[0]} F{latent[1]}:")
for prompt_acts in latents_activations[idx]:
    print(prompt_acts)

In [None]:
for prompt_acts in latents_activations[idx]:
    print("[", end="")
    for token_str, token_act in zip(prompt_acts.token_strs, prompt_acts.acts):
        print(f"({repr(token_str)}, {token_act:.2f})", end=" ")
    print("]")

