## mGenre prediction with NER and evaluation

### Data and libraries downloading

In [None]:
#download weights of mgenre

!wget https://dl.fbaipublicfiles.com/GENRE/fairseq_multilingual_entity_disambiguation.tar.gz
! tar -xvf workspace/kbqa/kbqa/fairseq_multilingual_entity_disambiguation.tar.gz

In [2]:
import pickle
from genre.trie import Trie, MarisaTrie
import torch
from genre.fairseq_model import mGENRE
import pandas as pd
from tqdm import tqdm

In [None]:
#! wget https://dl.fbaipublicfiles.com/GENRE/lang_title2wikidataID-normalized_with_redirect.pkl

# mapping between mentions and Wikidata IDs and number of times they appear on Wikipedia
with open("workspace/kbqa/kbqa/lang_title2wikidataID-normalized_with_redirect.pkl", "rb") as f:
    lang_title2wikidataID = pickle.load(f)

In [None]:
#! wget http://dl.fbaipublicfiles.com/GENRE/wikidataID2lang_title-normalized_with_redirect.pkl

# mapping between wikidataIDs and (lang, title) in all languages
with open("workspace/kbqa/kbqa/wikidataID2lang_title-normalized_with_redirect.pkl", "rb") as f:
    wikidataID2lang_title = pickle.load(f)

In [None]:
#! wget http://dl.fbaipublicfiles.com/GENRE/titles_lang_all105_marisa_trie_with_redirect.pkl

# memory efficient but slower prefix tree (trie) -- it is implemented with `marisa_trie`
with open("workspace/kbqa/kbqa/titles_lang_all105_marisa_trie_with_redirect.pkl", "rb") as f:
    trie = pickle.load(f)

In [None]:
#! wget http://dl.fbaipublicfiles.com/GENRE/mention2wikidataID_with_titles_label_alias_redirect.pkl

# mapping between mentions and Wikidata IDs and number of times they appear on Wikipedia
with open("workspace/kbqa/kbqa/mention2wikidataID_with_titles_label_alias_redirect.pkl", "rb") as f:
    mention2wikidataID = pickle.load(f)

### Predicting mGenre

In [3]:
data_ner = pd.read_csv('ner_experiments_mgenre.csv')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_mGENRE = mGENRE.from_pretrained("/fairseq_multilingual_entity_disambiguation").eval()
model_mGENRE.to(device)
print("mGENRE loaded")

In [None]:
#prediction procedure is taken from here https://github.com/facebookresearch/GENRE/tree/main/examples_mgenre

def pred(sentences, model):
    
    prediction = model.sample(
        sentences,
        prefix_allowed_tokens_fn=lambda batch_id, sent: [
            e for e in trie.get(sent.tolist())
            if e < len(model.task.target_dictionary)
        ],
        text_to_id=lambda x: max(lang_title2wikidataID[tuple(reversed(x.split(" >> ")))], key=lambda y: int(y[1:])),
        marginalize=True,
    )
    return prediction

In [None]:
#example prediction for the whole dataset

for k in tqdm(range(len(data_ner))):
    
    sent = [data_ner.loc[k, 'question_ner_spacy_pretrained_largecase']]
    prediction = pred(sent, model_mGENRE)
    pred_inds = []
    for j in range(len(prediction[0])):
        pred_inds.append(prediction[0][j]['id'])
    
    data_ner.loc[k, 'pred_ner_spacy_pretrained_largecase'] = ', '.join(pred_inds)

### Evaluate accuracy

In [None]:
def topk_accuracy(df, col):
    
    count_1 = 0
    count_2 = 0
    count_3 = 0
    count_4 = 0
    count_5 = 0
    for k in range(len(df)):
        
        if df.loc[k, col].split(', ')[0] == df.loc[k, 'subject']:
            count_1 += 1
            count_2 += 1
            count_3 += 1
            count_4 += 1
            count_5 += 1
            
        elif len(df.loc[k, col].split(', ')) >= 2 and df.loc[k, col].split(', ')[1] == df.loc[k, 'subject']:
            count_2 += 1
            count_3 += 1
            count_4 += 1
            count_5 += 1
            
        elif len(df.loc[k, col].split(', ')) >= 3 and df.loc[k, col].split(', ')[2] == df.loc[k, 'subject']:
            count_3 += 1
            count_4 += 1
            count_5 += 1
                
        elif len(df.loc[k, col].split(', ')) >= 4 and df.loc[k, col].split(', ')[3] == df.loc[k, 'subject']:
            count_4 += 1
            count_5 += 1
                
        elif len(df.loc[k, col].split(', ')) >= 5 and df.loc[k, col].split(', ')[4] == df.loc[k, 'subject']:
            count_5 += 1
                
    print('Top-1 accuracy:', count_1/len(df))
    print('Top-2 accuracy:', count_2/len(df))
    print('Top-3 accuracy:', count_3/len(df))
    print('Top-4 accuracy:', count_4/len(df))
    print('Top-5 accuracy:', count_5/len(df))


In [None]:
topk_accuracy(df, 'pred_ner_spacy_pretrained_largecase')