In [8]:
import sys
sys.path.append('..')
from lre.data import Relation, RelationSample
from lre.operators import JacobianIclEstimator, Word2VecIclEstimator
import lre.functional as functional

In [9]:
import lre.models as models
device = "mps"
mt = models.load_model("gptj", device=device, fp16=True)

KeyboardInterrupt: 

In [6]:
#We want to evaluate each subject with any of the available objects.
animal_path = 'data/enckno/animal-young.txt'
subjects = []
subject_object_pairs = []
all_pairs = []

def relation_from_path(path, relation_name, prompt):
    with open(path, "r") as f:
        lines = f.readlines()
        lines = [line.replace('\n','') for line in lines]

        for line in lines:
            a, bs = line.split('\t')
            subjects.append(a)
            bs = bs.split('/')
            subject_object_pairs.append((a,bs[0]))
            for b in bs:
                all_pairs.append((a,b))

    pairs = subject_object_pairs
    RelationSamples = [RelationSample(*pair) for pair in pairs]
    rel = Relation(
                    name=relation_name,
                    prompt_templates=[prompt],
                    prompt_templates_zs=[],
                    samples=
                    RelationSamples
                )
    return rel

animal_rel = relation_from_path(animal_path, "animal-youth", "The young version of {} is")

In [4]:
import lre.metrics as metrics
from collections import defaultdict

counts_by_lm_correct: dict[bool, int] = defaultdict(int)

def test_operator_on_relation(operator, relation, mt, h_layer, z_layer):
    k = 5
    operator = operator(mt=mt, h_layer=h_layer, z_layer=z_layer)
    operator = operator(relation)
    prompt_template = relation.prompt_templates[0]

    #assemble in-context prompts
    clozed_prompts = []
    clozed_answers = []
    for x in relation.samples:
        clozed_samples = [s for s in relation.samples if s != x]
        cloze_template = functional.make_prompt(
            prompt_template=prompt_template,
            subject="{}",
            examples = clozed_samples
            )
        cloze_prompt = cloze_template.format(x.subject)
        clozed_prompts.append(cloze_prompt)
        clozed_answers.append(x.object)

    for prompt in (clozed_prompts):
        print(f'Prompt: \n{prompt}\n')

    #max-tokens: 2048.
    #functional.predict_next_token uses the LM
    outputs_lm = functional.predict_next_token(mt=mt, prompt=clozed_prompts, k=k)
    preds_lm =  [[x.token for x in xs] for xs in outputs_lm]
    recall_lm = metrics.recall(preds_lm, clozed_answers)

    outputs_lre = []
    for sample in relation.samples:
        print(f'operator has {type(operator)} (should be LinearRelationOperator)')
        output_lre = operator(sample.subject, k=k)
        outputs_lre.append(output_lre.predictions)

    #predictions is made up of (token,probs)
    preds_lre = [[x.token for x in xs] for xs in outputs_lre]
    recall_lre = metrics.recall(preds_lre, clozed_answers)

    preds_by_lm_correct = defaultdict(list)
    targets_by_lm_correct = defaultdict(list)

    #if the LM was correct, append pred_lre to preds_by_lm_correct
    for pred_lm, pred_lre, target in zip(preds_lm, preds_lre, clozed_answers):
        lm_correct = metrics.any_is_nontrivial_prefix(pred_lm, target)
        #something like {True: 5, False: 2}
        preds_by_lm_correct[lm_correct].append(pred_lre)
        targets_by_lm_correct[lm_correct].append(target)
        counts_by_lm_correct[lm_correct] += 1

    print(f'For {operator} on {relation.name} (out of correct): {counts_by_lm_correct}')

In [7]:
#test_operator_on_relation(JacobianIclEstimator, animal_rel, mt, 5, 27)
test_operator_on_relation(Word2VecIclEstimator, animal_rel, mt, 5, 27)

Prompt: 
The young version of badger is kit
The young version of bear is cub
The young version of beaver is kit
The young version of bee is larva
The young version of beetle is larva
The young version of buffalo is calf
The young version of butterfly is larva
The young version of camel is calf
The young version of cat is kitten
The young version of cattle is calf
The young version of chimpanzee is baby
The young version of cicada is nymph
The young version of cockroach is nymph
The young version of cricket is larva
The young version of deer is fawn
The young version of dog is puppy
The young version of ape is

Prompt: 
The young version of ape is baby
The young version of bear is cub
The young version of beaver is kit
The young version of bee is larva
The young version of beetle is larva
The young version of buffalo is calf
The young version of butterfly is larva
The young version of camel is calf
The young version of cat is kitten
The young version of cattle is calf
The young version 