In [67]:
import os
import transformer_lens
from transformer_lens import HookedTransformer
from transformer_lens import utils
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, BitsAndBytesConfig
import accelerate
import bitsandbytes
import torch
import plotly
import plotly.express as px
import einops
import numpy as np
import psutil
import pandas as pd
import random
import tuned_lens

In [4]:
torch.set_grad_enabled(False)


<torch.autograd.grad_mode.set_grad_enabled at 0x198ac164a70>

# Load Data

In [20]:
prompt = "Description 3: Box A contains the cow. Box B contains nothing. Box C contains the mouse. John moves the cow from Box A to Box B. Box C has no change in its content.\nStatement 3: Let's think step by step. Initially, box A contains the cow, and John moves the cow to Box B. So, Box A now contains nothing, and Box B contains the cow. Box C contains the mouse."

In [21]:
print(prompt)

Description 3: Box A contains the cow. Box B contains nothing. Box C contains the mouse. John moves the cow from Box A to Box B. Box C has no change in its content.
Statement 3: Let's think step by step. Initially, box A contains the cow, and John moves the cow to Box B. So, Box A now contains nothing, and Box B contains the cow. Box C contains the mouse.


In [26]:
model_name ="meta-llama/Llama-2-7b-chat-hf"
model_name ="stanford-crfm/alias-gpt2-small-x21"
token = "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"




In [61]:
class ModelHelper:
    def __init__(self, token, device=None, load_in_8bit=False):
        if device is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)

        hf_model = AutoModelForCausalLM.from_pretrained(model_name, token=token,
                                                          device_map='auto')
        self.model = HookedTransformer.from_pretrained(model_name,
                                             hf_model=hf_model,
                                             fold_ln=False,
                                             fold_value_biases=False,
                                             center_writing_weights=False,
                                             center_unembed=False,
                                             tokenizer=self.tokenizer,
                                             device = self.device)

        print(self.model)
        print(self.model.cfg.n_layers)
        self.device = next(self.model.parameters()).device
        self.d_vocab = self.model.cfg.d_vocab
        self.n_layers = self.model.cfg.n_layers

    def logits_all_layers(self, text):
        inputs = self.tokenizer(text,return_tensors="pt")
        seq_len = inputs["input_ids"].shape[1]

        # Get residual output for each layer
        z_name_filter = lambda name: name.endswith("resid_post")
        self.model.reset_hooks()
        _,cache = self.model.run_with_cache(
        inputs["input_ids"],
        names_filter = z_name_filter,
        return_type = None
        )
        
        layer_logit_all = torch.zeros(self.n_layers,seq_len,self.d_vocab)
        for layer in range(self.model.cfg.n_layers):
            resid_ln = self.model.ln_final(cache[f'blocks.{layer}.hook_resid_post'])
            layer_logit = self.model.unembed(resid_ln)
            layer_logit_all[layer,:,:] = layer_logit
        return layer_logit_all
                    
            

In [62]:
model = ModelHelper(token, load_in_8bit=True)

Loaded pretrained model stanford-crfm/alias-gpt2-small-x21 into HookedTransformer
HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
 

In [63]:
log_probs = model.logits_all_layers(prompt).float().log_softmax(dim=-1).numpy()


In [65]:
try:
  from google.colab import output
  output.enable_custom_widget_manager()
except:
  pass

from tuned_lens.plotting import PredictionTrajectory
import ipywidgets as widgets
from plotly import graph_objects as go
import numpy as np

tokenizer=model.tokenizer
def make_plot(text, layer_stride, statistic, token_range):
    input_ids = tokenizer.encode(text)
    targets = input_ids[1:] + [tokenizer.eos_token_id]

    if len(input_ids) == 0:
        return widgets.Text("Please enter some text.")

    if (token_range[0] == token_range[1]):
        return widgets.Text("Please provide valid token range.")

    print(text)
    print(model.logits_all_layers(text).shape)
    log_probs = model.logits_all_layers(text).float().log_softmax(dim=-1).numpy()
    pred_traj = PredictionTrajectory(log_probs = log_probs,#np.zeros([32, len(input_ids), 32000]),
                                    input_ids = np.asarray(input_ids),
                                    targets= np.asarray(targets),
                                    anti_targets=None,
                                    tokenizer=tokenizer)
    pred_traj = pred_traj.slice_sequence(slice(*token_range))
    return getattr(pred_traj, statistic)().stride(layer_stride).figure(
        title=f"LLamav2lense {statistic}",
    )

style = {'description_width': 'initial'}
statistic_wdg = widgets.Dropdown(
    options=[
        ('Entropy', 'entropy'),
        ('Cross Entropy', 'cross_entropy'),
        ('Forward KL', 'forward_kl'),
    ],
    description='Select Statistic:',
    style=style,
)


layer_stride_wdg = widgets.BoundedIntText(
    value=2,
    min=1,
    max=10,
    step=1,
    description='Layer Stride:',
    disabled=False
)

token_range_wdg = widgets.IntRangeSlider(
    description='Token Range',
    min=0,
    max=1,
    step=1,
    style=style,
)


# def update_token_range(response_len=1,*args):
#     token_range_wdg.max = len(tokenizer.encode(text_wdg.value))+ response_len

def update_token_range(*args):
    token_range_wdg.max = len(tokenizer.encode(text_wdg.value))


In [68]:
prompt = """

Description 3: Box A contains the cow. Box B contains nothing. Box C contains the mouse. John moves the cow from Box A to Box B. Box C has no change in its content.

Statement 3: Let's think step by step. Initially, box A contains the cow, and John moves the cow to Box B. So, Box A now contains nothing, and Box B contains the cow. Box C contains the mouse."""
text_wdg = widgets.Textarea(
    description="Input Text",
    value =prompt
)


update_token_range()

token_range_wdg.value = [0, token_range_wdg.max]
text_wdg.observe(update_token_range, 'value')

interact = widgets.interact.options(manual_name='Run Lens', manual=True)

plot = interact(
    make_plot,
    text=text_wdg,
    statistic=statistic_wdg,
    layer_stride=layer_stride_wdg,
    token_range=token_range_wdg,
)

interactive(children=(Textarea(value="\n\nDescription 3: Box A contains the cow. Box B contains nothing. Box C…