# Prepare

In [1]:
import json
import re
import os, sys
import pickle

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm

In [3]:
from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model

In [4]:
from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup

In [5]:
def clean_sentence(sent):
    sent = sent.strip()
    sent = re.sub('{laughing}|{clearing}|{singing}|{applauding}', '', sent)
    sent = re.sub('[(][(].*?[)][)]|-.*?-', '', sent)
    return sent

In [6]:
get_main_topic = lambda x: x.split(' > ')[0]

def load_data():
    file_list = [f_name for f_name in os.listdir('data') if f_name[-5:] == '.json']
    
    np.random.seed(0)
    np.random.shuffle(file_list)
    
    valid_ratio = 0.1
    test_ratio = 0.1
    valid_size = int(len(file_list) * valid_ratio)
    test_size = int(len(file_list) * test_ratio)
    valid_list = file_list[-(valid_size + test_size):-test_size]
    test_list = file_list[-test_size:]
    
    train_data = []
    valid_data = []
    test_data = []
    for f_name in tqdm(file_list):
        with open('data/%s' %f_name, 'r') as f:
            data = json.loads(f.read())['document'][0]
            metadata = data['metadata']
            utterance = data['utterance']
            
            topic = get_main_topic(metadata['topic'])
            if topic[:4] == 'NWRW':
                continue

            last_speaker = None
            seg1 = seg2 = ''
            for u in utterance:
                if last_speaker is None:
                    last_speaker = u['speaker_id']
                    seg2 = u['form']
                elif last_speaker == u['speaker_id']:
                    seg2 += ' ' + u['original_form']
                else:
                    if seg1 and seg2:
                        if f_name in valid_list:
                            valid_data.append([f_name, topic, clean_sentence(seg1), clean_sentence(seg2)])
                        elif f_name in test_list:
                            test_data.append([f_name, topic, clean_sentence(seg1), clean_sentence(seg2)])
                        else:
                            train_data.append([f_name, topic, clean_sentence(seg1), clean_sentence(seg2)])
                    last_speaker = u['speaker_id']
                    seg1 = seg2
                    seg2 = u['original_form']
            if seg1 and seg2:
                if f_name in valid_list:
                    valid_data.append([f_name, topic, clean_sentence(seg1), clean_sentence(seg2)])
                elif f_name in test_list:
                    test_data.append([f_name, topic, clean_sentence(seg1), clean_sentence(seg2)])
                else:
                    train_data.append([f_name, topic, clean_sentence(seg1), clean_sentence(seg2)])
    return np.array(train_data), np.array(valid_data), np.array(test_data)

In [7]:
device = torch.device("cuda:0")

In [8]:
bertmodel, vocab = get_pytorch_kobert_model()
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

using cached model
using cached model
using cached model


In [9]:
## Setting parameters
pre_max_len = 4096
max_len = 512
batch_size = 6
warmup_ratio = 0.1
num_epochs = 20
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-5

In [10]:
indices = np.indices([pre_max_len])[0]
def pad_transform(transform, sent1, sent2, max_len):
    tokens, valid_len, segs = transform([sent1, sent2])
    if valid_len < max_len:
        return tokens[:max_len], np.array((tokens != 1).sum(), dtype='int32'), segs[:max_len]
    
    idx2 = indices[segs == 1][:-1]
    idx1 = indices[:idx2[0]][1:-1]
    if len(idx1) < len(idx2) and (len(idx1) * 2 + 3) < max_len:
        idx2 = idx2[:(max_len - 3 - len(idx1))]
    elif len(idx2) < len(idx1) and (len(idx2) * 2 + 3) < max_len:
        idx1 = idx1[-(max_len - 3 - len(idx2)):]
    else:
        if len(idx1) < len(idx2):
            idx1, idx2 = idx1[-((max_len - 3) // 2):], idx2[:((max_len - 3) // 2 + 1)]
        else:
            idx1, idx2 = idx1[-((max_len - 3) // 2 + 1):], idx2[:((max_len - 3) // 2)]
    return (np.concatenate([[2], tokens[idx1], [3], tokens[idx2],[3]]), 
             np.array(max_len, dtype='int32'), 
             np.array([0] * (len(idx1) + 2) + [1] * (len(idx2) + 1)))

In [11]:
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx1, sent_idx2, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=pre_max_len, pad=pad, pair=pair)

        self.sentences = [pad_transform(transform, d[sent_idx1], d[sent_idx2], max_len) for d in tqdm(dataset)]
#         self.sentences = [transform([d[sent_idx1], d[sent_idx2]]) for d in tqdm(dataset)]
        self.labels = [np.int32(d[label_idx]) for d in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))

In [12]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=2,
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        ret = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        pooler = ret[1]
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)

# Train

In [13]:
def row_to_list(row):
    return str(row[2]), str(row[3]), label_w2i[row[1]]

In [14]:
raw_train_dataset, raw_valid_dataset, raw_test_dataset = load_data()
label_i2w = np.unique(raw_train_dataset[:, 1]).tolist()
label_w2i = {w: i for i, w in enumerate(label_i2w)}
train_dataset = [row_to_list(r) for r in raw_train_dataset]
valid_dataset = [row_to_list(r) for r in raw_valid_dataset]
test_dataset = [row_to_list(r) for r in raw_test_dataset]

100%|██████████| 2232/2232 [00:02<00:00, 1046.09it/s]


In [15]:
data_train = BERTDataset(train_dataset, 0, 1, 2, tok, max_len, True, True)
data_valid = BERTDataset(valid_dataset, 0, 1, 2, tok, max_len, True, True)
data_test = BERTDataset(test_dataset, 0, 1, 2, tok, max_len, True, True)

100%|██████████| 18834/18834 [00:21<00:00, 870.52it/s] 
100%|██████████| 2270/2270 [00:02<00:00, 862.15it/s] 
100%|██████████| 2253/2253 [00:02<00:00, 862.33it/s]


In [16]:
train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, num_workers=5)
valid_dataloader = torch.utils.data.DataLoader(data_valid, batch_size=batch_size, num_workers=5)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=5)

