In [1]:
import sys
import json
sys.path.append('..')
from lre.data import Relation, RelationSample
from lre.operators import JacobianIclEstimator, Word2VecIclEstimator
import lre.functional as functional
import lre.models as models
import lre.metrics as metrics
from collections import defaultdict
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
counts_by_lm_correct: dict[bool, int] = defaultdict(int)

def test_operator_on_relation(operator, relation, mt, h_layer, z_layer, n_icl=8, k=5):
    #assemble in-context prompts with 8 ICL samples
    prompt_template = relation.prompt_templates[0]
    clozed_prompts = []
    clozed_answers = []
    for x in relation.samples:
        clozed_samples = [s for s in relation.samples if s != x]
        clozed_samples = random.sample(clozed_samples, n_icl)
        cloze_template = functional.make_prompt(
            prompt_template=prompt_template,
            subject="{}",
            examples = clozed_samples
            )
        cloze_prompt = cloze_template.format(x.subject)
        clozed_prompts.append(cloze_prompt)
        clozed_answers.append(x.object)

    for prompt in (clozed_prompts):
        print(f'Prompt: \n{prompt}\n')

    #LM prediction. max-tokens: 2048
    outputs_lm = functional.predict_next_token(mt=mt, prompt=clozed_prompts, k=k)
    preds_lm =  [[x.token for x in xs] for xs in outputs_lm]
    recall_lm = metrics.recall(preds_lm, clozed_answers)

    #operator prediction. this part is intensive.
    operator = operator(mt=mt, h_layer=h_layer, z_layer=z_layer)
    operator = operator(relation)

    outputs_lre = []
    for sample in relation.samples:
        output_lre = operator(sample.subject, k=k)
        outputs_lre.append(output_lre.predictions)

    #remember that predictions is made up of (token,probs)
    preds_lre = [[x.token for x in xs] for xs in outputs_lre]
    recall_lre = metrics.recall(preds_lre, clozed_answers)

    preds_by_lm_correct = defaultdict(list)
    targets_by_lm_correct = defaultdict(list)

    #if the LM was correct, append pred_lre to preds_by_lm_correct (sth like {True: 5, False: 2})
    for pred_lm, pred_lre, target in zip(preds_lm, preds_lre, clozed_answers):
        lm_correct = metrics.any_is_nontrivial_prefix(pred_lm, target)
        preds_by_lm_correct[lm_correct].append(pred_lre)
        targets_by_lm_correct[lm_correct].append(target)
        counts_by_lm_correct[lm_correct] += 1

    print(f'For {type(operator)} on {relation.name} (out of correct): {counts_by_lm_correct}')

### Test prompt construction for BATS from JSON

In [4]:
import json
import os
import sys
import random
sys.path.append('..')
from lre.functional import Relation, make_prompt

def all_file_paths(directory):
    file_paths = []
    for root, _, files in os.walk(directory):
        for file in files:
            relative_path = os.path.relpath(os.path.join(root, file), directory)
            file_paths.append(relative_path)
    return file_paths
directory = 'json'
file_paths = all_file_paths('json')

for json_path in file_paths:
    with open('json/' + json_path, 'r') as file:
        data = json.load(file)
        relation = Relation.from_dict(data)
        prompt_template = relation.prompt_templates[0]
        clozed_prompts = []
        clozed_answers = []
        #For each sample...
        for x in relation.samples:
            #make the prompt
            cloze_template = make_prompt(
                prompt_template=prompt_template,
                subject=x,
                examples = relation.samples
                )
            #print(f'\n{cloze_template}\n')
            cloze_prompt = cloze_template.format(x.subject)
            clozed_prompts.append(cloze_prompt)
            clozed_answers.append(x.object)

        #should print 50 (?)
        for prompt,answer in zip(clozed_prompts,clozed_answers):
            print(f'Prompt: \n{prompt}\n')
            print(f'Answer: \n{answer}\n')

Prompt: 
A part of a day is a nanosecond
A part of a car is a head_restraint
A part of a castle is a keep
A part of a seafront is a dock
A part of a tonne is a myriagram
A part of a torso is a axillary_artery
A part of a car is a lights
A part of a door is a doorsill
A part of a academia is a

Answer: 
['college', 'university', 'institute']

Prompt: 
A part of a torso is a rib_cage
A part of a jewellery is a stone
A part of a piano is a hammer
A part of a dress is a sleeve
A part of a seafront is a harbour
A part of a torso is a deltoid_muscle
A part of a womb is a myometrium
A part of a torso is a bellybutton
A part of a apartment is a

Answer: 
['bedroom', 'room', 'bathroom', 'kitchen', 'kitchenette', 'living_room', 'pantry', 'toilet', 'shower_room']

Prompt: 
A part of a day is a midafternoon
A part of a torso is a corpus_sternum
A part of a sonata is a movement
A part of a car is a window
A part of a tonne is a g
A part of a typewriter is a typewriter_carriage
A part of a car is a 

In [4]:
device = "mps"
mt = models.load_model("gptj", device=device, fp16=True)

In [5]:
import os
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.7'

In [None]:
test_operator_on_relation(Word2VecIclEstimator, relation, mt, 5, 27, k=5)

In [7]:
test_operator_on_relation(JacobianIclEstimator, relation, mt, 5, 27)

Prompt: 
The country with tbilisi as its capital is known as georgia
The country with ankara as its capital is known as turkey
The country with manila as its capital is known as philippines
The country with kiev as its capital is known as ukraine
The country with bern as its capital is known as switzerland
The country with zagreb as its capital is known as croatia
The country with kabul as its capital is known as afghanistan
The country with bangkok as its capital is known as thailand
The country with abuja as its capital is known as

Prompt: 
The country with conakry as its capital is known as guinea
The country with copenhagen as its capital is known as denmark
The country with bangkok as its capital is known as thailand
The country with lima as its capital is known as peru
The country with stockholm as its capital is known as sweden
The country with sofia as its capital is known as bulgaria
The country with hanoi as its capital is known as vietnam
The country with kiev as its capita

RuntimeError: MPS backend out of memory (MPS allocated: 17.53 GB, other allocations: 418.77 MB, max allowed: 18.13 GB). Tried to allocate 202.50 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).