In [74]:
import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens import utils as tutils
from transformer_lens.evals import make_pile_data_loader, evaluate_on_dataset

from functools import partial
from datasets import load_dataset
import tqdm

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

# from steering.eval_utils import evaluate_completions
from steering.utils import get_activation_steering, get_sae_diff_steering, remove_sae_feats, text_to_sae_feats, top_activations
from steering.patch import generate, get_loss

from steering.visualization import Table

import plotly.express as px

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x1507c3790>

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HookedTransformer.from_pretrained('gpt2-small', device=device)



Loaded pretrained model gpt2-small into HookedTransformer


In [112]:
from transformer_lens import HookedTransformer

def get_next_logits(
    model: HookedTransformer,
    fwd_hooks,
    prompt="",
):
    """
    
    """
    logits = model.run_with_hooks(prompt, fwd_hooks=fwd_hooks, prepend_bos=True)
    print(logits.shape)
    print(logits[0, -1, 255])
    top_logits = logits[0, -1].topk(20)
    # tokens = model.tokenizer.batch_decode(top_logits.indices)
    # values = top_logits.values.tolist()

    return top_logits

In [114]:
get_next_logits(model, [], "Once upon a time,")

torch.Size([1, 6, 50257])
tensor(-6.2251)


torch.return_types.topk(
values=tensor([12.8311, 12.0995, 11.9779, 11.8784, 11.5975, 11.3170, 11.1910, 10.8977,
        10.6984, 10.3434, 10.3292, 10.1011, 10.0181,  9.8481,  9.7113,  9.6436,
         9.6348,  9.6215,  9.5249,  9.4524]),
indices=tensor([ 262,  612,  257,  314,  618,  340,  356,  287,  345,  428,  661,  281,
         616, 5384, 1466,  355,  477,  530,  611, 5519]))

In [109]:
import pandas as pd
from IPython.display import display, clear_output
import ipywidgets as widgets

def generate_step_by_step(model: HookedTransformer, prompt: str, watch_tokens: list = []):
    """
    Generate a REPL for the model.
    """ 
    prompt_display = widgets.HTML(
        value=f"<b>Prompt:</b> {prompt}",
        placeholder='',
        description='',
        layout=widgets.Layout(width='100%')
    )

    input_box = widgets.Text(
        value='',
        placeholder='Enter a logit number to intervene or leave blank.',
        description='Next step:',
        disabled=False,
        layout=widgets.Layout(width='100%', margin='10px 0')
    )

    output_prompt = widgets.Output(layout={'border': '1px solid lightgrey'})
    output_area_left = widgets.Output(layout={'border': '1px solid lightgrey', 'width': '50%'})
    output_area_right = widgets.Output(layout={'border': '1px solid lightgrey', 'width': '50%'})
    side_by_side_output = widgets.HBox([output_area_left, output_area_right])

    def execute_command(logits, model, watch_tokens=[]):
        tokens = model.tokenizer.batch_decode(logits.indices)
        values = logits.values.tolist()

        with output_prompt:
            clear_output(wait=True)
            display(prompt_display)

        with output_area_left:
            clear_output(wait=True)
            
            top_tokens = tokens[:20]
            top_values = values[:20]
            positions: list = list(range(1, 21))
            try:
                Table("Top Tokens", ["Positions", "Token", "Prob"] , zip(positions, top_tokens, top_values))
                
            except Exception as e:
                print(f"Error: {e}")

        with output_area_right:
            clear_output(wait=True)
            

            top_tokens = tokens[:20]
            top_values = values[:20]
            positions: list = list(range(1, 21))
            try:
                Table("Other Tokens", ["Token", "Prob"] , zip(top_tokens, top_values))
            except Exception as e:
                print(f"Error: {e}")

    def on_submit(change):
        logits = get_next_logits(model, [], prompt)
        execute_command(logits, model, watch_tokens)

        input_box.value = ''
        input_box.focus = True

    input_box.on_submit(on_submit)

    display(input_box)
    display(output_prompt)
    display(side_by_side_output)

    # display_output()

    # execute_command("")
    # Initial display
    logits = get_next_logits(model, [], prompt)
    execute_command(logits, model, watch_tokens)

# Run the REPL
generate_step_by_step(model, "Once upon a time, there were", [555, 222])

  input_box.on_submit(on_submit)


