In [53]:
import logging
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler)
import random
import numpy as np
from finetuning.ner.data_loader import load_and_cache_examples, convert_examples_to_features, get_labels
from transformers import BertTokenizer
from util.arg import ModelConfig
from model.mlm_ner import ReformerNERModel
from finetuning.ner.utils import compute_metrics, show_report, get_test_texts
from finetuning.ner.data_loader import InputFeatures
import tqdm


logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
logger.info("device: {} n_gpu: {}".format(device, n_gpu))
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

04/10/2021 12:08:09 - INFO - __main__ -   device: cuda n_gpu: 1


<torch._C.Generator at 0x1ec97eec4f0>

In [105]:
class DecoderFromNamedEntitySequence():
    def __init__(self, tokenizer, index_to_ner):
        self.tokenizer = tokenizer
        self.index_to_ner = index_to_ner

    def __call__(self, input_text, list_of_pred_ids):
        input_token = input_text
        pred_ner_tag = list_of_pred_ids
        print(len(input_token), len(pred_ner_tag))
        # ----------------------------- parsing list_of_ner_word ----------------------------- #
        list_of_ner_word = []
        entity_word, entity_tag, prev_entity_tag = "", "", ""
        for i in range(len(input_token)):
            pred_ner_tag_str = pred_ner_tag[i]
            if "-B" in pred_ner_tag_str:
                entity_tag = pred_ner_tag_str[:3]
                if prev_entity_tag != entity_tag and prev_entity_tag != "":
                    list_of_ner_word.append({"word": entity_word.replace("▁", " "), "tag": prev_entity_tag, "prob": None})
                entity_word = input_token[i]
                prev_entity_tag = entity_tag
            elif entity_tag+"-I" in pred_ner_tag_str:
                entity_word += input_token[i]
            else:
                if entity_word != "" and entity_tag != "":
                    list_of_ner_word.append({"word":entity_word.replace("▁", " "), "tag":entity_tag, "prob":None})
                entity_word, entity_tag, prev_entity_tag = "", "", ""


        # ----------------------------- parsing decoding_ner_sentence ----------------------------- #
        decoding_ner_sentence = ""
        is_prev_entity = True
        prev_entity_tag = ""
        is_there_B_before_I = False

        for i, (token_str, pred_ner_tag_str) in enumerate(zip(input_token, pred_ner_tag)):
            if i == 0 or i == len(pred_ner_tag)-1: # remove [CLS], [SEP]
                continue
            token_str = token_str.replace('▁', ' ')  # '▁' 토큰을 띄어쓰기로 교체
            print(decoding_ner_sentence)
            if '-B' in pred_ner_tag_str:
                if is_prev_entity is True:
                    decoding_ner_sentence += ':' + prev_entity_tag+ '>'

                if token_str[0] == ' ':
                    token_str = list(token_str)
                    token_str[0] = ' <'
                    token_str = ''.join(token_str)
                    decoding_ner_sentence += token_str
                else:
                    decoding_ner_sentence += '<' + token_str
                is_prev_entity = True
                prev_entity_tag = pred_ner_tag_str[:3] # 첫번째 예측을 기준으로 하겠음
                is_there_B_before_I = True

            elif '-I' in pred_ner_tag_str:
                decoding_ner_sentence += token_str

                if is_there_B_before_I is True: # I가 나오기전에 B가 있어야하도록 체크
                    is_prev_entity = True
            else:
                if is_prev_entity is True:
                    decoding_ner_sentence += ':' + prev_entity_tag+ '>' + token_str
                    is_prev_entity = False
                    is_there_B_before_I = False
                else:
                    decoding_ner_sentence += token_str

        return list_of_ner_word, decoding_ner_sentence

