In [None]:
import torch
import string

from transformers import BertTokenizer, BertForMaskedLM
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased').eval()

from transformers import XLNetTokenizer, XLNetLMHeadModel
xlnet_tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
xlnet_model = XLNetLMHeadModel.from_pretrained('xlnet-base-cased').eval()

from transformers import XLMRobertaTokenizer, XLMRobertaForMaskedLM
xlmroberta_tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base')
xlmroberta_model = XLMRobertaForMaskedLM.from_pretrained('xlm-roberta-base').eval()

from transformers import BartTokenizer, BartForConditionalGeneration
bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
bart_model = BartForConditionalGeneration.from_pretrained('facebook/bart-large').eval()

from transformers import ElectraTokenizer, ElectraForMaskedLM
electra_tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-generator')
electra_model = ElectraForMaskedLM.from_pretrained('google/electra-small-generator').eval()

from transformers import RobertaTokenizer, RobertaForMaskedLM
roberta_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
roberta_model = RobertaForMaskedLM.from_pretrained('roberta-base').eval()

In [None]:
def decode(tokenizer, pred_idx):
    ignore_tokens = string.punctuation + '[PAD]'
    tokens = []
    for w in pred_idx:
        token = ''.join(tokenizer.decode(w).split())
        if token not in ignore_tokens:
            tokens.append(token.replace('##', ''))
    return tokens[:10]


def encode(tokenizer, text_sentence, add_special_tokens=True):
    text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
    if tokenizer.mask_token == text_sentence.split()[-1]:
        text_sentence += ' .'

    input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
    mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
    return input_ids, mask_idx

In [None]:
#text_sentence = 'Bremen is a <mask>.'
def get_all_predictions(text_sentence):
    # ========================= BERT =================================
    input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
    with torch.no_grad():
        predict = bert_model(input_ids)[0]
    bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(10).indices.tolist())
    #print(bert)
    # ========================= XLNET LARGE =================================
    input_ids, mask_idx = encode(xlnet_tokenizer, text_sentence, False)
    perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
    perm_mask[:, :, mask_idx] = 1.0  
    target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float)  
    target_mapping[0, 0, mask_idx] = 1.0  

    with torch.no_grad():
        predict = xlnet_model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)[0]
    xlnet = decode(xlnet_tokenizer, predict[0, 0, :].topk(10).indices.tolist())
    #print(xlnet)
    # ========================= XLM ROBERTA BASE =================================
    input_ids, mask_idx = encode(xlmroberta_tokenizer, text_sentence, add_special_tokens=True)
    with torch.no_grad():
        predict = xlmroberta_model(input_ids)[0]
    xlm = decode(xlmroberta_tokenizer, predict[0, mask_idx, :].topk(10).indices.tolist())
    #print(xlm)
    # ========================= BART =================================
    input_ids, mask_idx = encode(bart_tokenizer, text_sentence, add_special_tokens=True)
    with torch.no_grad():
        predict = bart_model(input_ids)[0]
    bart = decode(bart_tokenizer, predict[0, mask_idx, :].topk(10).indices.tolist())
    #print(bart)
    # ========================= ELECTRA =================================
    input_ids, mask_idx = encode(electra_tokenizer, text_sentence, add_special_tokens=True)
    with torch.no_grad():
        predict = electra_model(input_ids)[0]
    electra = decode(electra_tokenizer, predict[0, mask_idx, :].topk(10).indices.tolist())
    #print(electra)
    # ========================= ROBERTA =================================
    input_ids, mask_idx = encode(roberta_tokenizer, text_sentence, add_special_tokens=True)
    with torch.no_grad():
        predict = roberta_model(input_ids)[0]
    roberta = decode(roberta_tokenizer, predict[0, mask_idx, :].topk(10).indices.tolist())
    #print(roberta)
    
    return bert,xlnet,xlm,bart,electra,roberta


In [None]:
predicate = [ 
#' is a',
#' is an',
#' has class',
#' has type',
#' is a particular',
#' is a specific',
#' is an individual',
#' is a unique',
#' is an example of'
#' has superclass'
' is also a'
' subtype of'
' is a subtype of'
' subcategory of'
' is a category of'
' is thereby also a'
' is necessarily also a'
]

In [None]:
import random

testcount = 0
bertcountlist = []
xlnetcountlist = []
xlmcountlist = []
bartcountlist = []
electracountlist = []
robertacountlist = []

LM_count1 = []
LM_count2 = []
LM_count3 = []
LM_count4 = []
LM_count5 = []
LM_count6 = []
WK_count = []

r_wiki = random.sample(list(wikidata.items()), 1000)
wikidata = dict(r_wiki)