Text(value='', description='Next step:', layout=Layout(margin='10px 0', width='100%'), placeholder='Enter a lo…

Output(layout=Layout(border_bottom='1px solid lightgrey', border_left='1px solid lightgrey', border_right='1px…

HBox(children=(Output(layout=Layout(border_bottom='1px solid lightgrey', border_left='1px solid lightgrey', bo…

In [104]:
logits = get_next_logits(model, [], "The cat sat on the mat.")


In [105]:
logits

torch.return_types.topk(
values=tensor([ 17.2598,  16.8689,  16.6076,  ..., -19.3946, -20.7319, -20.7520]),
indices=tensor([  198,   383,   632,  ..., 11039, 13945, 19476]))

['\n',
 ' The',
 ' It',
 ' He',
 ' She',
 ' His',
 ' "',
 ' Her',
 ' I',
 ' A',
 ' Its',
 ' In',
 ' And',
 '\n\n',
 ' There',
 ' As',
 ' When',
 ' This',
 ' That',
 ' On',
 ' No',
 ' Then',
 ' With',
 ' My',
 ' One',
 ' An',
 ' But',
 ' After',
 ' At',
 ' Not',
 ' They',
 ' If',
 ' For',
 '<|endoftext|>',
 ' We',
 ' What',
 ' You',
 ' Two',
 ' (',
 ' To',
 ' Even',
 ' L',
 ' Just',
 ' All',
 ' Like',
 ' Now',
 ' How',
 ' From',
 " '",
 ' Harry',
 ' So',
 ' While',
 ' Ruby',
 ' Before',
 ' Behind',
 ' Every',
 ' T',
 ' Suddenly',
 ' Inside',
 ' Was',
 ' Looking',
 ' Some',
 ' W',
 ' Another',
 ' C',
 ' Once',
 ' Although',
 ' F',
 ' P',
 ' Nothing',
 ' Sitting',
 ' Why',
 ' S',
 ' Both',
 ' Only',
 ' Each',
 ' Or',
 ' Despite',
 ' R',
 ' Three',
 ' Slowly',
 ' By',
 ' Something',
 ' Maybe',
 ' Standing',
 ' Well',
 ' J',
 ' M',
 ' Had',
 ' Don',
 ' Though',
 ' H',
 ' Over',
 ' Sh',
 ' Blake',
 ' St',
 ' D',
 '\xa0',
 ' Yang',
 ' Someone',
 ' Cat',
 ' Without',
 ' Yes',
 ' Y',
 ' B',
 ' 

In [140]:
def get_indices(tensor, values):
    """
    Get the indices of the given values in the tensor.

    Args:
    tensor (torch.Tensor): The tensor to search in.
    values (list): The list of values to find in the tensor.

    Returns:
    list: A list of indices where the values are found in the tensor.
    """
    indices = []
    for value in values:
        idx = torch.nonzero(torch.eq(tensor, value)).squeeze()
        if idx.numel() > 0:  # Check if there are any indices found
            indices.append(idx.item())
    return indices

def preview_step(model: HookedTransformer, prompt: str, fwd_hooks=[], watch_logits: list = []):
    logits = model.run_with_hooks(prompt, fwd_hooks=fwd_hooks, prepend_bos=True)

    ranked_logits = logits[0, -1].topk(logits.shape[-1])
    indices = ranked_logits.indices
    values = ranked_logits.values

    positions = list(range(1, 21))
    top_values = [value.item() for value in values[0:20]]

    top_indices = indices[0:20]
    top_tokens = model.tokenizer.batch_decode(top_indices)
    Table("Top Tokens", ["Positions", "Token", "Prob"] , zip(positions, top_tokens, top_values))

    if len(watch_logits) > 0:
        watch_positions = get_indices(indices, watch_logits)
        watch_values = [value.item() for value in values[watch_positions]]

        watch_indices = indices[watch_positions]
        watch_tokens = model.tokenizer.batch_decode(watch_indices)

        Table("Watch Tokens", ["Positions", "Token", "Prob"] , zip(watch_positions, watch_tokens, watch_values))


preview_step(model, "Today I want", watch_logits=[555, 1203])


Unnamed: 0,Positions,Token,Prob
0,1,to,19.326347
1,2,you,15.544235
2,3,a,13.122031
3,4,the,12.669929
4,5,everyone,12.155163
5,6,my,11.714208
6,7,people,11.512468
7,8,your,11.403308
8,9,some,11.29584
9,10,an,11.197874


Unnamed: 0,Positions,Token,Prob
0,307,un,6.272503
1,3261,less,3.372362
