In [1]:
import torch as t
from torch.nn import functional as F
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import LlamaForCausalLM 
import einops
import matplotlib.pyplot as plt
import numpy as np
from typing import Union, Optional, Tuple, Any
from torch import Tensor
from dataclasses import dataclass, field
from tqdm.notebook import tqdm
from jaxtyping import Int, Float
from typing import List, Dict
from collections import defaultdict
from torch.utils.data import DataLoader, Dataset
import datetime
llama_token = "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"
device = t.device("cuda:0" if t.cuda.is_available() else "cpu")

In [2]:
t.cuda.empty_cache()
n_param = 7
model = AutoModelForCausalLM.from_pretrained(
        f"meta-llama/Llama-2-{n_param}b-chat-hf", use_auth_token=llama_token
    ).to(device)
tokenizer = AutoTokenizer.from_pretrained(
            "meta-llama/Llama-2-7b-chat-hf", ignore_mismatched_sizes=True, use_auth_token=llama_token
        )
tokenizer.pad_token = tokenizer.eos_token



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [146]:
class CachedDataset(Dataset):
    def __init__(self, model, tokenizer, token_list, activation_list, name: str="magic", threshhold: float=0.5 ):
        super().__init__()
        self.B_INST, self.E_INST = "[INST]", "[/INST]"
        self.B_SYS, self.E_SYS = "<<SYS>>", "<</SYS>>"

        self.magic_token_ids = tokenizer.encode(name)[1]
        self.tokenizer = tokenizer
        self.model = model

        yes_label = tokenizer.encode("1")[-1]
        no_label = tokenizer.encode("0")[-1]

        systtem_prompt =""" Your task is to assess if a given token (word) from a sentence represents a specified concept. Provide a rating based on this assessment:
                            If the token represents the concept, respond with "Rating: 1".
                            If the token does not represent the concept, respond with "Rating: 0".
                            Focus solely on the token and use the sentence for context only. Be confident.
                        """
        systemprompt_ids = self.systemprompt_to_ids(tokenizer, systtem_prompt)
        system_promt_cache = self.get_cache(systemprompt_ids.to(device))

        sentence_caches = []

        for sentence in tqdm(token_list):
            sentence_ids = self.sentence_to_ids(sentence)
            sentence_cache = self.get_cache(sentence_ids.to(device), prev_cache=system_promt_cache)
            print(sentence_cache[0][0].shape)
            #make a deep copy of the cache
            sentence_cache = [[layer.clone() for layer in sub_cache] for sub_cache in sentence_cache]
            sentence_caches.append(sentence_cache)
        
        self.datapoint_counter = 0
        self.data_dict = dict()

        for sentence, sentence_cache, activations in tqdm(zip(token_list, sentence_caches, activation_list)):
            for token, activation in zip(sentence, activations):

                label = yes_label if activation > threshhold else no_label
                
                question_end_ids = self.question_end_to_ids( token)
                self.data_dict[self.datapoint_counter] =(sentence_cache, question_end_ids,label)
                self.datapoint_counter += 1

    def systemprompt_to_ids(self,tokenizer, systtem_prompt):
        prompt = self.B_INST + self.B_SYS + systtem_prompt + self.E_SYS + "Sentence: "
        ids = t.tensor(tokenizer.encode(prompt)).unsqueeze(0)
        return ids
    def get_cache(self, ids, prev_cache = None):
        with t.no_grad():
            if prev_cache is None:
                output = self.model(ids, return_dict=True)
            else:
                output = self.model(ids, past_key_values=prev_cache, return_dict=True)
        return output.past_key_values
    def sentence_to_ids(self, sentence):
        post_text = "Concept:"
        post_text_ids = self.tokenizer.encode(post_text)[1:]
        ids = t.tensor(sentence + post_text_ids).unsqueeze(0)
        return ids
    def question_end_to_ids(self, question_token_ids):
        text_1 = " Token:"
        ids_1 = tokenizer.encode(text_1)[1:]
        text_2 = self.E_INST +"The rating is "
        ids_2 = self.tokenizer.encode(text_2)[1:]
        ids =[self.magic_token_ids] + ids_1 + [question_token_ids] + ids_2
        return t.tensor(ids).unsqueeze(0)

    
    def __getitem__(self, idx):
        return self.data_dict[idx]

    def __len__(self):
        return self.datapoint_counter


In [147]:
string_list = ["Lore ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.",
                "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.",
                "Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.",
                "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.",
                "Sed ut perspiciatis unde omnis iste natus error sit voluptatem accusantium doloremque laudantium, totam rem aperiam.",
]
token_list = [tokenizer.encode(string) for string in string_list]
activation_list = [t.rand(len(tokens)) for tokens in token_list]

In [148]:
example_dataset = CachedDataset(model,tokenizer,token_list,activation_list)

  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([1, 32, 156, 128])
torch.Size([1, 32, 151, 128])
torch.Size([1, 32, 149, 128])
torch.Size([1, 32, 151, 128])
torch.Size([1, 32, 152, 128])


0it [00:00, ?it/s]

In [149]:

import itertools

def custom_collate_fn(batch):
    # Unzip the batch
    caches, sentence_ids, labels = zip(*batch)

    # Convert caches to a list of lists (flatten the tuples)
    caches_list = [list(itertools.chain.from_iterable(cache)) for cache in caches]

    # Pad the sentence_ids to the longest in the batch
    sentence_ids_padded = t.nn.utils.rnn.pad_sequence(sentence_ids, batch_first=True)

    # Convert labels to tensor
    labels_tensor = t.tensor(labels)

    return caches_list, sentence_ids_padded, labels_tensor

