In [1]:
""" 모델 불러와서 MLM Test """

from plm_trainer import PLMDataset
import torch
import json
from model.bert import Bert
from transformers import BertTokenizer
import sys
import utils
from torch.utils.data import DataLoader
import os 

def predict_mask_token(bert,text,with_cuda=False):
    token_result = bert.tokenizer(text, is_split_into_words=True,max_length=bert.config["max_seq_len"],padding="max_length",return_token_type_ids=True)


    data ={}

    data["input_ids"] = torch.tensor([token_result["input_ids"]],dtype=torch.int) 
    data["seg_ids"]=torch.tensor([token_result["token_type_ids"]],dtype=torch.int)
    data["att_masks"]=torch.tensor([token_result["attention_mask"]],dtype=torch.int)

    mlm_positions=[]
    mlm_masks=[]
    for index,id in enumerate(token_result["input_ids"]):
        if id ==bert.vocab[utils.MASK_TOKEN]:
            mlm_positions.append(index)
            mlm_masks.append(1)

    if len(mlm_positions)<bert.config["max_mask_tokens"]:
        pad_num = bert.config["max_mask_tokens"]-len(mlm_positions)
        mlm_positions.extend([0]*pad_num)
        mlm_masks.extend([0]*pad_num)
    
    # [max_token_num]
    data["mlm_positions"]=torch.tensor([mlm_positions],dtype=torch.int)
    data["mlm_masks"]=torch.tensor([mlm_masks],dtype=torch.int)

    
    result,att_score_list = bert(data)

    print(bert.convert_mask_pred_to_token(result["mask_pred"],data["mlm_masks"],top_k=6))




train_path = "/root/data/ojt/datasets/book_corpus_debug.txt"
model_dir = "/root/data/ojt/output/debug_model_mlm"
model_name = "checkpoint_6000.pt"

# train_path = "/root/data/ojt/datasets/books_corpus_p1_1.txt"
# model_dir = "/root/data/ojt/output/bert_small_book_1_mlm_sop_lr"
# model_name = "checkpoint_106000.pt"

config_path = os.path.join(model_dir,utils.MODEL_CONFIG_NAME)
vocab_file_path = os.path.join(model_dir,utils.VOCAB_TXT_FILE_NAME)
model_path = os.path.join(model_dir,model_name)


with open(config_path,"r") as cfg_json:
    config = json.load(cfg_json)

tokenizer=BertTokenizer(vocab_file=vocab_file_path,do_lower_case=True)

bert=Bert(config=config,tokenizer=tokenizer,with_cuda=False,return_mlm=True,return_sop=False)

device = torch.device('cpu')
print("Loading model %s"%model_path)
bert.load_state_dict(torch.load(model_path,map_location=device))
bert.eval()

print(config["max_seq_len"])
id_to_vocab = {v:k for k,v in tokenizer.vocab.items()}
plm_dataset=PLMDataset(train_path,tokenizer,128,config["max_seq_len"],config["max_mask_tokens"],cached_dir=model_dir,use_cache=False,mlm_data=True)
plm_data_loader = DataLoader(plm_dataset, batch_size=1,num_workers=1)

""" 학습 데이터들에 대한 예측 결과 보기 """
# for i,data in enumerate(plm_data_loader):
#     # print(data)
#     input_ids=data["input_ids"].view(-1).tolist()
#     # for i in input_ids:
#     #     print(id_to_vocab[i],end=" ")
#     mask_labels=data["mlm_labels"].view(-1).tolist()
#     for i in mask_labels:
#         print(id_to_vocab[i],end=" ")
#     result,att_score_list = bert(data)
#     print("")
#     print(bert.convert_mask_pred_to_token(result["mask_pred"],data["mlm_masks"],top_k=1))
#     print("")


""" 입력한 Text에 대한 예측 결과 보기 """
#i wish i had a better answer to that question
text = "i wish i had a better [MASK] to that question . starlings , new york is not the place youd expect much to happen ."
predict_mask_token(bert,text)
text = "i wish i [MASK] a better answer to that question . starlings , new york is not the place youd expect much to happen ."
predict_mask_token(bert,text)
text = "i wish [MASK] had a better answer [MASK] that question . starlings , new york is not the place youd expect much to happen ."
predict_mask_token(bert,text)
text = "i wish i had a better answer to that question . starlings , new york is not the [MASK] youd expect much to happen ."
predict_mask_token(bert,text)
text = "i wish i had a better answer to that question . starlings , [MASK] [MASK] york is not the place youd expect much to happen ."
predict_mask_token(bert,text)


Loading model /root/data/ojt/output/debug_model_mlm/checkpoint_6000.pt
128
[[['an', 'so', 'q', 'be', '##king', '##v']]]
[[['ha', 'to', '##e', '##all', 'no', '##er']]]
[[['##us', 'i', 'p', '##pp', 'me', '##nt'], ['to', '##e', '##all', '##er', 'ha', 'a']]]
[[['##t', 'place', 'you', '##i', 'to', 'where']]]
[[['ne', 'where', 'know', '##pt', '##v', '##nt'], ['##w', 'the', '##ad', '##en', 'star', '##ch']]]


In [5]:
""" 데이터 제작 확인 """
from transformers import BertTokenizer
from plm_dataset import PLMDataset
from torch.utils.data import DataLoader

vocab_file_path = "/root/data/ojt/output/debug_model/vocab.txt"
train_path = "/root/data/ojt/datasets/book_corpus_debug.txt"

tokenizer=BertTokenizer(vocab_file=vocab_file_path,do_lower_case=True)
plm_dataset=PLMDataset(train_path,tokenizer,data_max_seq_len=64,model_max_seq_len=128,max_mask_tokens=20)
train_data_loader = DataLoader(plm_dataset, batch_size=4,num_workers=1)

vocab = tokenizer.vocab
id_to_vocab={v:k for k,v in vocab.items()}

print(iter(train_data_loader).next())

for data in train_data_loader:
    input_ids=data["input_ids"].tolist()
    mlm_positions=data["mlm_positions"].tolist()
    mlm_labels=data["mlm_labels"].tolist()

     
    for i,seq in enumerate(input_ids):
        mlm_label_index=0
        # print(mlm_positions[i])
        for j,token_id in enumerate(seq):
            

            if j!=0 and j in mlm_positions[i]:
            
                
                print("[M]"+id_to_vocab[mlm_labels[i][mlm_label_index]],end=" ")
                mlm_label_index+=1
            elif token_id!=0:
                print(id_to_vocab[token_id],end=" ")
      

        print("")
    
    break

Reading pkl file = /root/data/ojt/datasets/book_corpus_debug.pkl
doc num = 1, sentence num = 31
{'input_ids': tensor([[  2,  26, 163,  53,   4,  26, 129,  48,  18, 100,  58,  58,  79,  85,
          53,  52,  79,  86,   4,  34,  60, 124,  58, 133,   7,   3, 127,  44,
         174,   5, 158,  52,  42, 140,  64, 118,   4,  58,  78, 188,  99,  48,
          22,  65,   4,  47,   4,  58,  30,   4, 143,  86, 129, 122, 111,   4,
           3,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0],
        [  2, 106,   4, 100, 171, 169,  47, 116, 115,  23,   4,  44,  58, 131,
          36,  49, 174,  47,   5, 131,  20, 141,  72,  55,   7,  78,   4, 153,
 