In [54]:
from transformers import AutoTokenizer, AutoModelForMaskedLM, BertTokenizer, BertForMaskedLM
import torch
from tqdm import tqdm

# tokenizer = AutoTokenizer.from_pretrained("studio-ousia/luke-base")

# model = AutoModelForMaskedLM.from_pretrained("studio-ousia/luke-base")

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForMaskedLM.from_pretrained("bert-base-uncased")

def run(relation, template):
    lines = open('../../dataset/FB60K-NYT10-100/test.txt')
    triples = []

    for line in lines:
        h, r, t = line.split('\t')
        t = t[:-1]
        if r == relation:
            triples.append((h,r,t))

    loss = 0
    cnt = 0
    for triple in tqdm(triples):
        text_masked = convert_from_triple_to_sentence(triple=triple, template=template, mask=True)
        inputs = tokenizer(text_masked, return_tensors="pt")

        with torch.no_grad():
            logits = model(**inputs).logits

        mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]

        predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
        # print(triple)

        text_label = convert_from_triple_to_sentence(triple=triple, template=template)
        # print(text_masked)
        # print(text_label)
        labels = tokenizer(text_label, return_tensors="pt")["input_ids"]
        try:
            labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
        except:
            continue


        outputs = model(**inputs, labels=labels)
        loss += outputs.loss.item()
        cnt += 1
    
    return loss / cnt



Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
relation = '/people/person/place_lived'
templates = ['[X] lives in [Y]', '[X] lived in [Y]', '[X] , born in [Y]']

for template in templates:
    print(template)
    loss = run(relation, template)
    print(loss)

## 11 16 11

In [58]:
templates = [
    # "[Y] , [X]",
 "[X] , [Y]", "[Y] [X]", "[X] [Y]", "[X] - [Y]", "[X] and in [Y]", "[X] lived in [Y]", "[X] , [Y] ,", "[X] at the [Y]", "[X] from the [Y]", "[Y] [X] ,", "[X] at [Y]", "[Y] - born [X]", "[X] , born in [Y]", "[X] , who was born in [Y]", "award : [X] , [Y]"]
for template in templates:
    print(template)
    loss = run(relation, template)
    print(loss)

# 9.536

[X] , [Y]


100%|██████████| 1758/1758 [17:30<00:00,  1.67it/s]


13.38991347293714
[Y] [X]


100%|██████████| 1758/1758 [20:01<00:00,  1.46it/s]


11.40917787927411
[X] [Y]


100%|██████████| 1758/1758 [23:56<00:00,  1.22it/s] 


19.017863129521466
[X] - [Y]


100%|██████████| 1758/1758 [19:50<00:00,  1.48it/s]


15.160402970436293
[X] and in [Y]


100%|██████████| 1758/1758 [20:00<00:00,  1.46it/s]


16.080653077080136
[X] lived in [Y]


100%|██████████| 1758/1758 [22:12<00:00,  1.32it/s] 


15.239846391992254
[X] , [Y] ,


100%|██████████| 1758/1758 [20:04<00:00,  1.46it/s]


8.072169558638912
[X] at the [Y]


100%|██████████| 1758/1758 [24:20<00:00,  1.20it/s] 


13.412744543054602
[X] from the [Y]


100%|██████████| 1758/1758 [20:01<00:00,  1.46it/s]


12.315905910271864
[Y] [X] ,


100%|██████████| 1758/1758 [18:10<00:00,  1.61it/s]


11.163660777357471
[X] at [Y]


100%|██████████| 1758/1758 [22:21<00:00,  1.31it/s]


16.10741245702946
[Y] - born [X]


100%|██████████| 1758/1758 [19:49<00:00,  1.48it/s]


6.795481565752964
[X] , born in [Y]


100%|██████████| 1758/1758 [20:32<00:00,  1.43it/s] 


11.677484704024625
[X] , who was born in [Y]


100%|██████████| 1758/1758 [32:00<00:00,  1.09s/it] 


13.455974311619014
award : [X] , [Y]


100%|██████████| 1758/1758 [20:41<00:00,  1.42it/s]

11.337127789909587





In [59]:
relation = '/location/location/contains'
templates = [
    "[X] geographically contains [Y] .",
    "[X] contains [Y] .",
    "[Y] located in [X] .",
    "[Y] is located in [X] .",
    "[Y] is in [X] .",
    "[Y] [X]",
    "[X] [Y]",
    "[Y] , [X]",
    "people from [Y] , [X]",
    "[Y] , [X] ,",
    "[Y] in [X]",
    "[Y] and [X]",
    "[Y] , the [X]",
    "[Y] , in [X]",
    "in [Y] , [X]",
    "[X] , [Y]",
    "university of [X] , [Y]",
    "east [Y] , [X]",
    "university of [Y] in [X]",
    "[X] populated places on the [Y]",
    "school in [Y] , [X] ,",
    "city of [Y] , [X]",
    "[Y] district of [X]",
    "school in [Y] , [X]",
    "home in [Y] , [X] ,"
    ]
for template in templates:
    print(template)
    loss = run(relation, template)
    print(loss)

