In [1]:
import torch
import sys
sys.path.append('..')
from transformers import GPTJForCausalLM, AutoTokenizer
import lre.models as models
import lre.functional as functional
import os

import json
import random
from lre.data import Relation, RelationSample, Sequence
import lre.metrics as metrics
import lre.functional as functional

device = 'cuda:0'
weights = []
biases = []
subjects = []
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True)
model.to(device)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer.pad_token = tokenizer.eos_token
mt = models.ModelAndTokenizer(model,tokenizer)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#OBTAIN LRE WEIGHTS: Method 1
#Load weights saved with new naming convention

def mean_weight_or_bias(wdir, kind, samples):
    weights = []
    for sample in samples:
        weight_path = f"{wdir}/{sample}/{kind}.pt"
        #load s_s_weight and s_s_bias
        weight = torch.load(weight_path).to(device)
        #append to lists
        weights.append(weight)
    mean_weight = torch.stack(weights).mean(dim=0).to(device)
    return mean_weight
    
samples = ["ahead", "backward", "down", "inbound", "input", "mortal", "off", "top"]
lre_weight = mean_weight_or_bias('qapprox', 's_o_weight_5_27', samples)
lre_bias = mean_weight_or_bias('qapprox', 's_o_bias_5_27', samples)

In [2]:
#OBTAIN LRE WEIGHTS: Method 2
#Method 2: An LRE loading script for the older weights in 'approx' & 'wapprox'

#verb + er

weight_str = 'meronyms - member_weight'
bias_str = 'meronyms - member_bias'
wdir = "approx/meronyms - member"

weight_paths = [f for f in os.listdir(wdir) if f.startswith(weight_str)]
bias_paths = [f for f in os.listdir(wdir) if f.startswith(bias_str)]

for weight_str, bias_str in zip(weight_paths, bias_paths):
    weight = torch.load(wdir + "/" + weight_str).to(device)
    bias = torch.load(wdir + "/" + bias_str).to(device)
    weights.append(weight)
    biases.append(bias)
    
lre_weight = torch.stack(weights).mean(dim=0).to(device)
lre_bias = torch.stack(biases).mean(dim=0).to(device)
weight_paths

['meronyms - member_weight_3_140707.pt',
 'meronyms - member_weight_5_140715.pt',
 'meronyms - member_weight_7_140723.pt',
 'meronyms - member_weight_6_140719.pt',
 'meronyms - member_weight_1_140659.pt',
 'meronyms - member_weight_2_140703.pt',
 'meronyms - member_weight_0_140655.pt',
 'meronyms - member_weight_4_140711.pt']

In [12]:
import llra.build as build
from importlib import reload
reload(build)
build.determine_device(mt)
build.determine_params(mt)
S_O_start = 21
S_O_end = 26
start, end = 7, 27
beta = 2.75
wdir = 'capprox/COMPLETE grads/0member-team'

In [13]:
# get relevant weights and biases from layers 0-27
llra_samples = os.listdir(wdir)
# clre_samples = ["begin", "compose", "determine", "examine", "intrude", "recommend", "speak", "tell"]
# assert(len(clre_samples) == 8)

json_path = 'json/lexsem/L05 [meronyms - member].json'
DEFAULT_N_ICL = 8 
N_TRIALS = 8
VIEW_SAMPLES = 5

file = open(json_path, 'r')
data = json.load(file)
file.close()

beta = 1
relation = Relation.from_dict(data)
prompt_template = relation.prompt_templates[0]
#ASSEMBLE PROMPTS AND OBJECT ANSWERS
clozed_prompts = []
clozed_answers = []

for x in relation.samples:
    samples = [x] + random.sample(relation.samples, DEFAULT_N_ICL - 1)
    #print(f'{samples} samples)')
    cloze_prompt = functional.make_prompt(
        template = prompt_template, 
        target = x,
        examples = samples
        )
    clozed_prompts.append(cloze_prompt)
    clozed_answers.append(x.object)

outputs_lm = functional.predict_next_token(mt=mt, prompt=clozed_prompts)
preds_lm =  [[x.token for x in xs] for xs in outputs_lm]
recall_lm = metrics.recall(preds_lm, clozed_answers)