# Example usage with DataLoader
dataloader = DataLoader(example_dataset, batch_size=2, shuffle=True, collate_fn=custom_collate_fn)


In [150]:
for tokens in token_list:
    print(len(tokens))

41
36
34
36
37


In [151]:
cache_0, sentence_ids_0, labels_0 = example_dataset[0]
cache_1, sentence_ids_1, labels_1 = example_dataset[1]

In [153]:
print(t.sum(cache_0[0][0]))
print(t.sum(cache_1[0][0]))
print(t.sum(cache_0[0][0]))
print(t.sum(cache_1[0][0]))

tensor(5921.4639, device='cuda:0')
tensor(5921.4639, device='cuda:0')
tensor(5921.4639, device='cuda:0')
tensor(5921.4639, device='cuda:0')


In [97]:
cache_0, sentence_ids_0, labels_0 = example_dataset[0]
cache_1, sentence_ids_1, labels_1 = example_dataset[1]
print(cache_0[0][0].shape)
print(cache_1[0][0].shape)

torch.Size([1, 32, 156, 128])
torch.Size([1, 32, 156, 128])


In [57]:
for cache, sentence, target in dataloader:
    pass

In [62]:
cache[0][0].shape

torch.Size([1, 32, 151, 128])

In [79]:
string_list = ["Lore ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.",
                "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.",
                "Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.",
                "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.",
                "Sed ut perspiciatis unde omnis iste natus error sit voluptatem accusantium doloremque laudantium, totam rem aperiam.",
]
token_list = [tokenizer.encode(string) for string in string_list]
activation_list = [t.rand(len(tokens)) for tokens in token_list]


systtem_prompt =   """Your task is to assess if a given token (word) from a sentence represents a specified concept. Provide a rating based on this assessment:
    If the token represents the concept, respond with "Rating: 1".
    If the token does not represent the concept, respond with "Rating: 0".
    Focus solely on the token and use the sentence for context only. Be confident.
    """

magic_token_ids = tokenizer.encode("magic")[1]
def systemprompt_to_ids(tokenizer, systtem_prompt, post_text = "Sentence: "):
    B_INST, E_INST = "[INST]", "[/INST]"
    B_SYS, E_SYS = "<<SYS>>", "<</SYS>>"
    prompt = B_INST + B_SYS + systtem_prompt + E_SYS + post_text
    ids = t.tensor(tokenizer.encode(prompt)).unsqueeze(0)
    return ids
def get_cache(model, ids, prev_cache = None):
    with t.no_grad():
        if prev_cache is None:
            output = model(ids, return_dict=True)
        else:
            output = model(ids, past_key_values=prev_cache, return_dict=True)
    return output.past_key_values
def sentence_to_ids(tokenizer, sentence, post_text = "Concept:"):
    post_text_ids = tokenizer.encode(post_text)[1:]
    ids = t.tensor(sentence + post_text_ids).unsqueeze(0)
    return ids
def question_end_to_ids(tokenizer,magic_token_ids, question_token_ids):
    B_INST, E_INST = "[INST]", "[/INST]"
    B_SYS, E_SYS = "<<SYS>>", "<</SYS>>"
    text_1 = " Token:"
    ids_1 = tokenizer.encode(text_1)[1:]
    text_2 = E_INST +"The rating is "
    ids_2 = tokenizer.encode(text_2)[1:]
    ids =[magic_token_ids] + ids_1 + [question_token_ids] + ids_2
    return t.tensor(ids).unsqueeze(0)



In [113]:
example_dataset = CachedDataset(model,tokenizer,token_list,activation_list)

  0%|          | 0/5 [00:00<?, ?it/s]

TypeError: TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]

In [82]:
sentence_1 = token_list[0]

systemprompt_ids = systemprompt_to_ids(tokenizer, systtem_prompt)
sentence_1_ids = sentence_to_ids(tokenizer, sentence_1)
question_end_ids = question_end_to_ids(tokenizer, magic_token_ids, sentence_1[2])

total_ids = t.cat([systemprompt_ids, sentence_1_ids, question_end_ids], dim = 1)

In [84]:
tokenizer.decode(total_ids[0].tolist())

'<s> [INST]<<SYS>>Your task is to assess if a given token (word) from a sentence represents a specified concept. Provide a rating based on this assessment:\n    If the token represents the concept, respond with "Rating: 1".\n    If the token does not represent the concept, respond with "Rating: 0".\n    Focus solely on the token and use the sentence for context only. Be confident.\n    <</SYS>>Sentence: <s> Lore ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Concept: magic  Token:ore [/INST]The rating is '

In [86]:
systemprompt_ids = systemprompt_to_ids(tokenizer, systtem_prompt)
system_promt_cache = get_cache(model, systemprompt_ids.to(device))
sentence_1_ids = sentence_to_ids(tokenizer, sentence_1)
sentence_1_cache = get_cache(model, sentence_1_ids.to(device), system_promt_cache)
question_end_ids = question_end_to_ids(tokenizer, magic_token_ids, sentence_1[2])

In [60]:
first_sentence_ids = sentence_to_ids(tokenizer, string_list[0])
first_sentence_cache = get_cache(model, first_sentence_ids, systemprompt_cache)

In [48]:
sentence_caches = [get_cache(model, sentence_to_ids(tokenizer, string), systemprompt_cache) for string in string_list]

In [49]:
len(sentence_caches[0])

32

In [108]:
tokenizer.encode("0")[-1]

29900

In [109]:
tokenizer.decode([29900])

'0'