In [1]:
triple2text = {}
lines = open('../../dataset/FB60K-NYT10-100/triple2text.txt')
for line in lines:
    triple, text = line.split('####SPLIT####')
    h, r, t = triple.split('||')
    triple_ = h +'\t' + r + '\t' + t
    triple2text[triple_] = text[:-1]


entity2label = {}
entity_label_file = '../../dataset/FB60K-NYT10-100/entity2label.txt'
lines = open(entity_label_file).readlines()
for line in lines:
    entity, label = line.strip().split('\t')
    entity2label[entity] = label


In [2]:
def convert_from_triple_to_sentence(triple, template, mask=False):
    h, r, t = triple
    h_, t_ = entity2label[h], entity2label[t]
    triple_ = h_ +'\t' + r + '\t' + t_

    this_template = template

    if triple2text is not None:
        sentence = triple2text[triple_] if triple_ in triple2text.keys() else 'None'

        this_template = f'{sentence} . {this_template}'

    if entity2label is not None:
        h, t = entity2label[h], entity2label[t]
        if mask == True:
            t = '[MASK]'

    this_template = this_template.replace('[X]', '::;;##').replace('[Y]', '::;;##')
    prompts = this_template.split('::;;##')
    prompts = [x.strip() for x in prompts]
    assert(len(prompts) == 3)

    idx_x = template.find('[X]')
    idx_y = template.find('[Y]')
    if idx_x < idx_y:
        final_list = [prompts[0], h.strip(), prompts[1], t.strip(), prompts[2]]
    else:
        final_list = [prompts[0], t.strip(), prompts[1], h.strip(), prompts[2]]
    return ' '.join(final_list)

In [3]:
from transformers import AutoTokenizer, AutoModelForMaskedLM, BertTokenizer, BertForMaskedLM
import torch
from tqdm import tqdm

# tokenizer = AutoTokenizer.from_pretrained("studio-ousia/luke-base")

# model = AutoModelForMaskedLM.from_pretrained("studio-ousia/luke-base")

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForMaskedLM.from_pretrained("bert-base-uncased")

def run(relation, template):
    lines = open('../../dataset/FB60K-NYT10-100/test.txt')
    triples = []

    for line in lines:
        h, r, t = line.split('\t')
        t = t[:-1]
        if r == relation:
            triples.append((h,r,t))

    loss = 0
    cnt = 0
    for triple in tqdm(triples):
        text_masked = convert_from_triple_to_sentence(triple=triple, template=template, mask=True)
        inputs = tokenizer(text_masked, return_tensors="pt")

        with torch.no_grad():
            logits = model(**inputs).logits

        mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]

        predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
        # print(triple)

        text_label = convert_from_triple_to_sentence(triple=triple, template=template)
        # print(text_masked)
        # print(text_label)
        labels = tokenizer(text_label, return_tensors="pt")["input_ids"]
        try:
            labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
        except:
            continue


        outputs = model(**inputs, labels=labels)
        loss += outputs.loss.item()
        cnt += 1
    
    return loss / cnt



  from .autonotebook import tqdm as notebook_tqdm
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
relation = '/people/person/place_lived'
templates = ['[X] lives in [Y]', '[X] lived in [Y]', '[X] , born in [Y]']

for template in templates:
    print(template)
    loss = run(relation, template)
    print(loss)

## 11 16 11

In [58]:
templates = [
    # "[Y] , [X]",
 "[X] , [Y]", "[Y] [X]", "[X] [Y]", "[X] - [Y]", "[X] and in [Y]", "[X] lived in [Y]", "[X] , [Y] ,", "[X] at the [Y]", "[X] from the [Y]", "[Y] [X] ,", "[X] at [Y]", "[Y] - born [X]", "[X] , born in [Y]", "[X] , who was born in [Y]", "award : [X] , [Y]"]
for template in templates:
    print(template)
    loss = run(relation, template)
    print(loss)

# 9.536

[X] , [Y]


 48%|████▊     | 843/1758 [03:20<04:43,  3.23it/s]