for key in wikidata.keys():
    bertlist = set()
    xlnetlist = set()
    xlmlist = set()
    bartlist = set()
    electralist = set()
    robertalist = set()
    for i in range(len(predicate)) :
        text_sentence = str(key) + predicate[i] + ' <mask>.'
        bert,xlnet,xlm,bart,electra,roberta = get_all_predictions(text_sentence)
        bertlist.update([bert[0],bert[1],bert[2]])
        xlnetlist.update([xlnet[0],xlnet[1],xlnet[2]])
        xlmlist.update([xlm[0],xlm[1],xlm[2]])
        bartlist.update([bart[0],bart[1],bart[2]])
        electralist.update([electra[0],electra[1],electra[2]])
        robertalist.update([roberta[0],roberta[1],roberta[2]])
    bertlist = list(bertlist)
    xlnetlist = list(xlnetlist)
    xlmlist = list(xlmlist)
    bartlist = list(bartlist)
    electralist = list(electralist)
    robertalist = list(robertalist)
    
    LM_count1.append(len(bertlist))
    LM_count2.append(len(xlnetlist))
    LM_count3.append(len(xlmlist))
    LM_count4.append(len(bartlist))
    LM_count5.append(len(electralist))
    LM_count6.append(len(robertalist))
    WK_count.append(len(wikidata[key]))
    
    bertcount = 0
    for j in range(len(bertlist)) :
        if bertlist[j] in wikidata[key] :
            bertcount += 1
    bertcountlist.append(bertcount)
    
    xlnetcount = 0
    for j in range(len(xlnetlist)) :
        if xlnetlist[j] in wikidata[key] :
            xlnetcount += 1
    xlnetcountlist.append(xlnetcount)
    
    xlmcount = 0
    for j in range(len(xlmlist)) :
        if xlmlist[j] in wikidata[key] :
            xlmcount += 1
    xlmcountlist.append(xlmcount)
    
    bartcount = 0
    for j in range(len(bartlist)) :
        if bartlist[j] in wikidata[key] :
            bartcount += 1
    bartcountlist.append(bartcount)
    
    electracount = 0
    for j in range(len(electralist)) :
        if electralist[j] in wikidata[key] :
            electracount += 1
    electracountlist.append(electracount)
    
    robertacount = 0
    for j in range(len(robertalist)) :
        if robertalist[j] in wikidata[key] :
            robertacount += 1
    robertacountlist.append(robertacount)
    
    testcount += 1
    if testcount == 100 :
        break
    

bertAVG = sum(bertcountlist, 0.0)/len(bertcountlist)
print("bertAverage :", bertAVG)
xlnetAVG = sum(xlnetcountlist, 0.0)/len(xlnetcountlist)
print("xlnetAverage :", xlnetAVG) 
xlmAVG = sum(xlmcountlist, 0.0)/len(xlmcountlist)
print("xlmAverage :", xlmAVG) 
bartAVG = sum(bartcountlist, 0.0)/len(bartcountlist)
print("bartAverage :", bartAVG)
electraAVG = sum(electracountlist, 0.0)/len(electracountlist)
print("electraAverage :", electraAVG) 
robertaAVG = sum(robertacountlist, 0.0)/len(robertacountlist)
print("robertaAverage :", robertaAVG)

AVG_WKlist = sum(WK_count, 0.0)/len(WK_count)
print("AVG_WKlist :", AVG_WKlist)
AVG_LMlist1 = sum(LM_count1, 0.0)/len(LM_count1)
print("AVG_LMlist1 :", AVG_LMlist1)
AVG_LMlist2 = sum(LM_count2, 0.0)/len(LM_count2)
print("AVG_LMlist2 :", AVG_LMlist2)
AVG_LMlist3 = sum(LM_count3, 0.0)/len(LM_count3)
print("AVG_LMlist3 :", AVG_LMlist3)
AVG_LMlist4 = sum(LM_count4, 0.0)/len(LM_count4)
print("AVG_LMlist4 :", AVG_LMlist4)
AVG_LMlist5 = sum(LM_count5, 0.0)/len(LM_count5)
print("AVG_LMlist5 :", AVG_LMlist5)
AVG_LMlist6 = sum(LM_count6, 0.0)/len(LM_count6)
print("AVG_LMlist6 :", AVG_LMlist6)

Precision1 = bertAVG/AVG_WKlist
print("Precision1 :", Precision1)
Precision2 = xlnetAVG/AVG_WKlist
print("Precision2 :", Precision2)
Precision3 = xlmAVG/AVG_WKlist
print("Precision3 :", Precision3)
Precision4 = bartAVG/AVG_WKlist
print("Precision4 :", Precision4)
Precision5 = electraAVG/AVG_WKlist
print("Precision5 :", Precision5)
Precision6 = robertaAVG/AVG_WKlist
print("Precision6 :", Precision6)

Recall1 = bertAVG/AVG_LMlist1
print("Recall1 :", Recall1)
Recall2 = xlnetAVG/AVG_LMlist2
print("Recall2 :", Recall2)
Recall3 = xlmAVG/AVG_LMlist3
print("Recall3 :", Recall3)
Recall4 = bartAVG/AVG_LMlist4
print("Recall4 :", Recall4)
Recall5 = electraAVG/AVG_LMlist5
print("Recall5 :", Recall5)
Recall6 = robertaAVG/AVG_LMlist6
print("Recall6 :", Recall6)