In [106]:
def token2id(input_texts,
             max_seq_len,
             tokenizer,
             pad_token_label_id=-100,
             cls_token_segment_id=0,
             pad_token_segment_id=0,
             sequence_a_segment_id=0,
             mask_padding_with_zero=True):

    cls_token = tokenizer.cls_token
    sep_token = tokenizer.sep_token
    unk_token = tokenizer.unk_token
    pad_token_id = tokenizer.pad_token_id

    ids=[]
    masks = []
    tokenss = []
    for input_text in input_texts:
        # Tokenize word by word (for NER)
        tokens = []
        word_tokens = tokenizer.tokenize(input_text)
        if not word_tokens:
            word_tokens = [unk_token]  # For handling the bad-encoded word
        tokens.extend(word_tokens)

        # Account for [CLS] and [SEP]
        special_tokens_count = 2
        if len(tokens) > max_seq_len - special_tokens_count:
            tokens = tokens[: (max_seq_len - special_tokens_count)]

        # Add [SEP] token
        tokens += [sep_token]

        # Add [CLS] token
        tokens = [cls_token] + tokens

        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

        # Zero-pad up to the sequence length.
        padding_length = max_seq_len - len(input_ids)
        input_ids = input_ids + ([pad_token_id] * padding_length)
        attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
        ids.append(input_ids)
        masks.append(attention_mask)
        tokenss.append(tokens)

    input_ids = torch.tensor(ids, dtype=torch.long)
    attention_mask = torch.tensor(masks, dtype=torch.long)

    return input_ids, attention_mask, tokenss

In [123]:
def main(input_texts):
    logger.info("***** INTERACTIVE *****")
    config_path = "../config/mlm/ner-pretrain-small.json"
    config = ModelConfig(config_path).get_config()
    tokenizer = BertTokenizer(vocab_file=config.vocab_path, do_lower_case=False)
    input_ids, attention_mask, tokenss = token2id(input_texts, config.max_seq_len, tokenizer)

    model = ReformerNERModel(
        num_tokens=tokenizer.vocab_size,
        dim=config.dim,
        depth=config.depth,
        heads=config.n_head,
        max_seq_len=config.max_seq_len,
        causal=False,  # auto-regressive 학습을 위한 설정
        num_labels=len(get_labels(config))
    ).to(device)

    checkpoint = torch.load(config.ner_checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    logger.info(f'Load Reformer[NER] Model')
    logger.info("Num features=%s", len(input_ids))

    pad_token_label_id = torch.nn.CrossEntropyLoss().ignore_index
    model.eval()
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(SEED)

    logger.info("Start evaluating!")
    label_lst = get_labels(config)
    with torch.no_grad():
        input_ids = torch.tensor(input_ids, dtype=torch.long).to(device)
        logits = model(input_ids)
    preds = logits.detach().cpu().numpy()

    # Slot result
    preds = np.argmax(preds, axis=2)
    slot_label_map = {i: label for i, label in enumerate(label_lst)}
    preds_list = [[] for _ in range(input_ids.shape[0])]

    for i in range(len(input_texts)):
        for j in range(len(input_texts[i])):
            if input_ids[i, j] != pad_token_label_id:
                preds_list[i].append(slot_label_map[preds[i][j]])
    labels = get_labels(config)
    index_to_ner = {k: v for k, v in enumerate(labels)}
    decoder_from_res = DecoderFromNamedEntitySequence(tokenizer=tokenizer, index_to_ner=index_to_ner)
    for input_text, preds in zip(input_texts, preds_list):
        line = ""
        for word, pred in zip(input_text, preds):
            if pred == 'O':
                line = line + word + " "
            else:
                line = line + "[{}:{}] ".format(word, pred)
        print(line)
    logger.info("***** Eval results *****")

In [124]:
input_texts = []
input_text = '지난 1 신인드래프트 일부 8순위로 코카콜라음료에 입단한 박성국은 파인크리크 풍후 아펜젤치즈로 출격했다 .'
input_texts.append(input_text)
main(input_texts)

04/10/2021 13:04:43 - INFO - __main__ -   ***** INTERACTIVE *****
04/10/2021 13:04:43 - INFO - __main__ -   Load Reformer[NER] Model
04/10/2021 13:04:43 - INFO - __main__ -   Num features=1
04/10/2021 13:04:43 - INFO - __main__ -   Start evaluating!
  input_ids = torch.tensor(input_ids, dtype=torch.long).to(device)
04/10/2021 13:04:43 - INFO - __main__ -   ***** Eval results *****


[지:EVT-B] [난:DAT-B] [ :DAT-I] [1:EVT-B] [ :CVL-B] [신:EVT-B] [인:EVT-B] 드 [래:NUM-B] [프:NUM-B] 트 [ :CVL-B] [일:CVL-B] [부:CVL-B] [ :CVL-B] [8:CVL-I] 순 위 로 [ :PER-B] [코:PER-I] [카:PER-B] [콜:PER-I] [라:PER-I] 음 료 [에:CVL-B] [ :CVL-B] [입:ORG-I] [단:ORG-I] [한:ORG-I]   박 성 국 은 [ :PER-B] 파 인 크 리 크   풍 후   아 펜 젤 치 즈 로   출 격 했 다   . 
