### import torch
import sys
sys.path.append('../..')
from transformers import GPTJForCausalLM, AutoTokenizer
import lre.models as models
import lre.functional as functional
import os

device = "cuda:1"
weights = []
biases = []
subjects = []
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True)
model.to('cuda:1')
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer.pad_token = tokenizer.eos_token
mt = models.ModelAndTokenizer(model,tokenizer)

In [2]:
#Let's try averaging the w/b for each layer, because that seems the most intuitive.
#First attempt: s --> s for 5-26, s --> o for 26-27.

animals = ["dog", "duck", "fish", "horse", "mink", "seal", "shark", "trout"]

weight_str = 's_s_weight_'
bias_str = 's_s_bias_'

def sample_weights_biases(subject, kind, i, samples) -> dict:
    layer_dict = {"i": i}
    weights = []
    biases = []
    wdir = subject
    weight_path = f"{wdir}/{kind}_weight_h_{i}.pt"
    bias_path = f"{wdir}/{kind}_bias_h_{i}.pt"
    #load s_s_weight and s_s_bias
    weight = torch.load(weight_path).to(device)
    bias = torch.load(bias_path).to(device)
    layer_dict[f'{kind}_weight'] = weight
    layer_dict[f'{kind}_bias'] = bias
    return layer_dict
    
def mean_weights_biases(kind, i, samples) -> dict:
    layer_dict = {"i": i}
    weights = []
    biases = []
    for sample in samples:
        wdir = sample
        weight_path = f"{wdir}/{kind}_weight_h_{i}.pt"
        bias_path = f"{wdir}/{kind}_bias_h_{i}.pt"
        #load s_s_weight and s_s_bias
        weight = torch.load(weight_path).to(device)
        bias = torch.load(bias_path).to(device)
        #print(f'weight is {tp(weight)}')
        #append to lists
        weights.append(weight)
        biases.append(bias)
    mean_weight = torch.stack(weights).mean(dim=0).to(device)
    mean_bias = torch.stack(biases).mean(dim=0).to(device)
    sum_weight = torch.stack(weights).sum(dim=0).to(device)
    sum_bias = torch.stack(biases).sum(dim=0).to(device)
    
    layer_dict[f'{kind}_mean_weight'] = mean_weight
    layer_dict[f'{kind}_mean_bias'] = mean_bias
    layer_dict[f'{kind}_sum_weight'] = sum_weight
    layer_dict[f'{kind}_sum_bias'] = sum_bias
    
    return layer_dict

In [3]:
from baukit.baukit import parameter_names, get_parameter

#returns weight and bias for lns
def get_layer_norm_params(model, start, end):
    layer_norm_params = {}
    for i in range(start, end):
        w_name = f'transformer.h.{i}.ln_1.weight'
        b_name = f'transformer.h.{i}.ln_1.bias'
        weight = get_parameter(model=model,name=w_name).data.to(device)
        bias = get_parameter(model=model,name=b_name).data.to(device)
        layer_norm_params[w_name] = weight.to(device)
        layer_norm_params[b_name] = bias.to(device)
    return layer_norm_params

#we should add 1 to the layer ct.
#layers 5-27 out of 0-27
params = get_layer_norm_params(model,4,28)

In [4]:
#lm_head applies LayerNorm and then a linear map to get the token-space (50400)
# (1,4096) -layernorm, linear-> (1,50400) -softmax-> (1,50400) -topk-> (1,5)
def get_object(mt, z, k=5):
    logits = mt.lm_head(z)
    dist = torch.softmax(logits.float(), dim=-1)
    topk = dist.topk(k=k, dim=-1)
    probs = topk.values.view(5).tolist()
    token_ids = topk.indices.view(5).tolist()
    words = [mt.tokenizer.decode(token_id) for token_id in token_ids]
    return (words, probs)

In [5]:
def layer_norm(
    x: torch.Tensor, dim, eps: float = 0.00001
) -> torch.Tensor:
    mean = torch.mean(x, dim=dim, keepdim=True)
    var = torch.square(x - mean).mean(dim=dim, keepdim=True)
    return (x - mean) / torch.sqrt(var + eps)

In [62]:
def get_layer_dict(i):
    return next((item for item in layer_dicts if item['i'] == i), None)
    
def approx_s_s_layer(hs, i):
    layer_dict = get_layer_dict(i)
    layer_weight = layer_dict['s_s_mean_weight']
    layer_bias = layer_dict['s_s_mean_bias']
    ln_weight = params[f'transformer.h.{i}.ln_1.weight']
    ln_bias = params[f'transformer.h.{i}.ln_1.bias']
    _hs = hs
    
    #perform layer normalization with adaptive w and b
    hs = layer_norm(hs, (1)) * ln_weight + ln_bias
    
    #perform the layer operation
    hs = hs.mm(layer_weight.t()) #+ layer_bias
        
    #add residual
    #hs = hs + _hs
    return hs

