In [1]:
import torch
import sys
sys.path.append('../..')
from transformers import GPTJForCausalLM, AutoTokenizer
import lre.models as models
import lre.functional as functional
import os

import json
import random
from lre.data import Relation, RelationSample, Sequence
import lre.metrics as metrics
import lre.functional as functional

device = "cuda:1"
weights = []
biases = []
subjects = []
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True)
model.to('cuda:1')
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer.pad_token = tokenizer.eos_token
mt = models.ModelAndTokenizer(model,tokenizer)

  from .autonotebook import tqdm as notebook_tqdm


In [20]:
#Let's still try averaging the w/b for each layer, because that seems the most intuitive.

def sample_weights_biases(subject, kind, i, samples) -> dict:
    layer_dict = {"i": i}
    weights = []
    biases = []
    wdir = f'qapprox/{subject}/{i}'
    weight_path = f"{wdir}/{kind}_weight_h_{i}.pt"
    bias_path = f"{wdir}/{kind}_bias_h_{i}.pt"
    #load s_s_weight and s_s_bias
    weight = torch.load(weight_path).to(device)
    bias = torch.load(bias_path).to(device)
    layer_dict[f'{kind}_weight'] = weight
    layer_dict[f'{kind}_bias'] = bias
    return layer_dict
    
def mean_weights_biases(kind, i, samples) -> dict:
    layer_dict = {"i": i}
    weights = []
    biases = []
    for sample in samples:
        wdir = f'qapprox/{sample}/{i}'
        weight_path = f"{wdir}/{kind}_weight_h_{i}.pt"
        bias_path = f"{wdir}/{kind}_bias_h_{i}.pt"
        #load s_s_weight and s_s_bias
        weight = torch.load(weight_path).to(device)
        bias = torch.load(bias_path).to(device)
        #print(f'weight is {tp(weight)}')
        #append to lists
        weights.append(weight)
        biases.append(bias)
    mean_weight = torch.stack(weights).mean(dim=0).to(device)
    mean_bias = torch.stack(biases).mean(dim=0).to(device)
    sum_weight = torch.stack(weights).sum(dim=0).to(device)
    sum_bias = torch.stack(biases).sum(dim=0).to(device)
    
    layer_dict[f'{kind}_mean_weight'] = mean_weight
    layer_dict[f'{kind}_mean_bias'] = mean_bias
    layer_dict[f'{kind}_sum_weight'] = sum_weight
    layer_dict[f'{kind}_sum_bias'] = sum_bias
    
    return layer_dict

In [21]:
from baukit.baukit import parameter_names, get_parameter

#returns weight and bias for lns
def get_layer_norm_params(model, start, end):
    layer_norm_params = {}
    for i in range(start, end):
        w_name = f'transformer.h.{i}.ln_1.weight'
        b_name = f'transformer.h.{i}.ln_1.bias'
        weight = get_parameter(model=model,name=w_name).data.to(device)
        bias = get_parameter(model=model,name=b_name).data.to(device)
        layer_norm_params[w_name] = weight.to(device)
        layer_norm_params[b_name] = bias.to(device)
        
    ln_f_w_name = 'transformer.ln_f.weight'
    ln_f_b_name = 'transformer.ln_f.bias'
    weight = get_parameter(model=model,name=ln_f_w_name).data.to(device)
    bias = get_parameter(model=model,name=ln_f_b_name).data.to(device)
    layer_norm_params[ln_f_w_name] = weight.to(device)
    layer_norm_params[ln_f_b_name] = bias.to(device)
    return layer_norm_params

#we should add 1 to the layer ct.
#layers 5-27 out of 0-27
params = get_layer_norm_params(model,4,28)

In [22]:
#mt.lm_head applies a linear map to get the token-space (50400)
# (1,4096) -layernorm, linear-> (1,50400) -softmax-> (1,50400) -topk-> (1,5)
def get_object(mt, z, k=5):
    logits = mt.lm_head(z)
    dist = torch.softmax(logits.float(), dim=-1)
    topk = dist.topk(k=k, dim=-1)
    probs = topk.values.view(5).tolist()
    token_ids = topk.indices.view(5).tolist()
    words = [mt.tokenizer.decode(token_id) for token_id in token_ids]
    return (words, probs)

