# Overview

This is sanity check to see if transformer lens can replicate what I have in this [notebook](https://colab.research.google.com/drive/1nFX9O8ahmtJT2jIL9gMDVdNQNg61BsEM?usp=sharing), which was done without the transformer lens.

In [11]:
import os
import transformer_lens
from transformer_lens import HookedTransformer
from transformer_lens import utils
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
import accelerate
import bitsandbytes
import torch
import plotly
import plotly.express as px
import einops
import numpy as np
import time
from torch.utils.data import DataLoader
from datasets import Dataset
import tqdm


# Load Model and Tokenizer

- Load the model and tokenizer **locally**
- otherwise not compatible with the transformer lens

In [2]:
LLAMA_PATH = "D:/Data/Llama/Llama_2/7b_chat_hf"
LLANA_NAME = "meta-llama/Llama-2-7b-chat-hf"

In [3]:
model_hf = AutoModelForCausalLM.from_pretrained(LLAMA_PATH,
                                               device_map='cpu')
tokenizer = AutoTokenizer.from_pretrained(LLAMA_PATH)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
inference_dtype = torch.float32

model = HookedTransformer.from_pretrained(LLANA_NAME,
                                             hf_model=model_hf,
                                             dtype=inference_dtype,
                                             fold_ln=False,
                                             fold_value_biases=False,
                                             center_writing_weights=False,
                                             center_unembed=False,
                                             tokenizer=tokenizer)
model.generate("The capital of Germany is", max_new_tokens=2, temperature=0)
model = model.to("cuda" if torch.cuda.is_available() else "cpu")

Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer


  0%|          | 0/2 [00:00<?, ?it/s]

Moving model to device:  cuda


In [5]:
print("free(Gb):", torch.cuda.mem_get_info()[0]/1000000000, "total(Gb):", torch.cuda.mem_get_info()[1]/1000000000)


free(Gb): 0.0 total(Gb): 11.81089792


# Check performance in entity tracking with 1 state update

1.   List item
2.   List item



- The task structure is:
  - Three boxes
  - One state update
  - Maximum of 1 object per box
  - Without exact template


- Comparing two kinds of prompt structure:
  - 1. Few-shot prompt (without CoT)
  - 2. Few-shot CoT + think step by step
  - 3. zero-shot CoT (think step by step)





## Zero-shot CoT

In [11]:
prompt = """
Given the description after "Description:", write a true statement about all boxes and their contents after "Statement:". Make sure to keep track of the changes and update the contents of the boxes according to the changes.

Description 3: Box A contains the cow. Box B contains nothing. Box C contains the mouse. John moves the cow from Box A to Box B. Box C has no change in its content.

Statement 3: Let's think step by step. Box A contains"""
start_time = time.time()
outputs = model.generate(prompt, max_new_tokens=100, temperature=0)
print(outputs)
print("--- %s seconds ---" % (time.time() - start_time))


  0%|          | 0/100 [00:00<?, ?it/s]


Given the description after "Description:", write a true statement about all boxes and their contents after "Statement:". Make sure to keep track of the changes and update the contents of the boxes according to the changes.

Description 3: Box A contains the cow. Box B contains nothing. Box C contains the mouse. John moves the cow from Box A to Box B. Box C has no change in its content.

Statement 3: Let's think step by step. Box A contains the cow, and John moves the cow to Box B. So, Box A is now empty, and Box B contains the cow.

Box A: Empty
Box B: Cow
Box C: Mouse

Please update the contents of the boxes according to the statements.</s>
--- 118.9385597705841 seconds ---


In [11]:
prompt = """
Given the description after "Description:", write a true statement about all boxes and their contents after "Statement:". Make sure to keep track of the changes and update the contents of the boxes according to the changes.

Description 3: Box A contains the cow. Box B contains nothing. Box C contains the mouse. John moves the cow from Box A to Box B. Box C has no change in its content.

Statement 3: Let's think step by step. Box A contains"""
start_time = time.time()
outputs = model.generate(prompt, max_new_tokens=100, temperature=0)
print(outputs)
print("--- %s seconds ---" % (time.time() - start_time))


  0%|          | 0/100 [00:00<?, ?it/s]


Given the description after "Description:", write a true statement about all boxes and their contents after "Statement:". Make sure to keep track of the changes and update the contents of the boxes according to the changes.

Description 3: Box A contains the cow. Box B contains nothing. Box C contains the mouse. John moves the cow from Box A to Box B. Box C has no change in its content.

Statement 3: Let's think step by step. Box A contains the cow, and John moves the cow to Box B. So, Box A is now empty, and Box B contains the cow.

Box A: Empty
Box B: Cow
Box C: Mouse

Please update the contents of the boxes according to the statements.</s>
--- 118.9385597705841 seconds ---


In [6]:
prompt = """
Given the description after "Description:", write a true statement about all boxes and their contents after "Statement:". Make sure to keep track of the changes and update the contents of the boxes according to the changes.

Description 3: Box A contains the cow. Box B contains nothing. Box C contains the mouse. John moves the cow from Box A to Box B. Box C has no change in its content.

Statement 3: Let's think step by step. Box A contains"""
start_time = time.time()
outputs = model.generate(prompt, max_new_tokens=100, temperature=0)
print(outputs)
print("--- %s seconds ---" % (time.time() - start_time))


  0%|          | 0/100 [00:00<?, ?it/s]


Given the description after "Description:", write a true statement about all boxes and their contents after "Statement:". Make sure to keep track of the changes and update the contents of the boxes according to the changes.

Description 3: Box A contains the cow. Box B contains nothing. Box C contains the mouse. John moves the cow from Box A to Box B. Box C has no change in its content.

Statement 3: Let's think step by step. Box A contains the cow, and John moves the cow to Box B. So, Box A is now empty, and Box B contains the cow.

Box A: Empty
Box B: Cow
Box C: Mouse

Please update the contents of the boxes according to the statements.</s>
--- 127.31756210327148 seconds ---


In [16]:

batch_size = 1
max_new_tokens = 100
prompt = """
Given the description after "Description:", write a true statement about all boxes and their contents after "Statement:". Make sure to keep track of the changes and update the contents of the boxes according to the changes.

Description 3: Box A contains the cow. Box B contains nothing. Box C contains the mouse. John moves the cow from Box A to Box B. Box C has no change in its content.

Statement 3: Let's think step by step. Box A contains"""

input_tokens = model.to_tokens(prompt)

dataset = Dataset.from_dict(
    {
        "input_ids":input_tokens,
    }).with_format("torch")


dataloader = DataLoader(dataset, batch_size=batch_size)

with torch.no_grad():
    for _, inputs in tqdm.tqdm(enumerate(tqdm.tqdm(dataloader))):
        inputs["input_ids"] = inputs["input_ids"].to('cuda')
        # outputs = model(input_ids = inputs["input_ids"]) # next token prediction
        output = model.generate(inputs["input_ids"],
                                max_new_tokens=max_new_tokens,
                                temperature=0)  # generate


  0%|                                                                                            | 0/1 [00:00<?, ?it/s]
0it [00:00, ?it/s][A

  0%|          | 0/100 [00:00<?, ?it/s]


100%|███████████████████████████████████████████████████████████████████████████████████| 1/1 [02:03<00:00, 123.60s/it]
1it [02:03, 123.60s/it]


In [21]:
print(model.tokenizer.decode(output[0]))

<s> 
Given the description after "Description:", write a true statement about all boxes and their contents after "Statement:". Make sure to keep track of the changes and update the contents of the boxes according to the changes.

Description 3: Box A contains the cow. Box B contains nothing. Box C contains the mouse. John moves the cow from Box A to Box B. Box C has no change in its content.

Statement 3: Let's think step by step. Box A contains the cow, and John moves the cow to Box B. So, Box A is now empty, and Box B contains the cow.

Box A: Empty
Box B: Cow
Box C: Mouse

Please update the contents of the boxes according to the statements.</s>
