In [2]:
import torch
import sys
sys.path.append('../..')
from transformers import GPTJForCausalLM, AutoTokenizer
import lre.models as models
import lre.functional as functional
import os

device = "cuda:1"
weights = []
biases = []
subjects = []

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
#Let's try averaging the w/b for each layer, because that seems the most intuitive.
#First attempt: s --> s for 5-26, s --> o for 26-27.

animals = ["dog", "duck", "fish", "horse", "mink", "seal", "shark", "trout"]

weight_str = 's_s_weight_'
bias_str = 's_s_bias_'

def sample_weights_biases(subject, kind, i, samples) -> dict:
    layer_dict = {"i": i}
    weights = []
    biases = []
    wdir = subject
    weight_path = f"{wdir}/{kind}_weight_h_{i}.pt"
    bias_path = f"{wdir}/{kind}_bias_h_{i}.pt"
    #load s_s_weight and s_s_bias
    weight = torch.load(weight_path).to(device)
    bias = torch.load(bias_path).to(device)
    layer_dict[f'{kind}_weight'] = weight
    layer_dict[f'{kind}_bias'] = bias
    return layer_dict
    
def mean_weights_biases(kind, i, samples) -> dict:
    layer_dict = {"i": i}
    weights = []
    biases = []
    for sample in samples:
        wdir = sample
        weight_path = f"{wdir}/{kind}_weight_h_{i}.pt"
        bias_path = f"{wdir}/{kind}_bias_h_{i}.pt"
        #load s_s_weight and s_s_bias
        weight = torch.load(weight_path).to(device)
        bias = torch.load(bias_path).to(device)
        #append to lists
        weights.append(weight)
        biases.append(bias)
    mean_weight = torch.stack(weights).mean(dim=0).to(device)
    mean_bias = torch.stack(biases).mean(dim=0).to(device)
    layer_dict[f'{kind}_weight'] = mean_weight
    layer_dict[f'{kind}_bias'] = mean_bias
    return layer_dict


In [4]:
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True)
model.to('cuda:1')
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer.pad_token = tokenizer.eos_token
mt = models.ModelAndTokenizer(model,tokenizer)

In [5]:
from baukit.baukit import parameter_names, get_parameter

#returns weight and bias for lns
def get_layer_norm_params(model, start, end):
    layer_norm_params = {}
    for i in range(start, end):
        w_name = f'transformer.h.{i}.ln_1.weight'
        b_name = f'transformer.h.{i}.ln_1.bias'
        weight = get_parameter(model=model,name=w_name).data.to(device)
        bias = get_parameter(model=model,name=b_name).data.to(device)
        layer_norm_params[w_name] = weight.to(device)
        layer_norm_params[b_name] = bias.to(device)
    return layer_norm_params

#we should add 1 to the layer ct.
#layers 5-27 out of 0-27
params = get_layer_norm_params(model,5,28)

In [6]:
#lm_head applies LayerNorm and then a linear map to get the token-space (50400)
# (1,4096) -layernorm, linear-> (1,50400) -softmax-> (1,50400) -topk-> (1,5)
def get_object(mt, z, k=5):
    logits = mt.lm_head(z)
    dist = torch.softmax(logits.float(), dim=-1)
    topk = dist.topk(k=k, dim=-1)
    probs = topk.values.view(5).tolist()
    token_ids = topk.indices.view(5).tolist()
    words = [mt.tokenizer.decode(token_id) for token_id in token_ids]
    return (words, probs)

In [7]:
def layer_norm(
    x: torch.Tensor, dim, eps: float = 0.00001
) -> torch.Tensor:
    mean = torch.mean(x, dim=dim, keepdim=True)
    var = torch.square(x - mean).mean(dim=dim, keepdim=True)
    return (x - mean) / torch.sqrt(var + eps)

In [11]:
def approx_s_s_layer(hs, i):
    layer_dict = next((item for item in layer_dicts if item['i'] == i), None)
    layer_weight, layer_bias = layer_dict['s_s_weight'],layer_dict['s_s_bias']
    ln_weight = params[f'transformer.h.{i}.ln_1.weight']
    ln_bias = params[f'transformer.h.{i}.ln_1.bias']
    _hs = hs
    
    #perform layer normalization with adaptive w and b
    hs = layer_norm(hs, (1)) * ln_weight + ln_bias
    
    #perform the layer operation
    hs = hs.mm(layer_weight.t()) + layer_bias
        
    #add residual
    #hs = hs + _hs
    return hs

def approx_s_o_layer(hs, i):
    layer_dict = next((item for item in layer_dicts if item['i'] == i), None)
    layer_weight, layer_bias = layer_dict['s_o_weight'],layer_dict['s_o_bias']  
    ln_weight = params[f'transformer.h.{i}.ln_1.weight']
    ln_bias = params[f'transformer.h.{i}.ln_1.bias']
    _hs = hs
    
    #perform layer normalization with adaptive w and b
    hs = layer_norm(hs, (1)) * ln_weight + ln_bias
    
    #perform the layer operation
    hs = hs.mm(layer_weight.t()) + layer_bias
    
    #add residual
    #hs = hs + _hs
    return hs
    
