In [None]:
import pandas as pd
from pyknp import Juman
import utils
from model import *
import torch
import os
from pytorch_pretrained_bert import BertTokenizer, BertConfig, WEIGHTS_NAME, CONFIG_NAME
from torch.utils.data import DataLoader, TensorDataset
device = torch.device('cuda') if torch.cuda.is_available() else torch.device("cpu")
print(device)

In [None]:
MODEL_URL = "checkpoints/ALL_20200830171210/cv0/"
excel_file = "行動データ3言語54subs (002).xlsx"

In [None]:
class MorphologicalAnalyzer(object):

    def __init__(self, analyzer_name='juman'):
        self.analyzer_name = analyzer_name
        if self.analyzer_name == 'juman':
            from pyknp import Juman
            self.analyzer = Juman()
        elif self.analyzer_name == 'mecab':
            import MeCab
            self.analyzer = MeCab.Tagger()

    def analyze(self, text):
        if self.analyzer_name == 'juman':
            return [w.midasi for w in self.analyzer.analysis(text).mrph_list()]
        elif self.analyzer_name == 'mecab':
            return self.analyzer.parse(text).split()

def analyze_morph_and_mask(sent, analyzer, tokenizer):
    tokenized_sent = []
    event_mask = []
    for item_id, item in enumerate(sent):
        analyzed_item = ' '.join(analyzer.analyze(item))
        tokenized_item = tokenizer.tokenize(analyzed_item)
        tokenized_sent += tokenized_item
        if item_id != 2:
            event_mask += [0] * len(tokenized_item)
        else:
            event_mask += [1] * len(tokenized_item)
    
    assert len(tokenized_sent) == len(event_mask)
    tokenized_sent = ['[CLS]'] + tokenized_sent + ['[SEP]']
    event_mask = [0] + event_mask + [0]
    return tokenized_sent, event_mask

In [None]:
in_data = pd.read_excel(excel_file, encoding="SHIFT-JIS")
sentences = in_data.values[1:-6,[6,7,8]].tolist()
sentences[0]

ma = MorphologicalAnalyzer()
tokenizer = BertTokenizer.from_pretrained(MODEL_URL,
    do_lower_case=False, 
    do_basic_tokenize=False
)

In [None]:
sent_toks = [analyze_morph_and_mask(sent, ma, tokenizer)[0] for sent in sentences]
event_masks = [analyze_morph_and_mask(sent, ma, tokenizer)[1] for sent in sentences]

max_len = max([len(sent) for sent in sent_toks])
padded_sent_toks = utils.padding_2d(sent_toks, max_len, pad_tok='[PAD]')
padded_event_masks = utils.padding_2d(event_masks, max_len, pad_tok=0)
attn_masks = [[1] * len(sent) + [0] * (max_len - len(sent)) for sent in sent_toks]
padded_sent_tids = [tokenizer.convert_tokens_to_ids(sent) for sent in padded_sent_toks]

test_tensors = TensorDataset(
    torch.tensor(padded_sent_tids),
    torch.tensor(padded_event_masks),
    torch.tensor(attn_masks)
)

test_dataset = DataLoader(test_tensors, batch_size=16, shuffle=False)

In [None]:
config = BertConfig.from_json_file(os.path.join(MODEL_URL, 'config.json'))
model = DocEmbMultiTaskTRC(
    config,
    num_emb=2,
    task_list=['DCT', 'T2E', 'E2E', 'MAT'],
    num_labels=len(tag2id)
)
state_dict = torch.load(os.path.join(MODEL_URL, 'pytorch_model.bin'))
model.load_state_dict(state_dict)
model.to(device)

In [None]:
tag2id = {'AFTER': 4, 'OVERLAP-OR-AFTER': 3, 'OVERLAP': 2, 'BEFORE-OR-OVERLAP': 1, 'BEFORE': 0, 'VAGUE': 5}
id2tag = {v: k for k, v in tag2id.items()}
output_ids = []
model.eval()
for b_tok_ids, b_event_mask, b_attn_mask in test_dataset:
    b_tok_ids = b_tok_ids.to(device)
    b_event_mask = b_event_mask.to(device)
    b_attn_mask = b_attn_mask.to(device)
    with torch.no_grad():
        b_pred_logits = model(
            b_tok_ids, 
            b_event_mask, 
            b_event_mask, 
            'DCT', 
            attention_mask=b_attn_mask, 
            labels=None
        )
        output_ids += torch.argmax(b_pred_logits, dim=-1).squeeze(1).cpu().detach().tolist()
outputs = [id2tag[tag_id] for tag_id in output_ids]

in_data.insert(9, 'DCT Tag', ['DCT Tag'] + outputs + ([''] * 6))
in_data.to_excel('行動データ3言語54subs (002)_tagged.xlsx', index = None, header=True)

In [None]:
for sent, tag in zip(sent_toks, outputs):
    print(sent, tag)