In [16]:
import transformers as trf
import numpy as np
import torch as pt
import pandas as pd
import tqdm
import json
from collections import OrderedDict

In [21]:
class HeadlessBert(pt.nn.Module):
    def __init__(self, bert_ver='bert-base-uncased'):
        super().__init__()
        self.bert = trf.BertModel.from_pretrained(bert_ver)
        self.cls = pt.nn.Sequential(OrderedDict([
            ('linear', pt.nn.Linear(self.bert.config.hidden_size, self.bert.config.vocab_size, bias=False))
        ]))
        self.cls.linear.weight = pt.nn.Parameter(self.bert.embeddings.word_embeddings.weight.clone())

    def forward(self, X):
        bert_out, _ = self.bert(input_ids=X)
        return self.cls(bert_out)
class Add(pt.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.size = size
        self.weight = pt.nn.Parameter(pt.zeros(size))
    
    def forward(self, X):
        return self.weight + X

class Scale(pt.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.size = size
        self.weight = pt.nn.Parameter(pt.ones(size))
    
    def forward(self, X):
        return self.weight * X
class GeLU(pt.nn.Module):
    def __init__(self):
        super(GeLU, self).__init__()
        
    def forward(self, X):
        return pt.nn.functional.gelu(X)
    
class LMHeadTrim(pt.nn.Module):
    def __init__(self, bert_variant, trim):
        super(LMHeadTrim, self).__init__()
        tmp = trf.BertForMaskedLM.from_pretrained(bert_variant)
        dense_w = pt.nn.Linear(tmp.config.hidden_size, 
                               tmp.config.hidden_size, 
                               bias=False)
        dense_b = Add(tmp.config.hidden_size)
        layernorm_c = pt.nn.LayerNorm(eps=tmp.config.layer_norm_eps,
                                normalized_shape=(tmp.config.hidden_size,))
        layernorm_w = Scale(tmp.config.hidden_size)
        layernorm_b = Add(tmp.config.hidden_size)
        decoder_w = pt.nn.Linear(tmp.config.hidden_size, 
                                 tmp.config.vocab_size, 
                                 bias=False)
        decoder_b = Add(tmp.config.vocab_size)
        with pt.no_grad():
            dense_w.weight.copy_(tmp.cls.predictions.transform.dense.weight)
            dense_b.weight.copy_(tmp.cls.predictions.transform.dense.bias)
            layernorm_w.weight.copy_(tmp.cls.predictions.transform.LayerNorm.weight)
            layernorm_b.weight.copy_(tmp.cls.predictions.transform.LayerNorm.bias)
            decoder_w.weight.copy_(tmp.cls.predictions.decoder.weight)
            decoder_b.weight.copy_(tmp.cls.predictions.decoder.bias)

        self.seq = pt.nn.Sequential(OrderedDict(
          [
              ('dense_weight', dense_w),
              ('dense_bias', dense_b),
              ('gelu', GeLU()),
              ('layernorm_center', layernorm_c),
              ('layernorm_scale', layernorm_w),
              ('layernorm_bias', layernorm_b),
          ][:trim] +
          [  
              ('decoder_weight', decoder_w)
          ]
        ))
        
    def forward(self, X):
        return self.seq(X)
        
class BertLMImpl(pt.nn.Module):
    def __init__(self, trim, bert_variant='bert-base-uncased'):
        super(BertLMImpl, self).__init__()
        self.bert = trf.BertModel.from_pretrained(bert_variant)
        self.head = LMHeadTrim(bert_variant, trim)
        
    def forward(self, X):
        return self.head(self.bert(input_ids=X)[0])

In [3]:
bertlm = trf.BertForMaskedLM.from_pretrained('bert-base-uncased')
berthl = HeadlessBert()
tkr = trf.BertTokenizer.from_pretrained('bert-base-uncased')

In [4]:
data = []
path_to_conceptnet = '../data/ConceptNet/'
with open(f'{path_to_conceptnet}/test.jsonl', 'r') as f:
    for line in f.readlines():
        data.append(json.loads(line))

In [5]:
def get_obj(x):
    try:
        return x['obj_label']
    except:
        print(x)
        return None

correct_ans = list(map(get_obj, data))

In [6]:
correct_ans_tks = list(map(lambda x: tkr.convert_tokens_to_ids([x]), correct_ans))

In [7]:
# DATA IS CLEAN
[x for x in correct_ans_tks if len(x) > 1]

[]

In [8]:
correct_ans_tks = [x[0] for x in correct_ans_tks]

In [9]:
data_tks = list(tqdm.tqdm(map(lambda x: tkr.encode(x['masked_sentences'][0]), data)))

29774it [00:07, 3817.76it/s]


In [10]:
data_tks_targets = list(map(lambda x: x.index(tkr.mask_token_id), data_tks))

In [11]:
# STATS
# COUNT ORIGINAL HITS
h1 = 0
h2 = 0
h5 = 0
h10 = 0
h20 = 0
h50 = 0
h100 = 0
total = 0
with pt.no_grad():
    for tk, target, ans in tqdm.tqdm(zip(data_tks, data_tks_targets, correct_ans_tks)):
        tkm = tk.copy()
        total += 1
        ranks = ((-bertlm(pt.tensor(tkm).unsqueeze(0))[0][0, target]).argsort())
        top_ranks = ranks[:100]
        try:
            idx = list(top_ranks).index(ans)
        except:
            continue
        if 1 > idx:
            h1 += 1
        if 2 > idx:
            h2 += 1
        if 5 > idx:
            h5 += 1
        if 10 > idx:
            h10 += 1
        if 20 > idx:
            h20 += 1
        if 50 > idx:
            h50 += 1
        h100 += 1

29774it [20:51, 23.80it/s]


In [12]:
print(f"""H@1: {h1}
H@2: {h2}
H@5: {h5}
H@10: {h10}
H@20: {h20}
H@50: {h50}
H@100: {h100}
Total: {total}""")

H@1: 3710
H@2: 5285
H@5: 7743
H@10: 9460
H@20: 11219
H@50: 13825
H@100: 15843
Total: 29774


In [14]:
# STATS
# COUNT ORIGINAL HITS
h1 = 0
h2 = 0
h5 = 0
h10 = 0
h20 = 0
h50 = 0
h100 = 0
total = 0
with pt.no_grad():
    for tk, target, ans in tqdm.tqdm(zip(data_tks, data_tks_targets, correct_ans_tks)):
        tkm = tk.copy()
        total += 1
        ranks = ((-berthl(pt.tensor(tkm).unsqueeze(0))[0, target]).argsort())
        top_ranks = ranks[:100]
        try:
            idx = list(top_ranks).index(ans)
        except:
            continue
        if 1 > idx:
            h1 += 1
        if 2 > idx:
            h2 += 1
        if 5 > idx:
            h5 += 1
        if 10 > idx:
            h10 += 1
        if 20 > idx:
            h20 += 1
        if 50 > idx:
            h50 += 1
        h100 += 1

29774it [16:49, 29.48it/s]


In [15]:
print(f"""H@1: {h1}
H@2: {h2}
H@5: {h5}
H@10: {h10}
H@20: {h20}
H@50: {h50}
H@100: {h100}
Total: {total}""")

H@1: 0
H@2: 0
H@5: 36
H@10: 66
H@20: 116
H@50: 349
H@100: 1311
Total: 29774


In [22]:
bert = BertLMImpl(trim=0)
# STATS
# COUNT ORIGINAL HITS
h1 = 0
h2 = 0
h5 = 0
h10 = 0
h20 = 0
h50 = 0
h100 = 0
total = 0
with pt.no_grad():
    for tk, target, ans in tqdm.tqdm(zip(data_tks, data_tks_targets, correct_ans_tks)):
        tkm = tk.copy()
        total += 1
        ranks = ((-bert(pt.tensor(tkm).unsqueeze(0))[0, target]).argsort())
        top_ranks = ranks[:100]
        try:
            idx = list(top_ranks).index(ans)
        except:
            continue
        if 1 > idx:
            h1 += 1
        if 2 > idx:
            h2 += 1
        if 5 > idx:
            h5 += 1
        if 10 > idx:
            h10 += 1
        if 20 > idx:
            h20 += 1
        if 50 > idx:
            h50 += 1
        h100 += 1

29774it [18:20, 27.05it/s]


In [23]:
print(f"""H@1: {h1}
H@2: {h2}
H@5: {h5}
H@10: {h10}
H@20: {h20}
H@50: {h50}
H@100: {h100}
Total: {total}""")

H@1: 0
H@2: 0
H@5: 36
H@10: 66
H@20: 116
H@50: 349
H@100: 1311
Total: 29774


In [24]:
bert = BertLMImpl(trim=1)
# STATS
# COUNT ORIGINAL HITS
h1 = 0
h2 = 0
h5 = 0
h10 = 0
h20 = 0
h50 = 0
h100 = 0
total = 0
with pt.no_grad():
    for tk, target, ans in tqdm.tqdm(zip(data_tks, data_tks_targets, correct_ans_tks)):
        tkm = tk.copy()
        total += 1
        ranks = ((-bert(pt.tensor(tkm).unsqueeze(0))[0, target]).argsort())
        top_ranks = ranks[:100]
        try:
            idx = list(top_ranks).index(ans)
        except:
            continue
        if 1 > idx:
            h1 += 1
        if 2 > idx:
            h2 += 1
        if 5 > idx:
            h5 += 1
        if 10 > idx:
            h10 += 1
        if 20 > idx:
            h20 += 1
        if 50 > idx:
            h50 += 1
        h100 += 1

29774it [18:13, 27.23it/s]


In [25]:
print(f"""H@1: {h1}
H@2: {h2}
H@5: {h5}
H@10: {h10}
H@20: {h20}
H@50: {h50}
H@100: {h100}
Total: {total}""")

H@1: 66
H@2: 121
H@5: 219
H@10: 405
H@20: 654
H@50: 1241
H@100: 1927
Total: 29774


In [26]:
bert = BertLMImpl(trim=2)
# STATS
# COUNT ORIGINAL HITS
h1 = 0
h2 = 0
h5 = 0
h10 = 0
h20 = 0
h50 = 0
h100 = 0
total = 0
with pt.no_grad():
    for tk, target, ans in tqdm.tqdm(zip(data_tks, data_tks_targets, correct_ans_tks)):
        tkm = tk.copy()
        total += 1
        ranks = ((-bert(pt.tensor(tkm).unsqueeze(0))[0, target]).argsort())
        top_ranks = ranks[:100]
        try:
            idx = list(top_ranks).index(ans)
        except:
            continue
        if 1 > idx:
            h1 += 1
        if 2 > idx:
            h2 += 1
        if 5 > idx:
            h5 += 1
        if 10 > idx:
            h10 += 1
        if 20 > idx:
            h20 += 1
        if 50 > idx:
            h50 += 1
        h100 += 1

29774it [18:05, 27.43it/s]


In [27]:
print(f"""H@1: {h1}
H@2: {h2}
H@5: {h5}
H@10: {h10}
H@20: {h20}
H@50: {h50}
H@100: {h100}
Total: {total}""")

H@1: 60
H@2: 105
H@5: 189
H@10: 349
H@20: 581
H@50: 1147
H@100: 1849
Total: 29774


In [28]:
bert = BertLMImpl(trim=3)
# STATS
# COUNT ORIGINAL HITS
h1 = 0
h2 = 0
h5 = 0
h10 = 0
h20 = 0
h50 = 0
h100 = 0
total = 0
with pt.no_grad():
    for tk, target, ans in tqdm.tqdm(zip(data_tks, data_tks_targets, correct_ans_tks)):
        tkm = tk.copy()
        total += 1
        ranks = ((-bert(pt.tensor(tkm).unsqueeze(0))[0, target]).argsort())
        top_ranks = ranks[:100]
        try:
            idx = list(top_ranks).index(ans)
        except:
            continue
        if 1 > idx:
            h1 += 1
        if 2 > idx:
            h2 += 1
        if 5 > idx:
            h5 += 1
        if 10 > idx:
            h10 += 1
        if 20 > idx:
            h20 += 1
        if 50 > idx:
            h50 += 1
        h100 += 1

29774it [18:19, 27.08it/s]


In [29]:
print(f"""H@1: {h1}
H@2: {h2}
H@5: {h5}
H@10: {h10}
H@20: {h20}
H@50: {h50}
H@100: {h100}
Total: {total}""")

H@1: 4
H@2: 12
H@5: 593
H@10: 978
H@20: 1354
H@50: 2137
H@100: 3247
Total: 29774


In [30]:
bert = BertLMImpl(trim=4)
# STATS
# COUNT ORIGINAL HITS
h1 = 0
h2 = 0
h5 = 0
h10 = 0
h20 = 0
h50 = 0
h100 = 0
total = 0
with pt.no_grad():
    for tk, target, ans in tqdm.tqdm(zip(data_tks, data_tks_targets, correct_ans_tks)):
        tkm = tk.copy()
        total += 1
        ranks = ((-bert(pt.tensor(tkm).unsqueeze(0))[0, target]).argsort())
        top_ranks = ranks[:100]
        try:
            idx = list(top_ranks).index(ans)
        except:
            continue
        if 1 > idx:
            h1 += 1
        if 2 > idx:
            h2 += 1
        if 5 > idx:
            h5 += 1
        if 10 > idx:
            h10 += 1
        if 20 > idx:
            h20 += 1
        if 50 > idx:
            h50 += 1
        h100 += 1

29774it [18:12, 27.26it/s]


In [31]:
print(f"""H@1: {h1}
H@2: {h2}
H@5: {h5}
H@10: {h10}
H@20: {h20}
H@50: {h50}
H@100: {h100}
Total: {total}""")

H@1: 1858
H@2: 2571
H@5: 3726
H@10: 4770
H@20: 5984
H@50: 7776
H@100: 9488
Total: 29774


In [32]:
bert = BertLMImpl(trim=5)
# STATS
# COUNT ORIGINAL HITS
h1 = 0
h2 = 0
h5 = 0
h10 = 0
h20 = 0
h50 = 0
h100 = 0
total = 0
with pt.no_grad():
    for tk, target, ans in tqdm.tqdm(zip(data_tks, data_tks_targets, correct_ans_tks)):
        tkm = tk.copy()
        total += 1
        ranks = ((-bert(pt.tensor(tkm).unsqueeze(0))[0, target]).argsort())
        top_ranks = ranks[:100]
        try:
            idx = list(top_ranks).index(ans)
        except:
            continue
        if 1 > idx:
            h1 += 1
        if 2 > idx:
            h2 += 1
        if 5 > idx:
            h5 += 1
        if 10 > idx:
            h10 += 1
        if 20 > idx:
            h20 += 1
        if 50 > idx:
            h50 += 1
        h100 += 1

29774it [18:13, 27.23it/s]


In [33]:
print(f"""H@1: {h1}
H@2: {h2}
H@5: {h5}
H@10: {h10}
H@20: {h20}
H@50: {h50}
H@100: {h100}
Total: {total}""")

H@1: 3737
H@2: 5339
H@5: 7581
H@10: 9437
H@20: 11224
H@50: 13726
H@100: 15628
Total: 29774


In [34]:
bert = BertLMImpl(trim=6)
# STATS
# COUNT ORIGINAL HITS
h1 = 0
h2 = 0
h5 = 0
h10 = 0
h20 = 0
h50 = 0
h100 = 0
total = 0
with pt.no_grad():
    for tk, target, ans in tqdm.tqdm(zip(data_tks, data_tks_targets, correct_ans_tks)):
        tkm = tk.copy()
        total += 1
        ranks = ((-bert(pt.tensor(tkm).unsqueeze(0))[0, target]).argsort())
        top_ranks = ranks[:100]
        try:
            idx = list(top_ranks).index(ans)
        except:
            continue
        if 1 > idx:
            h1 += 1
        if 2 > idx:
            h2 += 1
        if 5 > idx:
            h5 += 1
        if 10 > idx:
            h10 += 1
        if 20 > idx:
            h20 += 1
        if 50 > idx:
            h50 += 1
        h100 += 1

29774it [18:08, 27.34it/s]


In [35]:
print(f"""H@1: {h1}
H@2: {h2}
H@5: {h5}
H@10: {h10}
H@20: {h20}
H@50: {h50}
H@100: {h100}
Total: {total}""")

H@1: 3762
H@2: 5359
H@5: 7804
H@10: 9579
H@20: 11224
H@50: 13862
H@100: 15905
Total: 29774