lre_correct = 0
llra_correct = 0
lm_correct = 0

print(f'{start=} {S_O_start=} {S_O_end=} {end=}')
layer_dicts = build.build_llra(wdir, llra_samples, start,S_O_start, S_O_end, end)
#print(layer_dicts)
llra = build.LLRA(layer_dicts=layer_dicts)

for i, sample, objs, prompt, preds in \
zip(range(0,50), relation.samples, clozed_answers, clozed_prompts, preds_lm):
    
    if (metrics.any_is_nontrivial_prefix(predictions=preds, targets=objs)):
        hs = build.get_hidden_state(mt, prompt, sample.subject, start) #layer 5
        #use the complete LRE
        llra_object_hs = llra.approx_lm(hs, (S_O_start, S_O_end))
        llra_preds = build.get_object(mt, llra_object_hs)[0]
        
        #use the simple LRE
        lre_object_hs = hs.mm(lre_weight.t()) * beta + lre_bias
        lre_preds = build.get_object(mt, lre_object_hs)[0]
        
        if(metrics.any_is_nontrivial_prefix(predictions=llra_preds, targets=objs)):
            llra_correct += 1

        if(metrics.any_is_nontrivial_prefix(predictions=lre_preds, targets=objs)):
            lre_correct += 1
            
        if(i < VIEW_SAMPLES):
            print(f'{sample.subject} {preds[0]} {llra_preds}')
            
        lm_correct += 1
print(f'S_O_START,{S_O_start},S_O_END,{S_O_end},beta,{beta},clre,{llra_correct},lre,{lre_correct},lm,{lm_correct}')

start=7 S_O_start=21 S_O_end=26 end=27
calling mwb with 21
constructed layer_dict for 21
calling mwb with 22
constructed layer_dict for 22
calling mwb with 23
constructed layer_dict for 23
calling mwb with 24
constructed layer_dict for 24
calling mwb with 25
constructed layer_dict for 25
calling mwb with 7
constructed layer_dict for 7
calling mwb with 8
constructed layer_dict for 8
calling mwb with 9
constructed layer_dict for 9
calling mwb with 10
constructed layer_dict for 10
calling mwb with 11
constructed layer_dict for 11
calling mwb with 12
constructed layer_dict for 12
calling mwb with 13
constructed layer_dict for 13
calling mwb with 14
constructed layer_dict for 14
calling mwb with 15
constructed layer_dict for 15
calling mwb with 16
constructed layer_dict for 16
calling mwb with 17
constructed layer_dict for 17
calling mwb with 18
constructed layer_dict for 18
calling mwb with 19
constructed layer_dict for 19
calling mwb with 20
constructed layer_dict for 20
calling mwb with 

KeyError: 's_s_mean_weight'

In [None]:
### LRE BETA INSERTION POSITION
json_path = 'qapprox/antonyms-binary.json'

DEFAULT_N_ICL = 8 
N_TRIALS = 8

with open(json_path, 'r') as file:
    data = json.load(file)
    relation = Relation.from_dict(data)
    prompt = "The opposite of {} is"
    prompt_template = relation.prompt_templates[0]
    beta = 2.75
    for beta_layer in range(5,27):
        print(f'{beta_layer=}')
        for _ in range(0,N_TRIALS):
            clozed_prompts = []
            clozed_answers = []
            for x in relation.samples:
                samples = [x] + random.sample(relation.samples, DEFAULT_N_ICL - 1)
                cloze_prompt = functional.make_prompt(
                    template = prompt_template, 
                    target = x,
                    examples = samples
                    )
                clozed_prompts.append(cloze_prompt)
                clozed_answers.append(x.object)
            outputs_lm = functional.predict_next_token(mt=mt, prompt=clozed_prompts)
            preds_lm =  [[x.token for x in xs] for xs in outputs_lm]
            recall_lm = metrics.recall(preds_lm, clozed_answers)
            lre_correct = 0
            lm_correct = 0
            
            for _, sample, objs, prompt, preds in zip(range(50), relation.samples, clozed_answers, clozed_prompts, preds_lm):
                if (metrics.any_is_nontrivial_prefix(
                    predictions=preds, 
                    targets=objs)):
                    hs = build.get_hidden_state(mt, prompt, sample.subject, 1) #layer 5
                    object_hs = approx_lm(hs, beta, beta_layer)
                    lre_preds = get_object(mt, object_hs)[0]
                    if(metrics.any_is_nontrivial_prefix(predictions=lre_preds, targets=objs)):
                        lre_correct += 1
                    lm_correct += 1
                    
            print(f'{lre_correct} {lm_correct}')