[X] geographically contains [Y] .


100%|██████████| 2941/2941 [40:01<00:00,  1.22it/s]  


8.066132125692704
[X] contains [Y] .


100%|██████████| 2941/2941 [14:02<00:00,  3.49it/s]


9.507862770814024
[Y] located in [X] .


100%|██████████| 2941/2941 [13:34<00:00,  3.61it/s]


12.145631687685999
[Y] is located in [X] .


100%|██████████| 2941/2941 [21:51<00:00,  2.24it/s]


10.213791070346938
[Y] is in [X] .


100%|██████████| 2941/2941 [13:02<00:00,  3.76it/s]


9.13350577041097
[Y] [X]


100%|██████████| 2941/2941 [12:47<00:00,  3.83it/s]


10.506316365277888
[X] [Y]


100%|██████████| 2941/2941 [15:03<00:00,  3.26it/s]


21.505543280958957
[Y] , [X]


100%|██████████| 2941/2941 [11:46<00:00,  4.16it/s]


7.311318049853511
people from [Y] , [X]


100%|██████████| 2941/2941 [03:14<00:00, 15.16it/s]


7.186374879230216
[Y] , [X] ,


100%|██████████| 2941/2941 [03:16<00:00, 15.00it/s]


7.361007068125183
[Y] in [X]


100%|██████████| 2941/2941 [03:45<00:00, 13.02it/s]


11.420278889117641
[Y] and [X]


100%|██████████| 2941/2941 [03:26<00:00, 14.21it/s]


9.049167473348971
[Y] , the [X]


100%|██████████| 2941/2941 [03:18<00:00, 14.81it/s]


9.234255067615967
[Y] , in [X]


100%|██████████| 2941/2941 [03:20<00:00, 14.67it/s]


8.504075358086704
in [Y] , [X]


100%|██████████| 2941/2941 [03:18<00:00, 14.82it/s]


7.983050687173038
[X] , [Y]


100%|██████████| 2941/2941 [03:57<00:00, 12.36it/s]


16.849903240562693
university of [X] , [Y]


100%|██████████| 2941/2941 [03:31<00:00, 13.93it/s]


13.710433042342688
east [Y] , [X]


100%|██████████| 2941/2941 [03:37<00:00, 13.53it/s]


9.254566068324367
university of [Y] in [X]


100%|██████████| 2941/2941 [03:42<00:00, 13.22it/s]


10.64755086078849
[X] populated places on the [Y]


100%|██████████| 2941/2941 [05:21<00:00,  9.16it/s]


12.788118642594466
school in [Y] , [X] ,


100%|██████████| 2941/2941 [04:30<00:00, 10.88it/s]


7.331460873544311
city of [Y] , [X]


100%|██████████| 2941/2941 [03:40<00:00, 13.33it/s]


8.195436552876677
[Y] district of [X]


100%|██████████| 2941/2941 [03:37<00:00, 13.54it/s]


10.813832975654456
school in [Y] , [X]


100%|██████████| 2941/2941 [04:11<00:00, 11.68it/s]


7.279377772047234
home in [Y] , [X] ,


100%|██████████| 2941/2941 [04:36<00:00, 10.62it/s]

7.593409359094912





In [None]:
relation = '/people/person/nationality'

templates = [
    "The nationality of [X] is [Y] .",
    "[X]'s nationality is [Y] .",
    "nationality of [X] - [Y] .",
    "[X] is from [Y] .",
    "[X] is from [Y] (country) .",
    "[X] born in [Y] .",
    "[X] was born in [Y] .",
    "[X] is a [Y] citizen .",
    "[Y] [X]",
    "[X] ( [Y] )",
    "[X] is an [Y]",
    "[X] in [Y]",
    "[X] as [Y]",
    "[X] , [Y]",
    "[Y] , [X]",
    "[X] ( [Y] ),",
    "[X] of [Y]",
    "[Y] , [X] ,",
    "[X] of [Y] ,",
    "[X] in , [Y]",
    "[Y] - [X]",
    "[X] from [Y] ,",
    "[X] in the [Y]",
    "[X] , the [Y]",
    "[X] from [Y]",
    "[X] and the [Y]",
    "[X] - [Y]",
    "[X] of the times of [Y]",
    "[X] , [Y] )",
    "( [Y] ) [X]",
    "[X] [Y]",
    "[Y] ' s [X]",
    "[X] , a [Y]",
    "[X] , an [Y]",
    "[Y] and [X]",
    "[X] ' s [Y]",
    "[X] , [Y] ,",
    "[X] , [Y] '",
    "[Y] under [X]",
    "[Y] by [X]",
    "[X] , - [Y]",
    "[Y] captain [X]",
    "[Y] ' s [X] ,",
    "[X] of the [Y]",
    "[X] and [Y]",
    "[Y] : [X]",
    "[Y] with [X]",
    "[Y] after [X]",
    "[X] [Y] '",
    "[X] the [Y]",
    "[Y] leader [X]"
    ]

for template in templates:
    print(template)
    loss = run(relation, template)
    print(loss)