In [1]:
!pip install mxnet
!pip install gluonnlp pandas tqdm
!pip install sentencepiece
!pip install transformers==3
!pip install torch



In [2]:
!pip install git+https://git@github.com/SKTBrain/KoBERT.git@master

Collecting git+https://****@github.com/SKTBrain/KoBERT.git@master
  Cloning https://****@github.com/SKTBrain/KoBERT.git (to revision master) to /tmp/pip-req-build-w43ljcn8
  Running command git clone -q 'https://****@github.com/SKTBrain/KoBERT.git' /tmp/pip-req-build-w43ljcn8
Building wheels for collected packages: kobert
  Building wheel for kobert (setup.py) ... [?25ldone
[?25h  Created wheel for kobert: filename=kobert-0.1.2-py3-none-any.whl size=12732 sha256=b54f8a29bb47646a60be2c436ac188abd9aafc168cffb930d05fff59ff485191
  Stored in directory: /tmp/pip-ephem-wheel-cache-78gtf233/wheels/d3/68/ca/334747dfb038313b49cf71f84832a33372f3470d9ddfd051c0
Successfully built kobert
Installing collected packages: kobert
Successfully installed kobert-0.1.2


In [3]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook

In [4]:
from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model

In [5]:
from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup

In [6]:
##GPU 사용 시
device = torch.device("cuda:0")

In [7]:
bertmodel, vocab = get_pytorch_kobert_model()

[██████████████████████████████████████████████████]
[██████████████████████████████████████████████████]


In [8]:
!wget https://www.dropbox.com/s/374ftkec978br3d/ratings_train.txt?dl=1
!wget https://www.dropbox.com/s/977gbwh542gdy94/ratings_test.txt?dl=1

--2020-12-24 00:20:39--  https://www.dropbox.com/s/374ftkec978br3d/ratings_train.txt?dl=1
Resolving www.dropbox.com (www.dropbox.com)... 162.125.80.18, 2620:100:6030:18::a27d:5012
Connecting to www.dropbox.com (www.dropbox.com)|162.125.80.18|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /s/dl/374ftkec978br3d/ratings_train.txt [following]
--2020-12-24 00:20:40--  https://www.dropbox.com/s/dl/374ftkec978br3d/ratings_train.txt
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc21a7edd911c0cecccdb93ea938.dl.dropboxusercontent.com/cd/0/get/BFkiZLIEIMoyATMc92ecpAqZhde8xtqUZMq42LZNLyyAEDz0wzM8G-dK7aAlJS5Dk-7fvcRCs9K8q4R90Q0jGK6Z1pjWpTZhu1C4Wnik2jRjzYr2JcyUwoOv_b-i_agddmA/file?dl=1# [following]
--2020-12-24 00:20:40--  https://uc21a7edd911c0cecccdb93ea938.dl.dropboxusercontent.com/cd/0/get/BFkiZLIEIMoyATMc92ecpAqZhde8xtqUZMq42LZNLyyAEDz0wzM8G-dK7aAlJS5Dk-7fvcRCs9K8q4R90Q0jGK6Z1pj

In [9]:
dataset_train = nlp.data.TSVDataset("ratings_train.txt?dl=1", field_indices=[1,2], num_discard_samples=1)
dataset_test = nlp.data.TSVDataset("ratings_test.txt?dl=1", field_indices=[1,2], num_discard_samples=1)

In [10]:
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

using cached model


In [11]:
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)

        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))


In [12]:
## Setting parameters
max_len = 64
batch_size = 64
warmup_ratio = 0.1
num_epochs = 10
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-5

In [13]:
data_train = BERTDataset(dataset_train, 0, 1, tok, max_len, True, False)
data_test = BERTDataset(dataset_test, 0, 1, tok, max_len, True, False)

In [14]:
train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, num_workers=5)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=5)

In [15]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=2,
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=0.5)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)

In [16]:
model = BERTClassifier(bertmodel,  dr_rate=0.5).to(device)

In [17]:
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

In [18]:
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

In [24]:
t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

In [20]:
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

In [21]:
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

In [22]:
for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_acc += calc_accuracy(out, label)
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))

HBox(children=(IntProgress(value=0, max=2344), HTML(value='')))

