### Corner Visualization Experiment

This is an imprecise term that may refer to an intersection of embeddings.

In [2]:
import sys
sys.path.append("..")

In [3]:
import torch, baukit
from transformers import AutoModelForCausalLM, AutoTokenizer
#MODEL_NAME = "gpt2-xl"  # gpt2-xl or EleutherAI/gpt-j-6B
MODEL_NAME = "EleutherAI/gpt-j-6B"
import os 
import models

#os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.7"

device = "mps"
mt =  models.load_model("gptj", device=device, fp16=True)

model, tok = mt.model, mt.tokenizer

baukit.set_requires_grad(False, model)

In [21]:
@torch.inference_mode()
def complete(prompt, max_new_tokens=1):
    #tokenize inputs as pytorch tensors
    inputs = mt.tokenizer(prompt, return_tensors="pt").to(device)
    #generate from inputs
    outputs = mt.model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        pad_token_id=mt.tokenizer.pad_token_id,
    )
    return mt.tokenizer.batch_decode(outputs)[0]

complete("The Eiffel Tower is located in the city of")

'The Eiffel Tower is located in the city of Paris'

I am commenting this out - but I have included a world-cities csv for doing a corner at world cities.

Here are a few hacks for experimenting with corner calculations in different ways.

`make_corner_vector` tries to calculate a corner with linear algebra - it sorta comes close, but it doesn't actually do what it aims to do, because the layernorm nonlinearity interferes.

In [22]:
generic_text = 'The quick brown fox jumped over the lazy dogs, while the rain in Spain fell mainly in the plain.'

def get_logit_scale(model, tok):
    #from tokenized, make v into k:tensor pair. 
    #None expands the dimensions
    inp = {k: torch.tensor(v)[None].to('mps') for k, v in tok(generic_text).items()}
    embed_layer = [n for n, _ in model.named_modules() if 'lm_head' in n][0]

    with baukit.Trace(model, embed_layer) as t:
        model(**inp)
        return t.output.max(2)[0].mean() # scale is: average maximum logit

# Normalizing pre-layernorm vectors to the empirically measured scale seems to overestimate
def get_prenorm_scale(model, tok):
    inp = {k: torch.tensor(v)[None].to('mps') for k, v in tok(generic_text).items()}
    with baukit.Trace(model, 'transformer.ln_f', retain_input=True) as t:
        model(**inp)
        return (t.input ** 2).mean().sqrt() # This is about 2.86 for GPT-J, which is too big.

def make_corner_vector(words, model, tok, logit_scale=None, prenorm_scale=None, add_space=True):
    if logit_scale is None:
        logit_scale = get_logit_scale(model, tok)
    if prenorm_scale is None:
        prenorm_scale = get_prenorm_scale(model, tok)
    decoding_vectors = model.lm_head.weight
    decoding_bias = model.lm_head.bias
    
    # Form a list of target tokens
    token_numbers = [tok((' '  if add_space else '') + w)['input_ids'][0] for w in words]
    token_numbers = list(set(token_numbers))

    # Solve the linear algebra
    A = decoding_vectors[token_numbers]
    b = logit_scale - decoding_bias[token_numbers]
    x = torch.linalg.lstsq(A, b).solution
    x_scale = (x ** 2).mean().sqrt()
    prenorm_x = x / x_scale * prenorm_scale
    return prenorm_x

Here is an optimization-based corner-finder.

`optimize_corner_vector` initializes with linear algebra but then finishes off with an optimizer.  I found that RMSprop actually works pretty well without tuning.