In [184]:
def layer_norm(
    x: torch.Tensor, dim, eps: float = 0.00001
) -> torch.Tensor:
    mean = torch.mean(x, dim=dim, keepdim=True)
    var = torch.square(x - mean).mean(dim=dim, keepdim=True)
    return (x - mean) / torch.sqrt(var + eps)

In [378]:
S_O_LAYER = 15
START_LAYER = 5
END_LAYER = 27
beta = 2.75

In [404]:
def get_layer_dict(i):
    return next((item for item in layer_dicts if item['i'] == i), None)

def approx_s_s_layer(hs, i, beta=1):
    layer_dict = get_layer_dict(i)
    layer_weight = layer_dict['s_s_mean_weight']
    layer_bias = layer_dict['s_s_mean_bias']
    ln_weight = params[f'transformer.h.{i}.ln_1.weight']
    ln_bias = params[f'transformer.h.{i}.ln_1.bias']
    _hs = hs

    hs = layer_norm(hs, (1)) * ln_weight + ln_bias
    hs = beta * hs.mm(layer_weight.t()) #+ layer_bias
    #hs = hs + _hs
    
    return hs

def approx_s_o_layer(hs, i, beta=1):
    layer_dict = get_layer_dict(i)
    layer_weight = layer_dict['s_o_mean_weight']
    layer_bias = layer_dict['s_o_mean_bias']  
    ln_weight = params[f'transformer.h.{i}.ln_1.weight']
    ln_bias = params[f'transformer.h.{i}.ln_1.bias']
    _hs = hs
    hs = layer_norm(hs, (1)) * ln_weight + ln_bias
    #this transformation should encompass the work of the MHSA and MLP layer.
    hs = beta * hs.mm(layer_weight.t()) + layer_bias
    hs = hs + _hs
    
    return hs
    
def approx_o_o_layer(hs, i, beta=1):
    layer_dict = get_layer_dict(i)
    layer_weight = layer_dict['o_o_mean_weight']
    layer_bias = layer_dict['o_o_mean_bias']  
    ln_weight = params[f'transformer.h.{i}.ln_1.weight']
    ln_bias = params[f'transformer.h.{i}.ln_1.bias']
    _hs = hs
    
    hs = layer_norm(hs, (1)) * ln_weight + ln_bias
    hs = beta + hs.mm(layer_weight.t()) #+ layer_bias
    #hs = hs + _hs
    return hs

In [405]:
#now we want to do h' = h.mm(weight.t()) * beta + bias for each layer 5-26 (s -> s')
#then, finally z = h.mm(weight.t()) * beta + bias (s' -> o)

def tp(state: torch.Tensor):
    return state.cpu().detach().numpy()[0]

def approx_lm(hs, beta, beta_layer):
    
    #apply s_s
    for i in range(START_LAYER, S_O_LAYER):
        if i == beta_layer:
            hs = approx_s_s_layer(hs, i, beta)
        else:
            hs = approx_s_s_layer(hs,i)
        
    #apply s_o
    for i in range(S_O_LAYER, S_O_LAYER +1):
        if i == beta_layer:
            hs = approx_s_o_layer(hs, i, beta)
        else:
            hs = approx_s_o_layer(hs, i)
        
    #apply o_o
    for i in range(S_O_LAYER + 1, END_LAYER):
        if i == beta_layer:
            hs = approx_o_o_layer(hs, i, beta)
        else:
            hs = approx_o_o_layer(hs, i)
    
    ln_weight = params['transformer.ln_f.weight']
    ln_bias = params['transformer.ln_f.bias']
    #hs = layer_norm(hs, (1)) * ln_weight + ln_bias
        
    return hs

In [406]:
def get_hidden_state(mt, prompt, subject, h_layer):
    prompt = prompt.format(subject)
    h_index, inputs = functional.find_subject_token_index(
        mt = mt, prompt=prompt, subject=subject)
    #print(f'h_index is {h_index}, inputs is {inputs}')
    [[hs], _] = functional.compute_hidden_states(
        mt = mt, layers = [h_layer], inputs = inputs)
    #h is hs @ h_layer @ h_index
    h = hs[:, h_index]
    h = h.to(device)
    return h