epoch 1 batch id 1 loss 0.7366753220558167 train acc 0.5
epoch 1 batch id 201 loss 0.6568121910095215 train acc 0.5171797263681592
epoch 1 batch id 401 loss 0.4852443337440491 train acc 0.6189993765586035
epoch 1 batch id 601 loss 0.41338542103767395 train acc 0.684900166389351
epoch 1 batch id 801 loss 0.42175182700157166 train acc 0.7240948813982522
epoch 1 batch id 1001 loss 0.33205199241638184 train acc 0.7482985764235764
epoch 1 batch id 1201 loss 0.4162306785583496 train acc 0.7663405495420483
epoch 1 batch id 1401 loss 0.40020865201950073 train acc 0.7788187009279086
epoch 1 batch id 1601 loss 0.35288363695144653 train acc 0.7895846346033729
epoch 1 batch id 1801 loss 0.29010167717933655 train acc 0.7977946279844531
epoch 1 batch id 2001 loss 0.28915029764175415 train acc 0.8049725137431284
epoch 1 batch id 2201 loss 0.35682129859924316 train acc 0.8109524079963653

epoch 1 train acc 0.8147842007963595


HBox(children=(IntProgress(value=0, max=782), HTML(value='')))


epoch 1 test acc 0.8812140345268542


HBox(children=(IntProgress(value=0, max=2344), HTML(value='')))

epoch 2 batch id 1 loss 0.5659840106964111 train acc 0.8125
epoch 2 batch id 201 loss 0.21055644750595093 train acc 0.8741449004975125
epoch 2 batch id 401 loss 0.2255171537399292 train acc 0.8764806733167082
epoch 2 batch id 601 loss 0.38153547048568726 train acc 0.8796017054908486
epoch 2 batch id 801 loss 0.3388769328594208 train acc 0.8806569912609239
epoch 2 batch id 1001 loss 0.25130346417427063 train acc 0.8821334915084915
epoch 2 batch id 1201 loss 0.34472060203552246 train acc 0.8838598043297252
epoch 2 batch id 1401 loss 0.17869064211845398 train acc 0.8850820842255531
epoch 2 batch id 1601 loss 0.26015505194664 train acc 0.8869944565896315
epoch 2 batch id 1801 loss 0.22083118557929993 train acc 0.8883953359244864
epoch 2 batch id 2001 loss 0.22348390519618988 train acc 0.8909295352323838
epoch 2 batch id 2201 loss 0.23048672080039978 train acc 0.8925985347569286

epoch 2 train acc 0.8939624217861205


HBox(children=(IntProgress(value=0, max=782), HTML(value='')))


epoch 2 test acc 0.8840113491048593


HBox(children=(IntProgress(value=0, max=2344), HTML(value='')))

epoch 3 batch id 1 loss 0.48373836278915405 train acc 0.828125
epoch 3 batch id 201 loss 0.15840297937393188 train acc 0.9125466417910447
epoch 3 batch id 401 loss 0.18017694354057312 train acc 0.9167705735660848
epoch 3 batch id 601 loss 0.3068211078643799 train acc 0.9188071963394343
epoch 3 batch id 801 loss 0.2723780572414398 train acc 0.9201388888888888
epoch 3 batch id 1001 loss 0.2513832449913025 train acc 0.9218906093906094
epoch 3 batch id 1201 loss 0.17808102071285248 train acc 0.9232150291423813
epoch 3 batch id 1401 loss 0.15271508693695068 train acc 0.9240163276231264
epoch 3 batch id 1601 loss 0.2967056334018707 train acc 0.9252517957526546
epoch 3 batch id 1801 loss 0.17671093344688416 train acc 0.9265078428650749
epoch 3 batch id 2001 loss 0.17577962577342987 train acc 0.9278720014992504
epoch 3 batch id 2201 loss 0.24299463629722595 train acc 0.9289385506587915

epoch 3 train acc 0.9296497262514222


HBox(children=(IntProgress(value=0, max=782), HTML(value='')))


epoch 3 test acc 0.8899256713554987


HBox(children=(IntProgress(value=0, max=2344), HTML(value='')))

epoch 4 batch id 1 loss 0.4444595277309418 train acc 0.859375
epoch 4 batch id 201 loss 0.13889901340007782 train acc 0.9422419154228856
epoch 4 batch id 401 loss 0.11540228873491287 train acc 0.9442799251870324
epoch 4 batch id 601 loss 0.24448156356811523 train acc 0.9441035773710482
epoch 4 batch id 801 loss 0.31389155983924866 train acc 0.9452637328339576
epoch 4 batch id 1001 loss 0.08649826049804688 train acc 0.9469124625374625
epoch 4 batch id 1201 loss 0.10248437523841858 train acc 0.9479860532889259
epoch 4 batch id 1401 loss 0.09261193871498108 train acc 0.9481843326195575
epoch 4 batch id 1601 loss 0.14506103098392487 train acc 0.949289506558401
epoch 4 batch id 1801 loss 0.09147702157497406 train acc 0.9500537895613548
epoch 4 batch id 2001 loss 0.16411834955215454 train acc 0.9510010619690155
epoch 4 batch id 2201 loss 0.24735620617866516 train acc 0.9514709223080418

epoch 4 train acc 0.9520228953356086