def approx_o_o_layer(hs, i):
    layer_dict = next((item for item in layer_dicts if item['i'] == i), None)
    layer_weight, layer_bias = layer_dict['o_o_weight'],layer_dict['o_o_bias']  
    ln_weight = params[f'transformer.h.{i}.ln_1.weight']
    ln_bias = params[f'transformer.h.{i}.ln_1.bias']
    _hs = hs
    
    #perform layer normalization with adaptive w and b
    hs = layer_norm(hs, (1)) * ln_weight + ln_bias
    
    #perform the layer operation
    hs = hs.mm(layer_weight.t()) + layer_bias
    
    #add residual
    #hs = hs + _hs
    return hs

In [12]:
#now we want to do h' = h.mm(weight.t()) * beta + bias for each layer 5-26 (s -> s')
#then, finally z = h.mm(weight.t()) * beta + bias (s' -> o)

def tp(state: torch.Tensor):
    return state.cpu().detach().numpy()[0]

def approx_lm(hs):
    
    #apply s_s
    for i in range(START_LAYER, S_O_LAYER):
        hs = approx_s_s_layer(hs,i)

    #apply s_o
    for i in range(S_O_LAYER, S_O_LAYER +1):
        hs = approx_s_o_layer(hs, i)
        
    #apply o_o
    for i in range(S_O_LAYER + 1, END_LAYER):
        hs = approx_o_o_layer(hs, i)
        
    return hs

In [13]:
#start with loading h @ 5 (@ subject index)
animals = ["dog", "duck", "fish", "horse", "mink", "seal", "shark", "trout"]

#layers: 0-27
START_LAYER = 5
S_O_LAYER = 26 #TODO: generate s_o_weight_27
END_LAYER = 27

#we can reconstruct the LM predictions perfectly with the jacobians of particular animals.
for animal in animals:
    layer_dicts = []
    ### S --> S'
    for i in range(START_LAYER, S_O_LAYER):
        layer_dict = sample_weights_biases(animal,"s_s", i, animals)
        #layer_dict = mean_weights_biases("s_s", i, animals)
        layer_dicts.append(layer_dict)
    
    #### S' --> O
    #should be referring to the 27 out of 0-27 (last) GPTJBlock
    for i in range(S_O_LAYER,S_O_LAYER+1):
        layer_dict = sample_weights_biases(animal,"s_o", i, animals)
        #layer_dict = mean_weights_biases("s_o", i, animals)
        layer_dicts.append(layer_dict)
    
    ### O --> O'
    for i in range(S_O_LAYER+1, END_LAYER):
        layer_dict = sample_weights_biases(animal,"o_o", i, animals)
        #layer_dict = mean_weights_biases("o_o", i, animals)
        layer_dicts.append(layer_dict)
    
    animal_hs = torch.load(f'{animal}/hs_h_{START_LAYER}.pt').to(device)[None]
    object_hs = approx_lm(animal_hs)
    #print(torch.linalg.norm(object_hs, dim=1, ord=2).cpu().detach().numpy())
    print(f'{animal}: {get_object(mt, object_hs)[0]}')

dog: [' puppy', ' pup', ' p', ' whe', ' dog']
duck: [' duck', ' g', ' dra', ' chick', ' hatch']
fish: [' fry', ' finger', ' lar', ' baby', ' young']
horse: [' fo', ' col', ' fill', ' pony', ' baby']
mink: [' kit', ' p', ' pup', ' kits', ' cub']
seal: [' pup', ' p', ' puppy', ' baby', ' seal']
shark: [' pup', ' baby', ' shark', ' young', ' kit']
trout: [' fry', ' finger', ' lar', ' baby', ' trout']


In [None]:
#get tokens in GPT-J
#get the hidden state of them at the last layer (after the 28th layer, or s->o @ 27)
import pickle
from tqdm import tqdm

def get_hidden_state(mt, subject, h_layer, h=None, k=5):
    prompt = f" {subject}"
    h_index, inputs = functional.find_subject_token_index(
        mt = mt, prompt=prompt, subject=subject)
    #print(f'h_index is {h_index}, inputs is {inputs}')
    [[hs], _] = functional.compute_hidden_states(
        mt = mt, layers = [h_layer], inputs = inputs)
    #h is hs @ h_layer @ h_index
    h = hs[:, h_index]
    h = h.to(device)
    return h
    
#Spaces are converted in a special character (the Ġ ) in the tokenizer prior to BPE splitting
#mostly to avoid digesting spaces since the standard BPE algorithm used spaces in its process 

#all animal encodings are at [-0.4153   2.023   -2.23    ... -0.785    0.06323 -0.1819 ]

text = "our classic pre-baked blueberry pie filled with delicious plump and juicy wild blueberries"
encoded_input = mt.tokenizer(text, return_tensors="pt")
token_ids = range(0,50400)
tokens = tokenizer.convert_ids_to_tokens(token_ids)
tokens = [token.replace("Ġ", " ") for token in tokens]

#this is too slow and not useful.
dict27 = {}
for i in tqdm(range(len(tokens))):
    token = tokens[i]
    dict27[token] = get_hidden_state(mt, token, 27)
    
with open('animal_youth_27.pkl', 'wb') as file:
    pickle.dump(dict27, file)

In [211]:
mt.lm_head

Sequential(
  (0): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
  (1): Linear(in_features=4096, out_features=50400, bias=True)
)