In [1]:
import numpy as np
import torch
from torch.utils import data
from transformers import AutoTokenizer
from itertools import chain

In [16]:
tokenizer = AutoTokenizer.from_pretrained('roberta-base')

In [3]:
def save_vocab(fpath, target_list):
    tag_set = list(set(tuple(chain.from_iterable(target_list))))
    VOCAB = ['[PAD]', '[UNK]', '[CLS]', '[SEP]']
    VOCAB.extend(tag_set)
    VOCAB = tuple(VOCAB)
    with open(fpath,'w') as save:
        save.write(" ".join(VOCAB))

In [4]:
def load_vocab(fpath):
    with open(fpath,'r') as f:
        f = f.readline()
        VOCAB = f.strip().split()
    return VOCAB

In [5]:
class DataLoader(data.Dataset):
    def __init__(self, fpath, args):
        entries = open(fpath, 'r', encoding='utf-8').read().strip().split('\n\n')
        source_list, target_list = [], []
        for lines in entries:
            source = [line.split('\t')[0] for line in lines.splitlines()]
            target = [line.split('\t')[1] for line in lines.splitlines()]

            source_list.append(source)
            target_list.append(target)

        self.source_list = source_list
        self.target_list = target_list
        
        if fpath == args.trainset:
            save_vocab(args.vocab_path, target_list)
        
        self.vocab = load_vocab(args.vocab_path)

        self.tag2idx = {tag: idx for idx, tag in enumerate(self.vocab)}
        self.idx2tag = {idx: tag for idx, tag in enumerate(self.vocab)}


    def __len__(self):
        return len(self.source_list)

    def __getitem__(self, idx):
        source, target = self.source_list[idx], self.target_list[idx]

        source = ['[CLS]'] + source + ['[SEP]']
        target = ['[CLS]'] + target + ['[SEP]']
        x, y = [], []
        is_heads = []
        
        for w, t in zip(source, target):
            tokens = tokenizer.tokenize(w) if w not in ('[CLS]', '[SEP]') else [w]
            xx = tokenizer.convert_tokens_to_ids(tokens)
            is_head = [1] + [0]*(len(tokens)-1)

            t = [t] + ["[PAD]"]*(len(tokens)-1)
            yy = [self.tag2idx[each] for each in t]

            x.extend(xx)
            y.extend(yy)
            is_heads.extend(is_head)

        x_seqlen, y_seqlen = len(x), len(y)

        return x, x_seqlen, source, y, y_seqlen, target, is_heads

In [6]:
def pad(batch):
    f = lambda x: [sample[x] for sample in batch]
    x_seqlens = f(1)
    sources = f(2)
    y_seqlens = f(4)
    target = f(5)
    is_heads = f(6)

    x_maxlen = np.array(y_seqlens).max()
    y_maxlen = np.array(y_seqlens).max()

    f = lambda x, maxlen: [sample[x]+[0]*(maxlen-len(sample[x])) for sample in batch]
    x = f(0, x_maxlen)
    y = f(3, y_maxlen)

    f = torch.LongTensor

    return f(x), x_seqlens, sources, f(y), y_seqlens, target, is_heads

In [7]:
import torch
import torch.nn as nn
from transformers import AutoModel

