In [7]:
import torch
import sys
sys.path.append('../..')
from transformers import GPTJForCausalLM, AutoTokenizer
import lre.models as models
import lre.functional as functional

In [11]:
weights = []
biases = []
for i in range(0,8):
    weight_str = f'adj - comparative_{i}_weight.pt'
    bias_str = f'adj - comparative_{i}_bias.pt'
    weight = torch.load(weight_str)
    bias = torch.load(bias_str)
    weights.append(weight)
    biases.append(bias)

weight = torch.stack(weights).mean(dim=0)
bias = torch.stack(biases).mean(dim=0)

In [12]:
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True)
model.to('cuda:1')
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer.pad_token = tokenizer.eos_token

mt = models.ModelAndTokenizer(model,tokenizer)

In [34]:
adj_comp = [("angry","angrier"),("cheap","cheaper"),("clever","cleverer"),("coarse","coarser"),("costly","costlier"),("cute","cuter"),("dense","denser"),("dumb","dumber"),("fierce","fiercer"),("handy","handier"),("happy","happier"),("hardy","hardier"),("harsh","harsher"),("healthy","healthier"),("hot","hotter"),("huge","huger"),("hungry","hungrier"),("lazy","lazier"),("lengthy","lengthier"),("lucky","luckier"),("mad","madder"),("merry","merrier"),("mild","milder"),("moist","moister"),("nasty","nastier"),("neat","neater"),("nice","nicer"),("noisy","noisier"),("proud","prouder"),("pure","purer"),("risky","riskier"),("rocky","rockier"),("rude","ruder"),("sad","sadder"),("scary","scarier"),("sexy","sexier"),("sticky","stickier"),("strict","stricter"),("strong","stronger"),("subtle","subtler"),("sunny","sunnier"),("tasty","tastier"),("tiny","tinier"),("tricky","trickier"),("ugly","uglier"),("vague","vaguer"),("vast","vaster"),("weak","weaker"),("wealthy","wealthier"),("weird","weirder")]

In [46]:
import numpy as np
def get_object(mt, subject, prompt, h_layer, beta, k=5):
    h_index, inputs = functional.find_subject_token_index(
        mt = mt, prompt=prompt, subject=subject)
    
    [[hs], _] = functional.compute_hidden_states(
        mt = mt, layers = [h_layer], inputs = inputs)
    #h is hs @ h_layer @ h_index
    h = hs[:, h_index]
    z = h.mm(weight.t()) * beta + bias
    
    logits = mt.lm_head(z)
    dist = torch.softmax(logits.float(), dim=-1)
    topk = dist.topk(dim=-1, k=k)
    probs = topk.values.view(k).tolist()
    token_ids = topk.indices.view(k).tolist()
    words = [mt.tokenizer.decode(token_id) for token_id in token_ids]
    
    return (words[0], probs[0])

def matches(ab, pred):
    a,b = ab[0].strip(), ab[1].strip()
    return a != pred and b.startswith(pred) and len(pred) > 1
        
for (adj,comp) in adj_comp:
    # for beta in range(-10,10):
    #     beta = np.exp(beta)
    for beta in range(0,30,1):
        beta = beta / 5
        prompt = f'merry merrier\n healthy healthier\n scary scarier\n vague vaguer\n wealthy wealthier\n cheap cheaper\n {adj}'
        (pred, prob) = get_object(mt, adj, prompt, 5, beta)
        if(matches((adj,comp), pred)):
            print(f"{beta}: {pred} @ {prob}")
            break

KeyboardInterrupt: 