In [65]:
import torch
import sys
sys.path.append('../..')
from transformers import GPTJForCausalLM, AutoTokenizer
import lre.models as models
import lre.functional as functional
import os

weights = []
biases = []

wdir = 'antonymsbinary/h5/'
weight_str = 'antonyms - binary_weight'
bias_str = 'antonyms - binary_bias'

weight_paths = [f for f in os.listdir(wdir) if f.startswith(weight_str)]
bias_paths = [f for f in os.listdir(wdir) if f.startswith(bias_str)]

for bias_path, weight_path in zip(bias_paths, weight_paths):
    weight = torch.load(f'{wdir}' + weight_path)
    bias = torch.load(f'{wdir}' + bias_path)
    weights.append(weight)
    biases.append(bias)
    
weight = torch.stack(weights).mean(dim=0).to(device)
bias = torch.stack(biases).mean(dim=0).to(device)

In [2]:
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True)
model.to('cuda:1')
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer.pad_token = tokenizer.eos_token

mt = models.ModelAndTokenizer(model,tokenizer)

In [66]:
#testing data
import json
json_path = 'antonymsbinary/antonymbinary.json'
pairs = []
with open(json_path, 'r') as file:
    data = json.load(file)
    for pair in data['samples']:
        pairs.append((pair['subject'],pair['object']))

In [75]:
import numpy as np
def get_object(mt, subject, weight,bias, prompt, h_layer, beta, k=5):
    h_index, inputs = functional.find_subject_token_index(
        mt = mt, prompt=prompt, subject=subject)
    #print(f'h_index is {h_index}, inputs is {inputs}')
    [[hs], _] = functional.compute_hidden_states(
        mt = mt, layers = [h_layer], inputs = inputs)
    #h is hs @ h_layer @ h_index
    h = hs[:, h_index]
    h = h.to(device)
    
    #apply mean jacobian and bias
    z = h.mm(weight.t()) * beta + bias
    
    logits = mt.lm_head(z)
    dist = torch.softmax(logits.float(), dim=-1)
    topk = dist.topk(k=k, dim=-1)
    probs = topk.values.view(5).tolist()
    token_ids = topk.indices.view(5).tolist()
    words = [mt.tokenizer.decode(token_id) for token_id in token_ids]
    return (words[0], probs[0])

In [76]:
def matches(ab, pred):
    pred = pred.strip()
    a,b = ab[0].strip(), ab[1].strip()
    return a != pred and b.startswith(pred) and len(pred) > 1
assert(matches(("angry", "angrier"), "ang") is True)
assert(matches(("angry", "angrier"), "angry") is False)
assert(matches(("angry", "angrier"), "c") is False)

#for most relations.
def is_nontrivial_prefix(prediction: str, target: str) -> bool:
    target = target.lower().strip()
    prediction = prediction.lower().strip()
    # if len(prediction) > 0 and target.startswith(prediction):
    #     print(f"{prediction} matches {target}")
    return len(prediction) > 1 and target.startswith(prediction)

def any_is_nontrivial_prefix(prediction, targets) -> bool:
    return any(is_nontrivial_prefix(prediction, target) for target in targets)

In [77]:
subj = 'ahead'
prompt = f'The opposite of {subj} is'
(pred, prob) = get_object(mt,subj,weight,bias,prompt, 5, 3)

In [78]:
#word-specific beta sweep
device = 'cuda:1'
import logging

# Set up the basic configuration for logging
logging.basicConfig(
    filename='_antonymbinary.log',         # Specify the file name
    filemode='a',               # 'a' for append mode, 'w' for write mode (overwrites the file)
    format='%(asctime)s - %(levelname)s - %(message)s',  # Log format
    level=logging.DEBUG         # Set the log level
)

weight = weight.to(device)
bias = bias.to(device)
corr = 0
logging.info('pred,obj,beta')
for pair in pairs:
    found = 0
    for beta in range(10,40, 1):
        beta = beta/10
        subj, obj = pair
        prompt = f'The opposite of {subj} is'
        (pred, prob) = get_object(mt,subj,weight,bias,prompt, 5, beta)
        if (any_is_nontrivial_prefix(pred, obj)):
            corr += 1
            print(f'{pred},{beta}')
            logging.info(f'{pred},{beta}')
            found = 1
            break
    (pred, prob) = get_object(mt,subj,weight,bias,prompt, 5, 3)
    if found == 0:
        print(f'{obj[0]}, None')
print(f'{corr}/{len(pairs)} correct')

 before,1.9
behind, None
 posterior,1.8
 forward,1.0
 after,1.4
 end,1.7
 up,1.1
descend, None
 alive,2.2
 increment,3.1
 up,1.0
emerge, None
 up,1.0
 up,1.0
lift, None
 static,2.1
dismiss, None
 entrance,1.7
 up,1.0
last, None
 remember,2.8
 backward,1.0
 back,1.0
 out,1.0
 out,1.0
 exh,3.8
 out,1.6
 out,1.0
 outside,1.4
 out,1.4
reverse, None
 immortal,2.3
vacant, None
 on,1.0
 in,1.3
upward, None
under, None
 later,2.3
retreat, None
 fall,2.0
 north,1.2
 north,1.4
emerge, None
 bottom,1.0
away, None
 false,1.4
 over,1.0
 down,1.0
 downhill,1.2
 east,1.0
37/50 correct


In [59]:
!pip install wordfreq

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