In [17]:
model = BERTClassifier(bertmodel, dr_rate=0.5, num_classes=len(label_i2w)).to(device)

In [18]:
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
# loss_fn = nn.CrossEntropyLoss(weight=weight)
loss_fn = nn.CrossEntropyLoss()

t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

In [19]:
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

In [None]:
pickle.dump(label_w2i, open('ver_1.w2i', 'wb'))
result = []

n_batch = len(train_dataloader)
for e in range(num_epochs):
    train_acc = 0.0
    valid_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm(train_dataloader, file=sys.stdout)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            tqdm.write("epoch %02d batch id %04d/%d loss %f train acc %f" %(e+1, batch_id+1, n_batch, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    train_acc /= batch_id + 1
    print("epoch %02d train acc %f" %(e+1, train_acc))
    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(valid_dataloader):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        valid_acc += calc_accuracy(out, label)
    valid_acc /=  batch_id+1
    print("epoch %02d valid acc %f" %(e+1, valid_acc))
    torch.save(model, 'ver_1_%02d_%.04f_%.04f.model' %(e+1, train_acc, valid_acc))

epoch 01 batch id 0001/3139 loss 2.776338 train acc 0.000000
epoch 01 batch id 0201/3139 loss 2.672711 train acc 0.087894
epoch 01 batch id 0401/3139 loss 2.796778 train acc 0.081879
epoch 01 batch id 0601/3139 loss 2.769908 train acc 0.078203
epoch 01 batch id 0801/3139 loss 2.996637 train acc 0.082813
epoch 01 batch id 1001/3139 loss 2.425073 train acc 0.089910
epoch 01 batch id 1201/3139 loss 2.491398 train acc 0.092978
epoch 01 batch id 1401/3139 loss 2.903411 train acc 0.117297
epoch 01 batch id 1601/3139 loss 1.749927 train acc 0.164689
 52%|█████▏    | 1636/3139 [06:36<06:09,  4.07it/s]

In [176]:
tqdm._instances.clear()

# Predict

In [14]:
model = torch.load('ver.1.model')
label_w2i = pickle.load(open('ver.1.w2i', 'rb'))

In [15]:
raw_data = [
    ('혹시 마블 좋아하세요?', '네. 최근에 스파이더맨 봤어요.'),
    ('점심 뭐 드셨어요?', '근처에서 마라탕 먹었어요.')
]
data = BERTDataset([d + (0,) for d in raw_data], 0, 1, 2, tok, max_len, True, False)
dataloader = DataLoader(data, batch_size=batch_size, num_workers=5)

100%|██████████| 2/2 [00:00<00:00, 91.12it/s]


In [16]:
model.eval()
outputs = []
for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(dataloader)):
    token_ids = token_ids.long().to(device)
    segment_ids = segment_ids.long().to(device)
    valid_length = valid_length
    label = label.long().to(device)
    out = model(token_ids, valid_length, segment_ids)
    outputs.append(out.cpu().detach().numpy())

result = np.concatenate(outputs)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  This is separate from the ipykernel package so we can avoid doing imports until


  0%|          | 0/1 [00:00<?, ?it/s]

In [17]:
result

array([[-1.6374834 , -1.3215848 ,  0.12254745, -2.5353487 , -0.89550805,
        -1.4683422 ,  1.9166542 , -2.0298762 , -1.7553192 ,  1.1083449 ,
        -1.3344933 , -0.5012208 ,  0.7741338 ,  8.988237  , -0.67473793],
       [ 0.14410633,  0.48958924, -0.43005556, -1.36853   , 10.405009  ,
        -1.5924621 , -1.6080155 ,  0.31998512, -2.137295  , -0.46647248,
        -0.10372277,  2.2267296 , -1.6280773 , -1.0927978 , -1.1962178 ]],
      dtype=float32)

In [62]:
full_length = list(map(lambda x: len(x[2] + x[3]), raw_train_dataset))

In [65]:
np.argmax(full_length)

18082