Linear Oriented Relational Embeddings

In [2]:
import torch
import sys
sys.path.append('../..')
from transformers import GPTJForCausalLM, AutoTokenizer
import lre.models as models
import lre.functional as functional
import os
from dataclasses import dataclass, field
import torch
from dataclasses_json import DataClassJsonMixin
device = "cuda:0"

model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True)
model.to(device)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer.pad_token = tokenizer.eos_token
mt = models.ModelAndTokenizer(model,tokenizer)

In [45]:
i = 0
wdir = 'animal_youth'
weight_str = 'weight_animal - youth_sem1'
bias_str = 'bias_animal - youth_sem1'

weight_paths = [f for f in os.listdir(wdir) if f.startswith(weight_str)]
bias_paths = [f for f in os.listdir(wdir) if f.startswith(bias_str)]

@dataclass(frozen=True, kw_only=True)
class SubjectWeight(DataClassJsonMixin):
    subject: str
    weight: torch.Tensor
    bias: torch.Tensor
    hs: torch.Tensor

weights = []
biases = []
subjects = []
subject_weights = []
h_layer = 1

def no_context_hs(word):
    h_index, inputs = functional.find_subject_token_index(
    mt = mt, prompt=f'The offspring of a {word} is referred to as a ', subject=word)
    [[hs], _] = functional.compute_hidden_states(
        mt = mt, layers = [h_layer], inputs = inputs)
    #h = hs @ h_layer @ h_index
    h = hs[:, h_index]
    return h

for bias_path, weight_path in zip(bias_paths, weight_paths):
    subject = weight_path.split("_")[3]
    weight = torch.load(f'{wdir}/' + weight_path)
    bias = torch.load(f'{wdir}/' + bias_path)
    hs = no_context_hs(subject).cpu()
    sw = SubjectWeight(subject=subject,
                       weight=weight,
                       bias=bias,
                       hs = hs
                       )
    subject_weights.append(sw)
    
weights = [s.weight for s in subject_weights]
biases = [s.bias for s in subject_weights]

weight = torch.stack(weights).mean(dim=0).to(device)
bias = torch.stack(biases).mean(dim=0).to(device)
#weigh similar subjects more than distant subjects.

In [46]:
#semantic method
import numpy as np
sim = torch.nn.CosineSimilarity(dim=1)
#This doesn't work because more similar animals ~= similar youth.
def lore_wb(word):
    word_hs = no_context_hs(word).cpu()
    sims = torch.stack([sim(word_hs, subject.hs) for subject in subject_weights])
    sims = (sims - torch.mean(sims)) / torch.std(sims)
    sims
    similarities = sims.reshape((8,1,1)).to(device)
    lore_weights = torch.mul(similarities,torch.stack(weights).to(device))
    lore_biases = torch.mul(similarities,torch.stack(biases).to(device))
    lore_weight = lore_weights.mean(dim=0).to(device)
    lore_bias = lore_biases.mean(dim=0).to(device)
    return (lore_weight, lore_bias)

In [47]:
#testing data
import json
json_path = 'animal-youth.json'
pairs = []
with open(json_path, 'r') as file:
    data = json.load(file)
    for pair in data['samples']:
        pairs.append((pair['subject'],pair['object'][0]))    
        
def is_nontrivial_prefix(prediction: str, target: str) -> bool:
    target = target.lower().strip()
    prediction = prediction.lower().strip()
    # if len(prediction) > 0 and target.startswith(prediction):
    #     print(f"{prediction} matches {target}")
    return len(prediction) > 1 and target.startswith(prediction)

In [48]:
#mt = models.ModelAndTokenizer(model,tokenizer)

LRE_H = 5
def get_object(mt, subject, weight,bias, prompt, h_layer, beta, k=5):
    h_index, inputs = functional.find_subject_token_index(
        mt = mt, prompt=prompt, subject=subject)
    #print(f'h_index is {h_index}, inputs is {inputs}')
    [[hs], _] = functional.compute_hidden_states(
        mt = mt, layers = [h_layer], inputs = inputs)
    
    #h is hs @ h_layer @ h_index
    h = hs[:, h_index]
    h = h.half().to(device)
    z = h.mm(weight.t()) * beta + bias
    logits = mt.lm_head(z)
    dist = torch.softmax(logits.float(), dim=-1)
    topk = dist.topk(dim=-1, k=k)
    probs = topk.values.view(k).tolist()
    token_ids = topk.indices.view(k).tolist()
    words = [mt.tokenizer.decode(token_id) for token_id in token_ids]
    return (words)

for beta in range(0,5, 1):
    beta = 1.5 + beta / 10
    #print("LORE \n")
    LORE_correct = 0
    for pair in pairs:
        subj, obj = pair
        prompt = f'The offspring of a {subj} is referred to as a'
        (lore_weight, lore_bias) = lore_wb(subj)
        pred = get_object(mt,pair[0],lore_weight,lore_bias,prompt, LRE_H, beta)
        #print(f'{pred[0]} {obj}')
        if (is_nontrivial_prefix(pred[0], obj)):
            LORE_correct += 1
    
    #LRE_correct = 0
    LRE_correct = 0
    for pair in pairs:
        subj, obj = pair
        prompt = f'The offspring of a {subj} is referred to as a'
        pred = get_object(mt,pair[0],weight,bias,prompt, 5, beta)
        #print(f'{pred[0]} {obj}')
        if (is_nontrivial_prefix(pred[0], obj)):
            LRE_correct += 1
            
    print(f'beta: {beta} LORE: {LORE_correct} LRE: {LRE_correct}')

beta: 1.5 LORE: 28 LRE: 29
beta: 1.6 LORE: 28 LRE: 29
beta: 1.7 LORE: 29 LRE: 29
beta: 1.8 LORE: 30 LRE: 29
beta: 1.9 LORE: 29 LRE: 30