In [54]:
#for most relations.
def is_nontrivial_prefix(prediction: str, target: str) -> bool:
    target = target.lower().strip()
    prediction = prediction.lower().strip()
    # if len(prediction) > 0 and target.startswith(prediction):
    #     print(f"{prediction} matches {target}")
    return len(prediction) > 1 and target.startswith(prediction)

def any_is_nontrivial_prefix(prediction, targets) -> bool:
    return any(is_nontrivial_prefix(prediction, target) for target in targets)

In [55]:
for (subj, obj) in pairs:
    for beta in range(10,50, 1):
        beta /= 10
        hs = get_hidden_state(mt, subj, 5) #layer 5
        object_hs = approx_lm(hs, beta) #beta
        pred = get_object(mt, object_hs)[0]
        if (any_is_nontrivial_prefix(pred[0], obj)):
            print(f"{subj} matches {pred[0]}: {beta}")
            break
    
# for (subj, obj) in pairs:
#     hs = get_hidden_state(mt, subj, 5)
#     object_hs = approx_lm(hs, 2.4)
#     print(f'{subj}: {get_object(mt, object_hs)[0]} {obj}')

NameError: name 'pairs' is not defined

In [None]:
#get tokens in GPT-J
#get the hidden state of them at the last layer (after the 28th layer, or s->o @ 27)
import pickle
from tqdm import tqdm

def get_hidden_state(mt, subject, h_layer, h=None, k=5):
    prompt = f" {subject}"
    h_index, inputs = functional.find_subject_token_index(
        mt = mt, prompt=prompt, subject=subject)
    #print(f'h_index is {h_index}, inputs is {inputs}')
    [[hs], _] = functional.compute_hidden_states(
        mt = mt, layers = [h_layer], inputs = inputs)
    #h is hs @ h_layer @ h_index
    h = hs[:, h_index]
    h = h.to(device)
    return h
    
#Spaces are converted in a special character (the Ġ ) in the tokenizer prior to BPE splitting
#mostly to avoid digesting spaces since the standard BPE algorithm used spaces in its process 

#all animal encodings are at [-0.4153   2.023   -2.23    ... -0.785    0.06323 -0.1819 ]

text = "our classic pre-baked blueberry pie filled with delicious plump and juicy wild blueberries"
encoded_input = mt.tokenizer(text, return_tensors="pt")
token_ids = range(0,50400)
tokens = tokenizer.convert_ids_to_tokens(token_ids)
tokens = [token.replace("Ġ", " ") for token in tokens]

#this is too slow and not useful.
dict27 = {}
for i in tqdm(range(len(tokens))):
    token = tokens[i]
    dict27[token] = get_hidden_state(mt, token, 27)
    
with open('animal_youth_27.pkl', 'wb') as file:
    pickle.dump(dict27, file)

In [211]:
mt.lm_head

Sequential(
  (0): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
  (1): Linear(in_features=4096, out_features=50400, bias=True)
)

In [8]:
animals = ["dog", "duck", "fish", "horse", "mink", "seal", "shark", "trout"]

[' puppy', ' pup', ' p', ' dog', ' �']
[' duck', ' dra', ' g', ' �', ' "']
[' fry', ' prog', ' F', ' �', ' lar']
[' fo', ' col', ' horse', ' pony', ' �']
[' kit', ' m', ' "', ' �', ' p']
[' seal', ' "', ' �', ' pup', ' p']
[' shark', ' "', ' �', ' p', ' pup']
[' fry', ' trout', ' "', ' �', ' rainbow']