HBox(children=(IntProgress(value=0, max=782), HTML(value='')))


epoch 4 test acc 0.8882273017902813


HBox(children=(IntProgress(value=0, max=2344), HTML(value='')))

epoch 5 batch id 1 loss 0.3622145354747772 train acc 0.859375
epoch 5 batch id 201 loss 0.060819655656814575 train acc 0.9612873134328358
epoch 5 batch id 401 loss 0.21588994562625885 train acc 0.9616193890274314
epoch 5 batch id 601 loss 0.27169376611709595 train acc 0.96141846921797
epoch 5 batch id 801 loss 0.17355506122112274 train acc 0.9628199126092385
epoch 5 batch id 1001 loss 0.017851151525974274 train acc 0.963676948051948
epoch 5 batch id 1201 loss 0.03543621674180031 train acc 0.9637541631973355
epoch 5 batch id 1401 loss 0.04253193736076355 train acc 0.9641550678087081
epoch 5 batch id 1601 loss 0.0490616112947464 train acc 0.9645436445971268
epoch 5 batch id 1801 loss 0.06676553189754486 train acc 0.9651929483620211
epoch 5 batch id 2001 loss 0.027783142402768135 train acc 0.9655328585707147
epoch 5 batch id 2201 loss 0.1519196778535843 train acc 0.9659955701953657

epoch 5 train acc 0.9666057842718999


HBox(children=(IntProgress(value=0, max=782), HTML(value='')))


epoch 5 test acc 0.8941815856777494


HBox(children=(IntProgress(value=0, max=2344), HTML(value='')))

epoch 6 batch id 1 loss 0.45254993438720703 train acc 0.859375
epoch 6 batch id 201 loss 0.023230426013469696 train acc 0.9709266169154229
epoch 6 batch id 401 loss 0.10994812846183777 train acc 0.9746337281795511
epoch 6 batch id 601 loss 0.21196112036705017 train acc 0.9741316555740432
epoch 6 batch id 801 loss 0.1503436267375946 train acc 0.9750702247191011
epoch 6 batch id 1001 loss 0.013627885840833187 train acc 0.9755244755244755
epoch 6 batch id 1201 loss 0.029128827154636383 train acc 0.9757753955037469
epoch 6 batch id 1401 loss 0.10501827299594879 train acc 0.9759323697359029
epoch 6 batch id 1601 loss 0.020535580813884735 train acc 0.9761769987507808
epoch 6 batch id 1801 loss 0.017443349584937096 train acc 0.9767230011104941
epoch 6 batch id 2001 loss 0.015322821214795113 train acc 0.9770271114442779
epoch 6 batch id 2201 loss 0.10562796890735626 train acc 0.9772120626987733

epoch 6 train acc 0.9774424061433447


HBox(children=(IntProgress(value=0, max=782), HTML(value='')))


epoch 6 test acc 0.896579283887468


HBox(children=(IntProgress(value=0, max=2344), HTML(value='')))

epoch 7 batch id 1 loss 0.38538122177124023 train acc 0.90625
epoch 7 batch id 201 loss 0.02471243217587471 train acc 0.9814987562189055
epoch 7 batch id 401 loss 0.020376408472657204 train acc 0.9821150249376559
epoch 7 batch id 601 loss 0.13279977440834045 train acc 0.98193115640599
epoch 7 batch id 801 loss 0.13205182552337646 train acc 0.9820731897627965
epoch 7 batch id 1001 loss 0.004017830826342106 train acc 0.9825174825174825
epoch 7 batch id 1201 loss 0.009073221124708652 train acc 0.9827097210657785
epoch 7 batch id 1401 loss 0.04712768644094467 train acc 0.982746698786581
epoch 7 batch id 1601 loss 0.002499643713235855 train acc 0.9830184259837601
epoch 7 batch id 1801 loss 0.07258564233779907 train acc 0.9835074264297612
epoch 7 batch id 2001 loss 0.009655407629907131 train acc 0.9838362068965517
epoch 7 batch id 2201 loss 0.04316674917936325 train acc 0.9839135620172649

epoch 7 train acc 0.9841883532423208


HBox(children=(IntProgress(value=0, max=782), HTML(value='')))


epoch 7 test acc 0.8965193414322251


HBox(children=(IntProgress(value=0, max=2344), HTML(value='')))