def approx_s_o_layer(hs, i):
    layer_dict = get_layer_dict(i)
    layer_weight = layer_dict['s_o_mean_weight']
    layer_bias = layer_dict['s_o_mean_bias']  
    ln_weight = params[f'transformer.h.{i}.ln_1.weight']
    ln_bias = params[f'transformer.h.{i}.ln_1.bias']
    _hs = hs
    
    #perform layer normalization with adaptive w and b
    hs = layer_norm(hs, (1)) * ln_weight + ln_bias
    
    #perform the layer operation
    hs = hs.mm(layer_weight.t()) + layer_bias
    
    #add residual
    hs = hs + _hs * 3.25
    return hs
    
def approx_o_o_layer(hs, i):
    layer_dict = get_layer_dict(i)
    layer_weight = layer_dict['o_o_mean_weight']
    layer_bias = layer_dict['o_o_mean_bias']  
    ln_weight = params[f'transformer.h.{i-1}.ln_1.weight']
    ln_bias = params[f'transformer.h.{i-1}.ln_1.bias']
    _hs = hs
    
    #perform layer normalization with adaptive w and b
    hs = layer_norm(hs, (1)) * ln_weight + ln_bias
    
    #perform the layer operation
    hs = hs.mm(layer_weight.t()) + layer_bias
    
    #add residual
    hs = hs + _hs
    return hs

In [63]:
#now we want to do h' = h.mm(weight.t()) * beta + bias for each layer 5-26 (s -> s')
#then, finally z = h.mm(weight.t()) * beta + bias (s' -> o)

def tp(state: torch.Tensor):
    return state.cpu().detach().numpy()[0]

def approx_lm(hs):
    
    #apply s_s
    for i in range(START_LAYER, S_O_LAYER):
        hs = approx_s_s_layer(hs,i)

    #apply s_o
    for i in range(S_O_LAYER, S_O_LAYER +1):
        hs = approx_s_o_layer(hs, i)
        
    #apply o_o
    for i in range(S_O_LAYER + 1, END_LAYER):
        hs = approx_o_o_layer(hs, i)
        
    return hs

In [64]:
# CREATE ResLRE

animals = ["dog", "duck", "fish", "horse", "mink", "seal", "shark", "trout"]
#layers: 0-27
S_O_LAYER = 26
START_LAYER = 5
#S_O_LAYER = 7 #TODO: generate s_o_weight_27
END_LAYER = S_O_LAYER + 1

#Consistent with previous results, from the early layers s' seems to represent o.
layer_dicts = []
### S --> S'
for i in range(START_LAYER, S_O_LAYER):
    #layer_dict = sample_weights_biases(animal,"s_s", i, animals)
    layer_dict = mean_weights_biases("s_s", i, animals)
    layer_dicts.append(layer_dict)

#### S' --> O
for i in range(S_O_LAYER,S_O_LAYER+1):
    #layer_dict = sample_weights_biases(animal,"s_o", i, animals)
    layer_dict = mean_weights_biases("s_o", i, animals)
    layer_dicts.append(layer_dict)

### O --> O'
for i in range(S_O_LAYER+1, END_LAYER):
    #layer_dict = sample_weights_biases(animal,"o_o", i, animals)
    layer_dict = mean_weights_biases("o_o", i, animals)
    layer_dicts.append(layer_dict)

for (subj, obj) in pairs:
    hs = get_hidden_state(mt, subj, 5) #layer 5
    object_hs = approx_lm(hs)
    print(f'{subj}: {get_object(mt, object_hs)[0]}')

