In [1]:
import os, json
import torch
import argparse
import pandas as pd
from tqdm import tqdm
from mask_dataset import MaskedDataset
from modules.xlmr_base_model import XLMBaseModel
from modules.bert_base_model import BERTBaseModel
from mask_prediction import load_objects, predict_mask_tokens, batchify
os.getcwd()

'/home/xzhao/workspace/probing-mulitlingual/src'

In [2]:
device = torch.device("cuda:0")
model = XLMBaseModel(device)
# model = BERTBaseModel(device)

In [6]:
# load dataset
mlama = MaskedDataset("mlama", model.mask_token, model.name)
# mlama = MaskedDataset("mlama", '[MASK]', 1)
data = mlama.get_lang_type('en', 'P131')

lang2objs = {}
for lang in mlama.langs:
    objs = load_objects(lang, model.name, model)
    lang2objs.update({lang:objs})

In [7]:
# Test mbert get_mask_tokens function
sent = data['sent']
sent2 = mlama.replace_with_mask(sent, 2, model.mask_token)
predicts = model.get_mask_tokens(sent2[:32], 2)
predicts

[[['▁the'], ['▁area']],
 [['▁the'], ['▁Caribbean']],
 [['▁the'], ['▁Carolina']],
 [['▁the'], ['▁County']],
 [['▁the'], ['▁County']],
 [['▁North'], ['▁Carolina']],
 [['▁the'], ['▁distance']],
 [['▁the'], ['▁India']],
 [['▁the'], ['▁County']],
 [['▁the'], ['▁Netherlands']],
 [['▁the'], ['▁Park']],
 [['▁the'], ['▁of']],
 [['▁the'], ['▁Angeles']],
 [['▁San'], ['▁County']],
 [['▁New'], ['▁Carolina']],
 [['▁North'], ['▁County']],
 [['▁the'], ['▁County']],
 [['▁the'], ['▁County']],
 [['▁the'], ['▁China']],
 [['▁the'], ['▁distance']],
 [['▁the'], ['▁Park']],
 [['▁Uttar'], ['▁Pradesh']],
 [['▁the'], ['▁Philippines']],
 [['▁the'], ['▁County']],
 [['▁the'], ['▁County']],
 [['▁the'], ['▁area']],
 [['▁the'], ['▁City']],
 [['▁the'], ['▁County']],
 [['▁West'], ['▁County']],
 [['▁Grand'], ['▁Canyon']],
 [['▁the'], ['▁City']],
 [['▁the'], ['▁County']]]

In [16]:
def tokens2id(pred, tokenizer):
    token_ids = []
    for tokens in pred:
        assert(all([len(token) == 1 and type(token[0]) == str for token in tokens]) )
        tokens = [token[0] for token in tokens]
        token_ids.append(tokenizer.convert_tokens_to_ids(tokens))
    return token_ids
tokens2id(predicts, model.tokenizer)

[[70, 16128],
 [70, 223487],
 [70, 96220],
 [70, 47064],
 [70, 47064],
 [23924, 96220],
 [70, 62488],
 [70, 5596],
 [70, 47064],
 [70, 231118],
 [70, 5227],
 [70, 111],
 [70, 31754],
 [1735, 47064],
 [2356, 96220],
 [23924, 47064],
 [70, 47064],
 [70, 47064],
 [70, 9098],
 [70, 62488],
 [70, 5227],
 [156910, 21979],
 [70, 129535],
 [70, 47064],
 [70, 47064],
 [70, 16128],
 [70, 6406],
 [70, 47064],
 [10542, 47064],
 [12801, 193266],
 [70, 6406],
 [70, 47064]]

In [14]:
# Test predict_mask_tokens function for mbert
lang = "en"
rel = "P103"
predict_mask_tokens(model, mlama, objs, lang, rel, "../result/prediction-mbert")

Predict mask tokens for en-P103
Start to predict masked tokens for lang-en, relation P103


In [None]:
# Test evaluation.py:predict_mask_tokens() - 1
lang2objs = {}
for lang in mlama.langs:
    objs = load_objects(lang, xlmr)
    lang2objs.update({lang:objs})
    
# Test evaluation.py:predict_mask_tokens() - 2
for lang in mlama.langs:
    objs = load_objects(lang, xlmr)
    for rel in objs.keys():
        # predict_mask_tokens(xlmr, mlama, lang2objs[lang], lang, rel)
        predict_mask_tokens(xlmr, mlama, objs, lang, rel)

In [None]:
# Test generating batches of sentences
frame = pd.DataFrame(columns=['id', 'sent', 'prediction'])
for rel in objs.keys():
    print('Start to analyze {}'.format(rel))
    maxlen = max(list(objs[rel].keys()))
    for i in range(2):
        relations = mlama.get_lang_type('en', rel)
        org_sents = relations['sent']
        ids = relations.index
        sents = mlama.replace_with_mask(org_sents, i+2, model.mask_token)
        batches = batchify(list(zip(ids, sents)), 32)
        for batch in batches:
            ids = list(zip(*batch))[0]
            sents = list(zip(*batch))[1]
            results = xlmr.get_mask_tokens(sents, i+2)
            samples = list(zip(ids, sents, results))
            item = {
                'id': ids,
                'sent': sents,
                'prediction': results}
            frame = pd.concat([frame, pd.DataFrame(item)])
    break

In [11]:
frame

Unnamed: 0,id,sent,prediction
0,16521,A605 road is located in <mask> <mask> .,"[[▁the], [▁area]]"
1,16522,Kupreanof Island is located in <mask> <mask> .,"[[▁the], [▁Caribbean]]"
2,16523,Pershing County is located in <mask> <mask> .,"[[▁the], [▁Carolina]]"
3,16524,Porcupine Hills is located in <mask> <mask> .,"[[▁the], [▁County]]"
4,16525,Minnesota State Highway 36 is located in <mask...,"[[▁the], [▁County]]"
...,...,...,...
12,17397,John Paul II Catholic University of Lublin is ...,"[[▁Lublin], [,], [▁Poland]]"
13,17398,Sugarloaf Key is located in <mask> <mask> <mas...,"[[▁the], [,], [▁Island]]"
14,17399,Cheyenne Frontier Days is located in <mask> <m...,"[[▁the], [,], [,]]"
15,17400,Heaton Park is located in <mask> <mask> <mask> .,"[[▁the], [ton], [▁of]]"


In [25]:
# Test batchify - 1
batches = batchify(sentences, 32)

In [39]:
# Test load_objects
objs = load_objects('zh', None)

In [None]:
# Test batchify function - 2
def batchify(sents, batch_size):
    l = len(sents)
    for ndx in range(0, l, batch_size):
        yield sents[ndx:min(ndx + batch_size, l)]

for batch in batches:
    print(batch)
    results = xlmr.get_mask_tokens(batch, 1)
    print(results)