epoch 8 batch id 1 loss 0.3220914304256439 train acc 0.921875
epoch 8 batch id 201 loss 0.028787951916456223 train acc 0.9860851990049752
epoch 8 batch id 401 loss 0.01827303133904934 train acc 0.9872584164588528
epoch 8 batch id 601 loss 0.1193886399269104 train acc 0.9871308236272879
epoch 8 batch id 801 loss 0.13345305621623993 train acc 0.9873205368289638
epoch 8 batch id 1001 loss 0.003282591700553894 train acc 0.9877622377622378
epoch 8 batch id 1201 loss 0.00876672100275755 train acc 0.9879397377185679
epoch 8 batch id 1401 loss 0.01857699081301689 train acc 0.9878992683797287
epoch 8 batch id 1601 loss 0.018004994839429855 train acc 0.9879469862585883
epoch 8 batch id 1801 loss 0.042217861860990524 train acc 0.9881923237090505
epoch 8 batch id 2001 loss 0.07355353236198425 train acc 0.9884042353823088
epoch 8 batch id 2201 loss 0.08019272983074188 train acc 0.9883291685597456

epoch 8 train acc 0.9884279010238908


HBox(children=(IntProgress(value=0, max=782), HTML(value='')))


epoch 8 test acc 0.8961596867007673


HBox(children=(IntProgress(value=0, max=2344), HTML(value='')))

epoch 9 batch id 1 loss 0.14218182861804962 train acc 0.96875
epoch 9 batch id 201 loss 0.016207313165068626 train acc 0.990282960199005
epoch 9 batch id 401 loss 0.013090400025248528 train acc 0.991232855361596
epoch 9 batch id 601 loss 0.1411430388689041 train acc 0.9908225873544093
epoch 9 batch id 801 loss 0.11252908408641815 train acc 0.9909683208489388
epoch 9 batch id 1001 loss 0.003230966627597809 train acc 0.991087037962038
epoch 9 batch id 1201 loss 0.00725606270134449 train acc 0.9909970857618651
epoch 9 batch id 1401 loss 0.09745766967535019 train acc 0.9909774268379729
epoch 9 batch id 1601 loss 0.01987529546022415 train acc 0.9910212367270456
epoch 9 batch id 1801 loss 0.003428582102060318 train acc 0.9911941282620766
epoch 9 batch id 2001 loss 0.0553654283285141 train acc 0.9912075212393803
epoch 9 batch id 2201 loss 0.017177462577819824 train acc 0.9911900840527034

epoch 9 train acc 0.9912942619453925


HBox(children=(IntProgress(value=0, max=782), HTML(value='')))


epoch 9 test acc 0.896639226342711


HBox(children=(IntProgress(value=0, max=2344), HTML(value='')))

epoch 10 batch id 1 loss 0.16917265951633453 train acc 0.96875
epoch 10 batch id 201 loss 0.014045952819287777 train acc 0.9929259950248757
epoch 10 batch id 401 loss 0.07299713790416718 train acc 0.9937655860349127
epoch 10 batch id 601 loss 0.11329394578933716 train acc 0.9930324459234608
epoch 10 batch id 801 loss 0.11510379612445831 train acc 0.9929580212234707
epoch 10 batch id 1001 loss 0.001598995178937912 train acc 0.9929445554445554
epoch 10 batch id 1201 loss 0.010951431468129158 train acc 0.9927924646128227
epoch 10 batch id 1401 loss 0.0175294429063797 train acc 0.9928176302640971
epoch 10 batch id 1601 loss 0.0019454564899206161 train acc 0.9928755465334166
epoch 10 batch id 1801 loss 0.003645370714366436 train acc 0.9929986812881733
epoch 10 batch id 2001 loss 0.006257026456296444 train acc 0.9930034982508745
epoch 10 batch id 2201 loss 0.016531750559806824 train acc 0.9928796569741026

epoch 10 train acc 0.9928274317406144


HBox(children=(IntProgress(value=0, max=782), HTML(value='')))


epoch 10 test acc 0.8967790920716112


In [23]:
import csv

data_valid = nlp.data.TSVDataset("valid.txt", field_indices=[1,2], num_discard_samples=1)
data_valid = BERTDataset(data_valid, 0, 1, tok, max_len, True, False)
valid_dataloader = torch.utils.data.DataLoader(data_valid, batch_size=1, num_workers=5)

valid_result = []

model.eval()
for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(valid_dataloader)):
    token_ids = token_ids.long().to(device)
    segment_ids = segment_ids.long().to(device)
    valid_length= valid_length
    out = model(token_ids, valid_length, segment_ids)
    _, max_idx = torch.max(out, dim=-1)
    pred_emotion = max_idx.tolist()[0]    
    #logits = out.detach().cpu().numpy()        
    valid_result.append(pred_emotion)
    #print(pred_emotion)


with open('labeled.csv', 'w', newline='') as csvfile:
    fieldnames = ['Id', 'Predicted']
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

    writer.writeheader()
    for idx,row in enumerate(valid_result):
      writer.writerow({'Id': idx, 'Predicted': row})

HBox(children=(IntProgress(value=0, max=11187), HTML(value='')))