ape: [' jo', ' kid', ' wee', ' c', ' fo']
badger: [' cub', ' pup', ' kit', ' p', ' kid']
bear: [' cub', ' pup', ' c', ' t', ' paw']
beaver: [' pup', ' p', ' kit', ' cub', ' kid']
bee: [' pup', ' lar', ' c', ' p', ' nest']
beetle: [' lar', ' pup', ' p', ' gr', ' larvae']
buffalo: [' calf', ' kid', ' herd', ' c', ' fo']
butterfly: [' lar', ' wing', ' chick', ' flying', ' p']
camel: [' fo', ' calf', ' kid', ' toe', ' fetus']
cat: [' p', ' kitten', ' litter', ' c', ' jo']
cattle: [' calf', ' fo', ' jo', ' herd', ' tick']
chimpanzee: [' jo', ' fo', ' c', ' baby', ' p']
cicada: [' prince', ' bite', ' p', ' lar', ' baby']
cockroach: [' p', ' lar', ' z', ' baby', ' tad']
cricket: [' lar', ' hopping', ' frog', 'ing', ' tad']
deer: [' fo', ' calf', ' kid', ' f', ' herd']
dog: [' puppy', ' k', ' litter', ' whe', ' p']
duck: ['ling', ' kid', ' fetus', ' fo', 'lings']
elephant: [' calf', ' fo', ' seal', ' t', ' baby']
ferret: [' kit', ' p', ' jo', ' litter', ' pup']
fish: [' fry', ' c', ' n', ' flo

### Compare ResLRE to LRE

In [23]:
def get_hidden_state(mt, subject, h_layer):
    prompt = f"The offspring of a {subject} is referred to as a"
    h_index, inputs = functional.find_subject_token_index(
        mt = mt, prompt=prompt, subject=subject)
    #print(f'h_index is {h_index}, inputs is {inputs}')
    [[hs], _] = functional.compute_hidden_states(
        mt = mt, layers = [h_layer], inputs = inputs)
    #h is hs @ h_layer @ h_index
    h = hs[:, h_index]
    h = h.to(device)
    return h

In [24]:
import json
json_path = 'animal-youth.json'
pairs = []
with open(json_path, 'r') as file:
    data = json.load(file)
    for pair in data['samples']:
        pairs.append((pair['subject'],pair['object']))

In [29]:
#for most relations.
def is_nontrivial_prefix(prediction: str, target: str) -> bool:
    target = target.lower().strip()
    prediction = prediction.lower().strip()
    # if len(prediction) > 0 and target.startswith(prediction):
    #     print(f"{prediction} matches {target}")
    return len(prediction) > 1 and target.startswith(prediction)

def any_is_nontrivial_prefix(prediction, targets) -> bool:
    return any(is_nontrivial_prefix(prediction, target) for target in targets)

In [35]:
for (subj, obj) in pairs:
    for beta in range(10,50, 1):
        beta /= 10
        hs = get_hidden_state(mt, subj, 5) #layer 5
        object_hs = approx_lm(hs, beta) #beta
        pred = get_object(mt, object_hs)[0]
        if (any_is_nontrivial_prefix(pred[0], obj)):
            print(f"{subj} matches {pred[0]}: {beta}")
            break
    
# for (subj, obj) in pairs:
#     hs = get_hidden_state(mt, subj, 5)
#     object_hs = approx_lm(hs, 2.4)
#     print(f'{subj}: {get_object(mt, object_hs)[0]} {obj}')

ape matches  baby: 1.0
bear matches  cub: 1.2
bee matches  lar: 3.7
beetle matches  lar: 1.9
buffalo matches  calf: 1.2
butterfly matches  pup: 1.0
cattle matches  calf: 1.6
chimpanzee matches  baby: 1.0
cricket matches  lar: 1.7
dog matches  pup: 1.0
elephant matches  calf: 1.4
ferret matches  kit: 2.9
fish matches  fry: 1.2
fly matches  mag: 3.8
fox matches  pup: 1.0
goat matches  kid: 1.0
goldfish matches  fry: 1.0
horse matches  fo: 1.0
insect matches  lar: 1.7
lion matches  cub: 1.0
mink matches  kit: 1.4
panda matches  cub: 1.7
raccoon matches  kit: 1.7
seal matches  pup: 1.0
shark matches  pup: 1.0
tiger matches  cub: 1.0
trout matches  finger: 4.9
whale matches  calf: 1.8
wolf matches  pup: 1.0


In [None]:
#get tokens in GPT-J
#get the hidden state of them at the last layer (after the 28th layer, or s->o @ 27)
import pickle
from tqdm import tqdm

def get_hidden_state(mt, subject, h_layer, h=None, k=5):
    prompt = f" {subject}"
    h_index, inputs = functional.find_subject_token_index(
        mt = mt, prompt=prompt, subject=subject)
    #print(f'h_index is {h_index}, inputs is {inputs}')
    [[hs], _] = functional.compute_hidden_states(
        mt = mt, layers = [h_layer], inputs = inputs)
    #h is hs @ h_layer @ h_index
    h = hs[:, h_index]
    h = h.to(device)
    return h
    
#Spaces are converted in a special character (the Ġ ) in the tokenizer prior to BPE splitting
#mostly to avoid digesting spaces since the standard BPE algorithm used spaces in its process 

#all animal encodings are at [-0.4153   2.023   -2.23    ... -0.785    0.06323 -0.1819 ]

text = "our classic pre-baked blueberry pie filled with delicious plump and juicy wild blueberries"
encoded_input = mt.tokenizer(text, return_tensors="pt")
token_ids = range(0,50400)
tokens = tokenizer.convert_ids_to_tokens(token_ids)
tokens = [token.replace("Ġ", " ") for token in tokens]

#this is too slow and not useful.
dict27 = {}
for i in tqdm(range(len(tokens))):
    token = tokens[i]
    dict27[token] = get_hidden_state(mt, token, 27)
    
with open('animal_youth_27.pkl', 'wb') as file:
    pickle.dump(dict27, file)

In [211]:
mt.lm_head

Sequential(
  (0): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
  (1): Linear(in_features=4096, out_features=50400, bias=True)
)