In [23]:
def optimize_corner_vector(words, model, tok, logit_scale=None, prenorm_scale=None,
                           lr=1e-4, iters=500, add_space=True):
    if logit_scale is None:
        logit_scale = get_logit_scale(model, tok)
    if prenorm_scale is None:
        prenorm_scale = get_prenorm_scale(model, tok)
    decoding_vectors = model.lm_head.weight
    decoding_bias = model.lm_head.bias
    print(f'decoding_bias.device: {decoding_bias.device}')
    # Form a list of target tokens
    token_numbers = [tok((' '  if add_space else '') + w)['input_ids'][0] for w in words]
    token_numbers = list(set(token_numbers))

    # Solve the linear algebra
    A = decoding_vectors[token_numbers]
    b = logit_scale - decoding_bias[token_numbers]
    x = torch.linalg.lstsq(A, b).solution
    x_scale = (x ** 2).mean().sqrt()
    prenorm_x = x / x_scale * prenorm_scale

    # Now optimize to make it better.
    decoder = torch.nn.Sequential(model.transformer.ln_f, model.lm_head, torch.nn.Softmax(dim=-1))
    x = prenorm_x
    x.requires_grad = True
    best_loss = None
    optimizer = torch.optim.RMSprop([x], lr=lr)
    
    for _ in range(iters):
        p = decoder(x)[token_numbers]
        m = p.mean()
        loss = (p - m).abs().mean() - m
        if best_loss is None or loss < best_loss:
            best_loss = loss.clone().detach()
            result = x.clone().detach()
        loss.backward()
        optimizer.step()
    return result


Here we find the corner between eight words (eight color words, for fun).

In [27]:
colors = ['red', 'green', 'blue', 'orange', 'yellow', 'purple', 'gray', 'brown']
token_numbers = [tok(' ' + c)['input_ids'][0] for c in colors]

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"]= "1"
print(os.environ["PYTORCH_ENABLE_MPS_FALLBACK"])

##(moved both model and input vectors (in optimize_corner_vector) to 'mps')
model = model.to('mps')

v = optimize_corner_vector(colors, model, tok)

decoder = torch.nn.Sequential(model.transformer.ln_f, model.lm_head, torch.nn.Softmax(dim=-1))
decoder(v)[token_numbers]

1
decoding_bias.device: mps:0


NotImplementedError: The operator 'aten::linalg_lstsq.out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

In [None]:
import matplotlib
from matplotlib import pyplot as plt
cmap = matplotlib.colors.ListedColormap(colors)

x, y = torch.meshgrid(torch.linspace(-1, 1, 100), torch.linspace(-1, 1, 100))
how_many = 3

for j in range(4096):
    for i in range(j+1, 4096):
        vv = torch.zeros_like(x)[:,:,None] + v.cpu()[None,None,:]
        vv[:,:,j] += x * 10
        vv[:,:,i] += y * 10
        cindex = decoder(vv)[:,:,token_numbers].argmax(dim=2).cpu()
        if len(cindex.unique()) == len(colors):
            print(j, i)
            plt.scatter(x, y, c=cindex, cmap=cmap)
            plt.axis('square')
            plt.show()
            how_many -= 1
            if how_many <= 0:
                break
    if how_many <= 0:
            break

print('done')

## Attribute lens based-on-corner test

Here is an attribute-lens style test.

In [None]:
def make_corner_readout(model, tok, words, logit_scale=None, prenorm_scale=None):
    # Get the corner vector
    x = optimize_corner_vector(words, model, tok, logit_scale=logit_scale, prenorm_scale=prenorm_scale)

    decoder = torch.nn.Sequential(model.transformer.ln_f, model.lm_head, torch.nn.Softmax(dim=-1))
    def corner_readout(h):
        import numpy
        cuda_h = h
        return decoder(cuda_h + x)
    return corner_readout

In [None]:
import csv
with open('worldcities.csv') as w:
    records = list(csv.DictReader(w))
big_city_list = [r['city'] for r in records if float(r['population'] or 0) >= 1000000]


In [None]:
short_city_list = [c for c in big_city_list if len(tok(' ' + c)['input_ids']) <= 1]
len(short_city_list)

In [None]:
f = make_corner_readout(
    model, tok,
    short_city_list)

# Try running f on some zero vectors
probs = f(torch.zeros(1, 5, 1, 3, 2, 4096))
print(probs.shape)
probs.sum(dim=-1).flatten()  # Veriy that Probabilities add up to 1.0

This function gathers the hidden state.