In [407]:
# CREATE ResLRE
samples = ['ahead', 'backward', 'down', 'inbound', 'input', 'mortal', 'off', 'top']
#layers: 0-27
#Consistent with previous results, from the early layers s' seems to represent o.
layer_dicts = []
### S --> S'
for i in range(START_LAYER, S_O_LAYER):
    #layer_dict = sample_weights_biases(animal,"s_s", i, animals)
    layer_dict = mean_weights_biases("s_s", i, samples)
    layer_dicts.append(layer_dict)

#### S' --> O
for i in range(S_O_LAYER,S_O_LAYER+1):
    #layer_dict = sample_weights_biases(animal,"s_o", i, animals)
    layer_dict = mean_weights_biases("s_o", i, samples)
    layer_dicts.append(layer_dict)

### O --> O'
for i in range(S_O_LAYER+1, END_LAYER):
    #layer_dict = sample_weights_biases(animal,"o_o", i, animals)
    layer_dict = mean_weights_biases("o_o", i, samples)
    layer_dicts.append(layer_dict)

In [408]:
import lre.functional as functional
from importlib import reload
reload(functional)
import lre.functional as functional

In [None]:
#test the approximator
json_path = 'qapprox/antonyms-binary.json'

DEFAULT_N_ICL = 8 
N_TRIALS = 8

with open(json_path, 'r') as file:
    data = json.load(file)
    relation = Relation.from_dict(data)
    prompt = "The opposite of {} is"
    
    #counts_by_lre_correct: dict[bool, int] = defaultdict(int)
    prompt_template = relation.prompt_templates[0]
    #TEST LRE ON LM CORRECT
    beta = 2.75
    for beta_layer in range(5,27):
        print(f'{beta_layer=}')
        for _ in range(0,N_TRIALS):
            #RELATION SAMPLES
            clozed_prompts = []
            clozed_answers = []
            for x in relation.samples:
                samples = [x] + random.sample(relation.samples, DEFAULT_N_ICL - 1)
                #print(f'{samples} samples)')
                cloze_prompt = functional.make_prompt(
                    template = prompt_template, 
                    target = x,
                    examples = samples
                    )
                clozed_prompts.append(cloze_prompt)
                clozed_answers.append(x.object)
            #LM PREDICTION
            outputs_lm = functional.predict_next_token(mt=mt, prompt=clozed_prompts)
            preds_lm =  [[x.token for x in xs] for xs in outputs_lm]
            recall_lm = metrics.recall(preds_lm, clozed_answers)
            #print(recall_lm)
            lre_correct = 0
            lm_correct = 0
            for _, sample, objs, prompt, preds in zip(range(50), relation.samples, clozed_answers, clozed_prompts, preds_lm):
                if (metrics.any_is_nontrivial_prefix(
                    predictions=preds, 
                    targets=objs)):
                    #print(f'{sample.subject}', end='')
                    hs = get_hidden_state(mt, prompt, sample.subject, 1) #layer 5
                    object_hs = approx_lm(hs, beta, beta_layer)
                    lre_preds = get_object(mt, object_hs)[0]
                    #print(lre_preds, end='')
                    if(metrics.any_is_nontrivial_prefix(predictions=lre_preds, targets=objs)):
                        lre_correct += 1
                    lm_correct += 1
            print(f'{lre_correct} {lm_correct}')

beta_layer=5
13 47
14 47
12 47
14 48
13 47
14 47
15 47
13 48
beta_layer=6
14 48
13 46
14 47
12 47
13 47
14 48
14 47
15 47
beta_layer=7
15 47
13 47
15 48
14 48
13 46
14 48
13 47
13 47
beta_layer=8
14 47
15 47
13 48
15 48
15 47
14 46
15 48
13 47
beta_layer=9
13 47
13 47
13 46
14 47
15 47
14 47
14 48
14 46
beta_layer=10
14 48
13 47
14 48
14 47
13 47
15 45
14 47
14 48
beta_layer=11
11 47
15 47
14 47
13 46
13 46
13 47
12 47
14 47
beta_layer=12
12 47
14 48
15 47
13 47
14 48
15 47
14 48
13 48
beta_layer=13
13 48
14 46
13 47
15 47
14 48
14 47
13 48
13 47
beta_layer=14
2 46
2 47
2 47
2 48
2 47
2 47
2 47
2 47
beta_layer=15
25 48
24 47
24 47
24 47
24 46
24 47
25 48
25 46
beta_layer=16
13 47
15 46
12 47
15 47
13 48
15 47
14 47
13 47
beta_layer=17
15 46
13 47
14 47
13 48
13 47
14 47
11 48
14 48
beta_layer=18
15 47
13 47
15 48
14 47