In [17]:
class Net(nn.Module):
    def __init__(self, device='cpu', hidden_size=None, finetuning=False, tag_size=None, dr=0.0):
        super().__init__()
        self.bert = AutoModel.from_pretrained('roberta-base')

        self.rnn = nn.LSTM(bidirectional=True, num_layers=2, input_size=hidden_size,
                           hidden_size=hidden_size // 2, batch_first=True)

        self.dropout = nn.Dropout(p=dr)
        self.fc = nn.Linear(hidden_size, tag_size)

        self.device = device
        self.finetuning = finetuning

    def forward(self, x, y, ):
        x = x.to(self.device)
        y = y.to(self.device)

        if self.finetuning:
            self.bert.train()
            encoded_layers = self.bert(x)
            enc = encoded_layers[0]
        else:
            self.bert.eval()
            with torch.no_grad():
                encoded_layers = self.bert(x)
                enc = encoded_layers[0]
        
        enc, _ = self.rnn(enc)
        logits = self.dropout(self.fc(enc))
        y_hat = logits.argmax(-1)

        return logits, y, y_hat

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

from torch.utils import data

import argparse

from model import Net

import os

In [10]:
def train(model, iterator, optimizer, criterion):
    model.train()
    for i, batch in enumerate(iterator):
        x, x_seqlens, sources, y, y_seqlens, targets, is_heads = batch

        optimizer.zero_grad()

        logits, y, y_hat = model(x, y)

        logits = logits.view(-1, logits.shape[-1])
        y = y.view(-1)

        loss = criterion(logits, y)
        loss.backward()

        optimizer.step()

        if i and i % 10==0:
            print(f'step: {i}, loss: {loss.item()}')

In [11]:
def calc_score(true_list, pred_list, tag2idx):
    y_true = np.array([tag2idx[each] for each in true_list])
    y_pred = np.array([tag2idx[each] for each in pred_list])

    num_proposed = len(y_pred[y_pred > 3])
    num_correct = (np.logical_and(y_true == y_pred, y_true > 3)).astype(int).sum()
    num_gold = len(y_true[y_true > 3])

    try:
        pre = num_correct / num_proposed
    except ZeroDivisionError:
        pre = 1.0
    try:
        re = num_correct / num_gold
    except ZeroDivisionError:
        re = 1.0
    try:
        f1 = 2*pre*re / (pre + re)
    except ZeroDivisionError:
        f1 = 1.0

    print(f'---- evaluation result ----')
    print(f'precision : {pre}')
    print(f'recall : {re}')
    print(f'f1 score : {f1}')

    return pre, re, f1

In [12]:
def eval(model, iterator, epoch, tag2idx, idx2tag):
    model.eval()

    Words, Targets, Y, Y_hat, Is_heads = [], [], [], [], []
    pred_list, true_list = [], []

    with torch.no_grad():
        for i, batch in enumerate(iterator):
            x, x_seqlens, _, y, y_seqlens, targets, is_heads = batch

            _, _, y_hat = model(x, y)

            Targets.extend(targets)
            Is_heads.extend(is_heads)
            Y.extend(y.cpu().numpy().tolist())
            Y_hat.extend(y_hat.cpu().numpy().tolist())

    for targets, y_hat, is_heads in zip(Targets, Y_hat, Is_heads):
        preds = [idx2tag[hat] for hat, head in zip(y_hat, is_heads) if head==1]
        for t, p in zip(targets[1:-1], preds[1:-1]):
            pred_list.append(p)
            true_list.append(t)

    return calc_score(true_list, pred_list, tag2idx)

In [13]:
def pred(model, iterator, tag2idx, idx2tag):
    model.eval()

    Words, Targets, Y_hat, Is_heads = [], [], [], []
    true_list, pred_list = [], []

    with torch.no_grad():
        for i, batch in enumerate(iterator):
            x, x_seqlens, sources, y, y_seqlens, targets, is_heads = batch

            _, _, y_hat = model(x, y)

            Words.extend(sources)
            Targets.extend(targets)
            Is_heads.extend(is_heads)
            Y_hat.extend(y_hat.cpu().numpy().tolist())

    with open('result.txt', 'w') as fout:
        for x, targets, y_hat, is_heads in zip(Words, Targets, Y_hat, Is_heads):
            preds = [idx2tag[hat] for hat, head in zip(y_hat, is_heads) if head==1]
            for w, t, p in zip(x[1:-1], targets[1:-1], preds[1:-1]):
                true_list.append(t)
                pred_list.append(p)
                fout.write(f'{w} {t} {p}\n')
            fout.write('\n')

    _, _, _ = calc_score(true_list, pred_list, tag2idx)

In [14]:
if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument("--lr", type=float, default=1e-5)
    parser.add_argument("--n_epochs", type=int, default=300)
    parser.add_argument("--hidden_size", type=int, default=768)
    parser.add_argument("--early_stopping_step", type=int, default=15)
    parser.add_argument("--logdir", type=str, default="checkpoints")
    parser.add_argument("--trainset", type=str, default="data/train.txt")
    parser.add_argument("--testset", type=str, default="data/test.txt")
    parser.add_argument("--validset", type=str, default="data/val.txt")
    parser.add_argument("--model_path", type=str, default="checkpoints/129.pt")
    parser.add_argument("--vocab_path", type=str, default="vocab.txt")
    parser.add_argument("--finetuning", dest="finetuning", action='store_true')

    args = parser.parse_args(args=[])

In [None]:
    if not os.path.isdir(args.logdir):
        os.mkdir(args.logdir)

    args.finetuning = True
    train_dataset = DataLoader(args.trainset, args)
    train_iter = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=pad)
    eval_dataset = DataLoader(args.validset, args)
    eval_iter = data.DataLoader(eval_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=pad)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = Net(device=device, hidden_size=args.hidden_size, finetuning=args.finetuning, tag_size=len(train_dataset.vocab))
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    criterion = nn.CrossEntropyLoss(ignore_index=0)

    maxf1 = 0.0
    step = 0
    for epoch in range(1, args.n_epochs+1):
        print(f'train in {epoch}...')
        train(model, train_iter, optimizer, criterion)
        pre, re, f1 = eval(model, eval_iter, epoch, eval_dataset.tag2idx, eval_dataset.idx2tag)

        if maxf1 < f1:
            maxf1 = f1
            if os.path.exists(args.model_path):
                os.remove(args.model_path)
            torch.save(model, f'{args.logdir}/{str(epoch)}.pt')
            args.model_path = f'{args.logdir}/{str(epoch)}.pt'
            step = 0

        if step > args.early_stopping_step:
            print(f'early stopping on {epoch}')
            break
        else:
            step += 1

model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


train in 1...
step: 10, loss: 2.9695372581481934


In [None]:

    pred_dataset = DataLoader(args.testset, args)
    pred_iter = data.DataLoader(pred_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=pad)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = Net(device=device, hidden_size=args.hidden_size, finetuning=False, tag_size=len(pred_dataset.vocab))
    model.to(device)

    model = torch.load(args.model_path)
    print(f'load model: {args.model_path}')

    pred(model, pred_iter, pred_dataset.tag2idx, pred_dataset.idx2tag)