In [None]:
def get_hidden_states(model, tok, prefix):
    import re
    from baukit import TraceDict
    inp = {k: torch.tensor(v)[None] for k, v in tok(prefix).items()}
    layer_names = [n for n, _ in model.named_modules()
                   if re.match(r'^transformer.h.\d+$', n)]
    with TraceDict(model, layer_names) as tr:
        logits = model(**inp)['logits']
    return torch.stack([tr[layername].output[0] for layername in layer_names])

prompt = 'Hello, my name is also'
hs = get_hidden_states(model, tok, prompt)
hs.shape

Here is the basic logit lens visualization.  Comments inline.

In [None]:
def show_logit_lens(model, tok, prefix, topk=5, color=None, hs=None, decoder=None):
    from baukit import show

    # You can pass in a function to compute the hidden states, or just the tensor of hidden states.
    if hs is None:
        hs = get_hidden_states
    if callable(hs):
        hs = hs(model, tok, prefix)

    # The full decoder head normalizes hidden state and applies softmax at the end.
    if decoder is None:
        decoder = torch.nn.Sequential(model.transformer.ln_f, model.lm_head, torch.nn.Softmax(dim=-1))

    probs = decoder(hs) # Apply the decoder head to every hidden state
    favorite_probs, favorite_tokens = probs.topk(k=topk, dim=-1)
    # Let's also plot hidden state magnitudes
    magnitudes = hs.norm(dim=-1)
    # For some reason the 0th token always has huge magnitudes, so normalize based on subsequent token max.
    magnitudes = magnitudes / magnitudes[:,:,1:].max()
    
    # All the input tokens.
    prompt_tokens = [tok.decode(t) for t in tok.encode(prefix)]

    # Foreground color shows token probability, and background color shows hs magnitude
    if color is None:
        color = [0, 0, 255]
    def color_fn(m, p):
        #a = [int(255 * (1-m) + c * m) for c in color]
        a = [int(255 * (1-p) + c * p) for c in color]
        b = [int(196 * (1-p) + 0 * p)] * 2 + [0]
        return show.style(background=f'rgb({a[0]}, {a[1]}, {a[2]})',
                          #color=f'rgb({b[0]}, {b[1]}, {b[2]})' )
                          color='black' if p < 0.75 else 'white' )

    # In the hover popup, show topk probabilities beyond the 0th.
    def hover(tok, prob, toks, m):
        lines = [f'mag: {m:.2f}']
        for p, t in zip(prob, toks):
            lines.append(f'{tok.decode(t)}: prob {p:.2f}')
        return show.attr(title='\n'.join(lines))
    
    # Construct the HTML output using show.
    header_line = [ # header line
             [[show.style(fontWeight='bold'), 'Layer']] +
             [
                 [show.style(background='yellow'), show.attr(title=f'Token {i}'), t]
                 for i, t in enumerate(prompt_tokens)
             ]
         ]
    layer_logits = [
             # first column
             [[show.style(fontWeight='bold'), layer]] +
             [
                 # subsequent columns
                 [color_fn(m, p[0]), hover(tok, p, t, m), show.style(overflowX='hide'), tok.decode(t[0])]
                 for m, p, t in zip(wordmags, wordprobs, words)
             ]
        for layer, wordmags, wordprobs, words in
                zip(range(len(magnitudes)), magnitudes[:, 0], favorite_probs[:, 0], favorite_tokens[:,0])]
    
    # If you want to get the html without showing it, use show.html(...)
    show(header_line + layer_logits + header_line)


An example.

In [None]:
# year_list = [str(y) for y in range(1800, 2023)]

f = make_corner_readout(
    model, tok,
    short_city_list, prenorm_scale=0.75)

show_logit_lens(model, tok, 'Brent Cross is a neighborhood in the city of', decoder=f)
#show_logit_lens(model, tok, 'Prudential Center is a mall in the city of', decoder=f)
show_logit_lens(model, tok, 'South of Houston Street is a neighborhood in the city of', decoder=f)
# show_logit_lens(model, tok, 'South of Houston Street is a neighborhood in the city of', decoder=f)
#show_logit_lens(model, tok, 'Eisenhower was born in')
#show_logit_lens(model, tok, 'Harrison Ford was born in', decoder=f)
#show_logit_lens(model, tok, 'Harrison Ford was born in')