In [17]:
import os
import torch
import gdown
import random
import logging
import numpy as np
from collections import Counter
from seqeval.metrics import precision_score, recall_score, f1_score, classification_report

In [3]:
logger = logging.getLogger(__name__)

In [None]:
def get_test_texts(args):
    texts = []
    with open(os.path.join(args.data_dir, args.test_file), 'r', encoding='utf-8') as f:
        for line in f:
            text, _ = line.split('\t')
            text = text.split()
            texts.append(text)

    return texts

build_vocab

In [7]:
def build_vocab(args):

    # load data & split sentence, ner_tag
    sentences , ner_tags = [], []
    with open(os.path.join(args.data_dir,args.data_name), 'r', encoding='utf-8') as f:
        document = f.readlines()
        for line in document:
            sentence , ner_tag = line.split('\t')
            sentences.append(sentence)
            ner_tags.append(ner_tag)

    # total word, char list 만들기
    total_word , total_char = [], [] 
    for sentence in sentences:
        sentence = sentence.split()
        for word in sentence:
            for char in word:
                total_char.append(char)
            total_word.append(word)

    # vocab_dir  만들기
    if not os.path.exists(args.vocab_dir):
        os.makedirs(args.vocab_dir)

    # build vocab (word, char)
    word_vocab, char_vocab = [], []

    word_vocab_path = os.path.join(args.vocab_dir, "word_vocab")
    char_vocab_path = os.path.join(args.vocab_dir, "char_vocab")

    word_counts = Counter(total_word)
    word_vocab.append("PAD")
    word_vocab.append("UNK")
    word_vocab.extend([x[0] for x in word_counts.most_common()])
    logger.info("Total word vocabulary size: {}".format(len(word_vocab)))

    with open(word_vocab_path, 'w', encoding='utf-8') as f:
            for word in word_vocab:
                f.write(word + "\n")

    char_counts = Counter(total_char)
    char_vocab.append("PAD")
    char_vocab.append("UNK")
    char_vocab.extend([x[0] for x in char_counts.most_common()])
    logger.info("Total char vocabulary size: {}".format(len(char_vocab)))

    with open(char_vocab_path, 'w', encoding='utf-8') as f:
            for char in char_vocab:
                f.write(char + "\n")

    # Set the exact vocab size
    # If the original vocab size is smaller than args.vocab_size, then set args.vocab_size to original one
    with open(word_vocab_path, 'r', encoding='utf-8') as f:
        word_list = f.readlines()
        args.word_vocab_size = min(len(word_list), args.word_vocab_size)

    with open(char_vocab_path, 'r', encoding='utf-8') as f:
        char_list = f.readlines()
        args.char_vocab_size = min(len(char_list), args.char_vocab_size)

    logger.info("args.word_vocab_size: {}".format(args.word_vocab_size))
    logger.info("args.char_vocab_size: {}".format(args.char_vocab_size))


01/21/2022 16:42:36 - INFO - __main__ -   Total word vocabulary size: 307250
01/21/2022 16:42:36 - INFO - __main__ -   Total char vocabulary size: 2163
01/21/2022 16:42:36 - INFO - __main__ -   args.word_vocab_size: 307250
01/21/2022 16:42:36 - INFO - __main__ -   args.char_vocab_size: 2163


load_vocab

In [8]:
def load_vocab(args):
    word_vocab_path = os.path.join(args.vocab_dir, "word_vocab")
    char_vocab_path = os.path.join(args.vocab_dir, "char_vocab")


    if not os.path.exists(word_vocab_path):
        logger.warning("Please build word vocab first!!!")
        return

    if not os.path.exists(char_vocab_path):
        logger.warning("Please build char vocab first!!!")
        return

    word_vocab = {}
    word_ids_to_tokens = []

    #load word_vocab
    with open(word_vocab_path,'r', encoding='utf-8') as f:
        word_list = f.readlines()
        args.word_vocab_size = min(len(word_list), args.word_vocab_size)

        for idx, word in enumerate(word_list[:args.word_vocab_size]):
            word = word.strip()
            word_vocab[word] = idx
            word_ids_to_tokens.append(word) 

    char_vocab = {}
    char_ids_to_tokens = []

    #load_char_vocab
    with open(char_vocab_path,'r', encoding='utf-8') as f:
        char_list = f.readlines()
        args.char_vocab_size = min(len(char_list), args.char_vocab_size)

        for idx, char in enumerate(char_list[:args.char_vocab_size]):
            char = char.strip()
            char_vocab[char] = idx
            char_ids_to_tokens.append(char) 

        return word_vocab, word_ids_to_tokens, char_vocab, char_ids_to_tokens

get w2v

In [9]:
def download_w2v(args):
    # """ Download pretrained word vector """
    w2v_path = os.path.join(args.wordvec_dir, args.w2v_file)
    # Pretrained word vectors
    if not os.path.exists(w2v_path):
        logger.info("Downloading pretrained word vectors...")
        gdown.download("https://drive.google.com/uc?id=1YX7yHm5MHZ-Icdm1ZX4X9_wD7UrXexJ-", w2v_path, quiet=False)

get label

In [None]:
def get_labels(args):
    return [label.strip() for label in open(os.path.join(args.data_dir, args.label_file), 'r', encoding='utf-8')]

In [None]:
def load_label_vocab(args):
    label_vocab = dict()
    for idx, label in enumerate(get_labels(args)):
        label_vocab[label] = idx

    return label_vocab

In [None]:
def init_logger():
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt='%m/%d/%Y %H:%M:%S',
                        level=logging.INFO)

In [None]:
def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if not args.no_cuda and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

In [None]:
def compute_metrics(labels, preds):
    assert len(preds) == len(labels)
    return f1_pre_rec(labels, preds)

In [None]:
def f1_pre_rec(labels, preds):
    return {
        "precision": precision_score(labels, preds, suffix=True),
        "recall": recall_score(labels, preds, suffix=True),
        "f1": f1_score(labels, preds, suffix=True)
    }

In [None]:
def show_report(labels, preds):
    return classification_report(labels, preds, suffix=True)