In [None]:
relation = '/location/location/contains'
templates = [
    "[X] geographically contains [Y] .",
    "[X] contains [Y] .",
    "[Y] located in [X] .",
    "[Y] is located in [X] .",
    "[Y] is in [X] .",
    "[Y] [X]",
    "[X] [Y]",
    "[Y] , [X]",
    "people from [Y] , [X]",
    "[Y] , [X] ,",
    "[Y] in [X]",
    "[Y] and [X]",
    "[Y] , the [X]",
    "[Y] , in [X]",
    "in [Y] , [X]",
    "[X] , [Y]",
    "university of [X] , [Y]",
    "east [Y] , [X]",
    "university of [Y] in [X]",
    "[X] populated places on the [Y]",
    "school in [Y] , [X] ,",
    "city of [Y] , [X]",
    "[Y] district of [X]",
    "school in [Y] , [X]",
    "home in [Y] , [X] ,"
    ]
for template in templates:
    print(template)
    loss = run(relation, template)
    print(loss)

In [4]:
relation = '/people/person/nationality'

templates = [
    "The nationality of [X] is [Y] .",
    "[X]'s nationality is [Y] .",
    "nationality of [X] - [Y] .",
    "[X] is from [Y] .",
    "[X] is from [Y] (country) .",
    "[X] born in [Y] .",
    "[X] was born in [Y] .",
    "[X] is a [Y] citizen .",
    "[Y] [X]",
    "[X] ( [Y] )",
    "[X] is an [Y]",
    "[X] in [Y]",
    "[X] as [Y]",
    "[X] , [Y]",
    "[Y] , [X]",
    "[X] ( [Y] ),",
    "[X] of [Y]",
    "[Y] , [X] ,",
    "[X] of [Y] ,",
    "[X] in , [Y]",
    "[Y] - [X]",
    "[X] from [Y] ,",
    "[X] in the [Y]",
    "[X] , the [Y]",
    "[X] from [Y]",
    "[X] and the [Y]",
    "[X] - [Y]",
    "[X] of the times of [Y]",
    "[X] , [Y] )",
    "( [Y] ) [X]",
    "[X] [Y]",
    "[Y] ' s [X]",
    "[X] , a [Y]",
    "[X] , an [Y]",
    "[Y] and [X]",
    "[X] ' s [Y]",
    "[X] , [Y] ,",
    "[X] , [Y] '",
    "[Y] under [X]",
    "[Y] by [X]",
    "[X] , - [Y]",
    "[Y] captain [X]",
    "[Y] ' s [X] ,",
    "[X] of the [Y]",
    "[X] and [Y]",
    "[Y] : [X]",
    "[Y] with [X]",
    "[Y] after [X]",
    "[X] [Y] '",
    "[X] the [Y]",
    "[Y] leader [X]"
    ]

for template in templates:
    print(template)
    loss = run(relation, template)
    print(loss)

The nationality of [X] is [Y] .


100%|██████████| 2494/2494 [38:53<00:00,  1.07it/s] 


8.225764293140836
[X]'s nationality is [Y] .


100%|██████████| 2494/2494 [37:05<00:00,  1.12it/s] 


5.785918122430642
nationality of [X] - [Y] .


100%|██████████| 2494/2494 [29:12<00:00,  1.42it/s] 


5.86999755523271
[X] is from [Y] .


100%|██████████| 2494/2494 [24:32<00:00,  1.69it/s] 


3.2586262934323815
[X] is from [Y] (country) .


100%|██████████| 2494/2494 [55:04<00:00,  1.33s/it] 


3.231822348733743
[X] born in [Y] .


100%|██████████| 2494/2494 [25:16<00:00,  1.64it/s] 


2.8473114577742913
[X] was born in [Y] .


100%|██████████| 2494/2494 [33:19<00:00,  1.25it/s] 


3.464634246604724
[X] is a [Y] citizen .


100%|██████████| 2494/2494 [28:32<00:00,  1.46it/s] 


9.536245273219215
[Y] [X]


100%|██████████| 2494/2494 [29:04<00:00,  1.43it/s] 


10.08024020936754
[X] ( [Y] )


100%|██████████| 2494/2494 [23:03<00:00,  1.80it/s] 


7.576448968119092
[X] is an [Y]


100%|██████████| 2494/2494 [28:35<00:00,  1.45it/s] 


16.982728010813396
[X] in [Y]


 15%|█▍        | 366/2494 [02:22<13:50,  2.56it/s]  


KeyboardInterrupt: 