In [54]:
#for most relations.
def is_nontrivial_prefix(prediction: str, target: str) -> bool:
    target = target.lower().strip()
    prediction = prediction.lower().strip()
    # if len(prediction) > 0 and target.startswith(prediction):
    #     print(f"{prediction} matches {target}")
    return len(prediction) > 1 and target.startswith(prediction)

def any_is_nontrivial_prefix(prediction, targets) -> bool:
    return any(is_nontrivial_prefix(prediction, target) for target in targets)

In [55]:
for (subj, obj) in pairs:
    for beta in range(10,50, 1):
        beta /= 10
        hs = get_hidden_state(mt, subj, 5) #layer 5
        object_hs = approx_lm(hs, beta) #beta
        pred = get_object(mt, object_hs)[0]
        if (any_is_nontrivial_prefix(pred[0], obj)):
            print(f"{subj} matches {pred[0]}: {beta}")
            break
    
# for (subj, obj) in pairs:
#     hs = get_hidden_state(mt, subj, 5)
#     object_hs = approx_lm(hs, 2.4)
#     print(f'{subj}: {get_object(mt, object_hs)[0]} {obj}')

NameError: name 'pairs' is not defined

In [None]:
#get tokens in GPT-J
#get the hidden state of them at the last layer (after the 28th layer, or s->o @ 27)
import pickle
from tqdm import tqdm

def get_hidden_state(mt, subject, h_layer, h=None, k=5):
    prompt = f" {subject}"
    h_index, inputs = functional.find_subject_token_index(
        mt = mt, prompt=prompt, subject=subject)
    #print(f'h_index is {h_index}, inputs is {inputs}')
    [[hs], _] = functional.compute_hidden_states(
        mt = mt, layers = [h_layer], inputs = inputs)
    #h is hs @ h_layer @ h_index
    h = hs[:, h_index]
    h = h.to(device)
    return h
    
#Spaces are converted in a special character (the Ġ ) in the tokenizer prior to BPE splitting
#mostly to avoid digesting spaces since the standard BPE algorithm used spaces in its process 

#all animal encodings are at [-0.4153   2.023   -2.23    ... -0.785    0.06323 -0.1819 ]

text = "our classic pre-baked blueberry pie filled with delicious plump and juicy wild blueberries"
encoded_input = mt.tokenizer(text, return_tensors="pt")
token_ids = range(0,50400)
tokens = tokenizer.convert_ids_to_tokens(token_ids)
tokens = [token.replace("Ġ", " ") for token in tokens]

#this is too slow and not useful.
dict27 = {}
for i in tqdm(range(len(tokens))):
    token = tokens[i]
    dict27[token] = get_hidden_state(mt, token, 27)
    
with open('animal_youth_27.pkl', 'wb') as file:
    pickle.dump(dict27, file)

In [211]:
mt.lm_head

Sequential(
  (0): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
  (1): Linear(in_features=4096, out_features=50400, bias=True)
)

In [8]:
animals = ["dog", "duck", "fish", "horse", "mink", "seal", "shark", "trout"]

[' puppy', ' pup', ' p', ' dog', ' �']
[' duck', ' dra', ' g', ' �', ' "']
[' fry', ' prog', ' F', ' �', ' lar']
[' fo', ' col', ' horse', ' pony', ' �']
[' kit', ' m', ' "', ' �', ' p']
[' seal', ' "', ' �', ' pup', ' p']
[' shark', ' "', ' �', ' p', ' pup']
[' fry', ' trout', ' "', ' �', ' rainbow']
