In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install mxnet
!pip install gluonnlp
!pip install sentencepiece
!pip install transformers==3
!pip install torch
!pip install konlpy
!pip install git+https://git@github.com/SKTBrain/KoBERT.git@master

In [None]:
import re
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook

from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model

from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup

torch.manual_seed(615)
np.random.seed(615)

device = torch.device("cuda:0")
bertmodel, vocab = get_pytorch_kobert_model()

In [None]:
import pandas as pd
import numpy as np
import re
import os
from konlpy.tag import Okt

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
matplotlib.rcParams['axes.unicode_minus'] = False
matplotlib.rc('font', family='Malgun Gothic')
%matplotlib inline

In [None]:
import pandas as pd
path = '/content/drive/MyDrive/natural/'
train=pd.read_csv(path+'train.csv')
test=pd.read_csv(path+'test.csv')

In [None]:
train=train[['사업명', '과제명', '요약문_연구목표', '요약문_연구내용', '요약문_한글키워드', '요약문_기대효과', 'label']]
test=test[['사업명', '과제명', '요약문_연구목표', '요약문_연구내용', '요약문_기대효과', '요약문_한글키워드']]
train['요약문_연구목표'].fillna('NAN', inplace=True)
test['요약문_연구목표'].fillna('NAN', inplace=True)
train['요약문_연구내용'].fillna('NAN', inplace=True)
test['요약문_연구내용'].fillna('NAN', inplace=True)
train['요약문_기대효과'].fillna('NAN', inplace=True)
test['요약문_기대효과'].fillna('NAN', inplace=True)
train['요약문_한글키워드'].fillna('NAN', inplace=True)
test['요약문_한글키워드'].fillna('NAN', inplace=True)

train['data']=train['사업명'] + train['과제명'] + train['요약문_연구목표'] + train['요약문_연구내용'] + train['요약문_기대효과'] + train['요약문_한글키워드']
test['data']=test['사업명'] + test['과제명'] + test['요약문_연구목표'] + test['요약문_연구내용'] + test['요약문_기대효과'] + test['요약문_한글키워드']

def clean_text(sent):
    sent_clean=re.sub("[^가-힣ㄱ-하-ㅣ]", " ", sent)
    sent_clean = re.sub(' +', ' ', sent_clean) 
    return sent_clean

In [None]:
from sklearn.model_selection import train_test_split
trains, tests = train_test_split(train, test_size=0.2, random_state=615)

In [None]:
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

using cached model


In [None]:
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair) 

        self.sentences = [transform([clean_text(i)]) for i in dataset['data']]
        self.labels = [np.int32(i) for i in dataset['label']]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))

In [None]:
max_len = 120
batch_size = 120
warmup_ratio = 0.1
num_epochs = 46
max_grad_norm = 1
log_interval = 200
learning_rate = 3e-5

In [None]:
data_train = BERTDataset(trains, 3, 2, tok, max_len, True, False)
data_test = BERTDataset(tests, 3, 2, tok, max_len, True, False)

train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, num_workers=2)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=2)

In [None]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes = 46, 
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)

In [None]:
model = BERTClassifier(bertmodel, dr_rate=0.5).to(device)


no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]


optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)


def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

In [None]:
highest_acc = 0
patience = 0

for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) 
        optimizer.step()
        scheduler.step() 
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))

    
    model.eval()
    
    for test_batch_id, (test_token_ids, test_valid_length, test_segment_ids, test_label) in enumerate(tqdm_notebook(test_dataloader)):
        test_token_ids = test_token_ids.long().to(device)
        test_segment_ids = test_segment_ids.long().to(device)
        test_valid_length= test_valid_length
        test_label = test_label.long().to(device)
        test_out = model(token_ids, valid_length, segment_ids)
        test_loss = loss_fn(out, label)
        test_acc += calc_accuracy(out, label)
    print("epoch {} test acc {}".format(e+1, test_acc / (test_batch_id+1)))

    if test_acc > highest_acc:
        torch.save({
            'epoch': e,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': test_loss,
            }, '/content/drive/MyDrive/natural/torchckpt/model.pt')
        patience = 0
    else:
        print("test acc did not improved. best:{} current:{}".format(highest_acc, test_acc))
        patience += 1
        if patience > 5:
            break
    print('current patience: {}'.format(patience))
    print("************************************************************************************")


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  if __name__ == '__main__':


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 1 batch id 1 loss 3.730957269668579 train acc 0.03333333333333333
epoch 1 batch id 201 loss 2.5041677951812744 train acc 0.4369402985074626
epoch 1 batch id 401 loss 1.5741596221923828 train acc 0.6256234413965086
epoch 1 batch id 601 loss 1.2751870155334473 train acc 0.68991957848031
epoch 1 batch id 801 loss 0.9575108289718628 train acc 0.7220661672908861
epoch 1 batch id 1001 loss 1.067197322845459 train acc 0.7411338661338661
epoch 1 train acc 0.7533247348810549


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 1 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 2 batch id 1 loss 1.1128082275390625 train acc 0.75
epoch 2 batch id 201 loss 0.6845892667770386 train acc 0.8397180762852402
epoch 2 batch id 401 loss 0.7873134016990662 train acc 0.8421238570241059
epoch 2 batch id 601 loss 0.8229278326034546 train acc 0.8468247365501937
epoch 2 batch id 801 loss 0.5637807250022888 train acc 0.8504577611319191
epoch 2 batch id 1001 loss 0.6164284348487854 train acc 0.8531218781218791
epoch 2 train acc 0.8554958440814


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 2 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 3 batch id 1 loss 0.6960445046424866 train acc 0.8166666666666667
epoch 3 batch id 201 loss 0.49364858865737915 train acc 0.8730514096185742
epoch 3 batch id 401 loss 0.5746456384658813 train acc 0.8749376558603494
epoch 3 batch id 601 loss 0.641198456287384 train acc 0.8777454242928452
epoch 3 batch id 801 loss 0.386289119720459 train acc 0.8797856845609635
epoch 3 batch id 1001 loss 0.4066770374774933 train acc 0.8814935064935028
epoch 3 train acc 0.8827601031814227


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 3 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 4 batch id 1 loss 0.5035029649734497 train acc 0.825
epoch 4 batch id 201 loss 0.393719881772995 train acc 0.894237147595357
epoch 4 batch id 401 loss 0.3975681960582733 train acc 0.8948254364089783
epoch 4 batch id 601 loss 0.5176234245300293 train acc 0.8973100388241828
epoch 4 batch id 801 loss 0.2687014639377594 train acc 0.899094881398251
epoch 4 batch id 1001 loss 0.2548816204071045 train acc 0.9004995004994962
epoch 4 train acc 0.9017053597019151


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 4 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 5 batch id 1 loss 0.3736937344074249 train acc 0.8833333333333333
epoch 5 batch id 201 loss 0.33587342500686646 train acc 0.9088308457711444
epoch 5 batch id 401 loss 0.29087895154953003 train acc 0.9101413133832097
epoch 5 batch id 601 loss 0.45850399136543274 train acc 0.912021630615641
epoch 5 batch id 801 loss 0.19804280996322632 train acc 0.9139096962130647
epoch 5 batch id 1001 loss 0.2272442728281021 train acc 0.9154761904761863
epoch 5 train acc 0.9163442247062151


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 5 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 6 batch id 1 loss 0.3424355387687683 train acc 0.9
epoch 6 batch id 201 loss 0.2607234716415405 train acc 0.9240049751243782
epoch 6 batch id 401 loss 0.2444974035024643 train acc 0.9241064006650042
epoch 6 batch id 601 loss 0.38876888155937195 train acc 0.9261508596783141
epoch 6 batch id 801 loss 0.15830470621585846 train acc 0.9286620890553481
epoch 6 batch id 1001 loss 0.1381760537624359 train acc 0.9300033300033311
epoch 6 train acc 0.9304098595586136


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 6 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 7 batch id 1 loss 0.2312782108783722 train acc 0.9166666666666666
epoch 7 batch id 201 loss 0.16954441368579865 train acc 0.9398839137645109
epoch 7 batch id 401 loss 0.2508619427680969 train acc 0.9399625935162098
epoch 7 batch id 601 loss 0.2614334523677826 train acc 0.9405435385468675
epoch 7 batch id 801 loss 0.16269832849502563 train acc 0.9413753641281765
epoch 7 batch id 1001 loss 0.14581140875816345 train acc 0.9425574425574479
epoch 7 train acc 0.9432358842075146


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 7 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 8 batch id 1 loss 0.14887526631355286 train acc 0.95
epoch 8 batch id 201 loss 0.099470354616642 train acc 0.9493781094527363
epoch 8 batch id 401 loss 0.2099006175994873 train acc 0.9483374896093099
epoch 8 batch id 601 loss 0.23863789439201355 train acc 0.9494453688297294
epoch 8 batch id 801 loss 0.10650515556335449 train acc 0.9510403662089107
epoch 8 batch id 1001 loss 0.09543074667453766 train acc 0.9519230769230851
epoch 8 train acc 0.9522499283462375


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 8 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 9 batch id 1 loss 0.17012490332126617 train acc 0.9583333333333334
epoch 9 batch id 201 loss 0.11321315914392471 train acc 0.9539386401326702
epoch 9 batch id 401 loss 0.20563024282455444 train acc 0.9549875311720684
epoch 9 batch id 601 loss 0.2375619113445282 train acc 0.9571270105379935
epoch 9 batch id 801 loss 0.17236129939556122 train acc 0.9585830212234769
epoch 9 batch id 1001 loss 0.040407873690128326 train acc 0.9592990342990432
epoch 9 train acc 0.9595514474061407


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 9 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 10 batch id 1 loss 0.2584867775440216 train acc 0.95
epoch 10 batch id 201 loss 0.15819750726222992 train acc 0.96231343283582
epoch 10 batch id 401 loss 0.14493612945079803 train acc 0.962697423108893
epoch 10 batch id 601 loss 0.21254666149616241 train acc 0.9635191347753762
epoch 10 batch id 801 loss 0.09972628206014633 train acc 0.9642322097378344
epoch 10 batch id 1001 loss 0.09000253677368164 train acc 0.9650016650016747
epoch 10 train acc 0.9650902837489327


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 10 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 11 batch id 1 loss 0.18815824389457703 train acc 0.9416666666666667
epoch 11 batch id 201 loss 0.07737249881029129 train acc 0.9676202321724696
epoch 11 batch id 401 loss 0.08136334270238876 train acc 0.9671446384039896
epoch 11 batch id 601 loss 0.29408469796180725 train acc 0.9679284525790375
epoch 11 batch id 801 loss 0.08369878679513931 train acc 0.9694652517686297
epoch 11 batch id 1001 loss 0.03464324772357941 train acc 0.9701964701964804
epoch 11 train acc 0.9702851820005804


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 11 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 12 batch id 1 loss 0.06337656825780869 train acc 0.9916666666666667
epoch 12 batch id 201 loss 0.02722727321088314 train acc 0.9723466003316734
epoch 12 batch id 401 loss 0.06432587653398514 train acc 0.9728387364921027
epoch 12 batch id 601 loss 0.16466815769672394 train acc 0.9737520798668906
epoch 12 batch id 801 loss 0.046798162162303925 train acc 0.9738347898460326
epoch 12 batch id 1001 loss 0.09652991592884064 train acc 0.9739094239094332
epoch 12 train acc 0.9740685010031589


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 12 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 13 batch id 1 loss 0.05622393637895584 train acc 0.9916666666666667
epoch 13 batch id 201 loss 0.04677430912852287 train acc 0.9762023217247083
epoch 13 batch id 401 loss 0.14914311468601227 train acc 0.9760182876142971
epoch 13 batch id 601 loss 0.15719886124134064 train acc 0.9762340543538572
epoch 13 batch id 801 loss 0.05221116915345192 train acc 0.9769870994590161
epoch 13 batch id 1001 loss 0.05007857084274292 train acc 0.9772061272061362
epoch 13 train acc 0.9772929206076288


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 13 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 14 batch id 1 loss 0.10122717171907425 train acc 0.9583333333333334
epoch 14 batch id 201 loss 0.01488072145730257 train acc 0.9771144278606952
epoch 14 batch id 401 loss 0.04318295046687126 train acc 0.9771820448877809
epoch 14 batch id 601 loss 0.20664571225643158 train acc 0.977925679423187
epoch 14 batch id 801 loss 0.0665556937456131 train acc 0.9787765293383339
epoch 14 batch id 1001 loss 0.02097744680941105 train acc 0.9789960039960129
epoch 14 train acc 0.9791344224706272


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 14 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 15 batch id 1 loss 0.03768059238791466 train acc 0.9916666666666667
epoch 15 batch id 201 loss 0.014681565575301647 train acc 0.9811774461028179
epoch 15 batch id 401 loss 0.03102988749742508 train acc 0.9807772236076484
epoch 15 batch id 601 loss 0.17805816233158112 train acc 0.9813227953411017
epoch 15 batch id 801 loss 0.017072878777980804 train acc 0.9817311693716253
epoch 15 batch id 1001 loss 0.042046524584293365 train acc 0.9816017316017401
epoch 15 train acc 0.9818071080538874


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 15 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 16 batch id 1 loss 0.045795876532793045 train acc 0.9833333333333333
epoch 16 batch id 201 loss 0.031794317066669464 train acc 0.98217247097844
epoch 16 batch id 401 loss 0.06378388404846191 train acc 0.9826891105569425
epoch 16 batch id 601 loss 0.18006275594234467 train acc 0.9831114808652288
epoch 16 batch id 801 loss 0.009272752329707146 train acc 0.9831668747399154
epoch 16 batch id 1001 loss 0.030894223600625992 train acc 0.9830419580419668
epoch 16 train acc 0.9832975064488434


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 16 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 17 batch id 1 loss 0.013288116082549095 train acc 1.0
epoch 17 batch id 201 loss 0.0518495999276638 train acc 0.9854477611940291
epoch 17 batch id 401 loss 0.11159209907054901 train acc 0.9846841230257705
epoch 17 batch id 601 loss 0.1386571079492569 train acc 0.9851774819744906
epoch 17 batch id 801 loss 0.0025806569028645754 train acc 0.9851851851851913
epoch 17 batch id 1001 loss 0.1309502273797989 train acc 0.9850316350316427
epoch 17 train acc 0.9852321582115254


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 17 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 18 batch id 1 loss 0.030025038868188858 train acc 0.9916666666666667
epoch 18 batch id 201 loss 0.01475230697542429 train acc 0.9872719734660019
epoch 18 batch id 401 loss 0.07548500597476959 train acc 0.9864297589359948
epoch 18 batch id 601 loss 0.12517999112606049 train acc 0.9868552412645627
epoch 18 batch id 801 loss 0.004090254195034504 train acc 0.9868705784436177
epoch 18 batch id 1001 loss 0.024929575622081757 train acc 0.9868215118215187
epoch 18 train acc 0.9866580682143901


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 18 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 19 batch id 1 loss 0.015834899619221687 train acc 1.0
epoch 19 batch id 201 loss 0.011023002676665783 train acc 0.9870232172470972
epoch 19 batch id 401 loss 0.03182969242334366 train acc 0.9868246051537841
epoch 19 batch id 601 loss 0.11995285004377365 train acc 0.9871186910704421
epoch 19 batch id 801 loss 0.00863428134471178 train acc 0.9873907615480706
epoch 19 batch id 1001 loss 0.010661553591489792 train acc 0.9873293373293442
epoch 19 train acc 0.9874390942963626


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 19 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 20 batch id 1 loss 0.005075446795672178 train acc 1.0
epoch 20 batch id 201 loss 0.0625009834766388 train acc 0.9881011608623541
epoch 20 batch id 401 loss 0.02984583005309105 train acc 0.9881961762261033
epoch 20 batch id 601 loss 0.13115184009075165 train acc 0.9886300610094324
epoch 20 batch id 801 loss 0.018415119498968124 train acc 0.9887432376196473
epoch 20 batch id 1001 loss 0.02248639240860939 train acc 0.98893606393607
epoch 20 train acc 0.989115792490686


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 20 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 21 batch id 1 loss 0.01771310158073902 train acc 0.9916666666666667
epoch 21 batch id 201 loss 0.00542370043694973 train acc 0.990174129353233
epoch 21 batch id 401 loss 0.040876876562833786 train acc 0.9895677472984223
epoch 21 batch id 601 loss 0.056349270045757294 train acc 0.9898225180255166
epoch 21 batch id 801 loss 0.0580267459154129 train acc 0.9898148148148198
epoch 21 batch id 1001 loss 0.003943885210901499 train acc 0.9897019647019706
epoch 21 train acc 0.9896316996274007


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 21 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 22 batch id 1 loss 0.07247115671634674 train acc 0.9916666666666667
epoch 22 batch id 201 loss 0.004170529544353485 train acc 0.9899668325041447
epoch 22 batch id 401 loss 0.031974129378795624 train acc 0.9896716541978404
epoch 22 batch id 601 loss 0.0420839786529541 train acc 0.990446478092072
epoch 22 batch id 801 loss 0.015480514615774155 train acc 0.9905014565126969
epoch 22 batch id 1001 loss 0.005618082359433174 train acc 0.9905844155844209
epoch 22 train acc 0.9905846947549445


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 22 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 23 batch id 1 loss 0.02255900762975216 train acc 0.9833333333333333
epoch 23 batch id 201 loss 0.013433055020868778 train acc 0.9911691542288551
epoch 23 batch id 401 loss 0.029736999422311783 train acc 0.9912718204488794
epoch 23 batch id 601 loss 0.05671485885977745 train acc 0.9915141430948451
epoch 23 batch id 801 loss 0.0034232214093208313 train acc 0.991468997086979
epoch 23 batch id 1001 loss 0.055007562041282654 train acc 0.9914502164502215
epoch 23 train acc 0.9915878475207802


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 23 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 24 batch id 1 loss 0.011342356912791729 train acc 1.0
epoch 24 batch id 201 loss 0.019877366721630096 train acc 0.9919983416252071
epoch 24 batch id 401 loss 0.024110278114676476 train acc 0.9920199501246906
epoch 24 batch id 601 loss 0.14865462481975555 train acc 0.9922767609539689
epoch 24 batch id 801 loss 0.0010337315034121275 train acc 0.992270079067836
epoch 24 batch id 1001 loss 0.06280355900526047 train acc 0.992365967365972
epoch 24 train acc 0.9922828890799663


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 24 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 25 batch id 1 loss 0.029159851372241974 train acc 0.9916666666666667
epoch 25 batch id 201 loss 0.02155504748225212 train acc 0.9934908789386394
epoch 25 batch id 401 loss 0.06689658761024475 train acc 0.9926849542809657
epoch 25 batch id 601 loss 0.1007944792509079 train acc 0.9928452579034971
epoch 25 batch id 801 loss 0.0008173169335350394 train acc 0.9929359134415352
epoch 25 batch id 1001 loss 0.007155990693718195 train acc 0.9930486180486222
epoch 25 train acc 0.9930639151619368


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 25 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 26 batch id 1 loss 0.058147966861724854 train acc 0.9916666666666667
epoch 26 batch id 201 loss 0.053858380764722824 train acc 0.9926616915422884
epoch 26 batch id 401 loss 0.028611497953534126 train acc 0.9928096425602678
epoch 26 batch id 601 loss 0.02956177107989788 train acc 0.9930948419301195
epoch 26 batch id 801 loss 0.05974783003330231 train acc 0.9931127756970493
epoch 26 batch id 1001 loss 0.0783378928899765 train acc 0.9932400932400974
epoch 26 train acc 0.993393522499284


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 26 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 27 batch id 1 loss 0.019862432032823563 train acc 0.9916666666666667
epoch 27 batch id 201 loss 0.009165416471660137 train acc 0.9930348258706467
epoch 27 batch id 401 loss 0.010287592187523842 train acc 0.9930382377389881
epoch 27 batch id 601 loss 0.044173821806907654 train acc 0.9933582917359985
epoch 27 batch id 801 loss 0.0004930093418806791 train acc 0.9934248855597207
epoch 27 batch id 1001 loss 0.00988300796598196 train acc 0.9935564435564476
epoch 27 train acc 0.9936156491831457


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 27 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 28 batch id 1 loss 0.004109537228941917 train acc 1.0
epoch 28 batch id 201 loss 0.04386534169316292 train acc 0.9941127694859032
epoch 28 batch id 401 loss 0.011347831226885319 train acc 0.994264339152121
epoch 28 batch id 601 loss 0.028053201735019684 train acc 0.9943427620632302
epoch 28 batch id 801 loss 0.0054199849255383015 train acc 0.9944132334581802
epoch 28 batch id 1001 loss 0.01129225268959999 train acc 0.9944888444888478
epoch 28 train acc 0.9944468329034094


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 28 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 29 batch id 1 loss 0.003263238351792097 train acc 1.0
epoch 29 batch id 201 loss 0.01782440021634102 train acc 0.9949834162520722
epoch 29 batch id 401 loss 0.002661709673702717 train acc 0.9949916874480482
epoch 29 batch id 601 loss 0.07393981516361237 train acc 0.9953549639489763
epoch 29 batch id 801 loss 0.001338057336397469 train acc 0.995349563046195
epoch 29 batch id 1001 loss 0.011017593555152416 train acc 0.9952797202797234
epoch 29 train acc 0.9952278589853827


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 29 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 30 batch id 1 loss 0.0005838843062520027 train acc 1.0
epoch 30 batch id 201 loss 0.10450397431850433 train acc 0.9949834162520724
epoch 30 batch id 401 loss 0.011150260455906391 train acc 0.9950332502078151
epoch 30 batch id 601 loss 0.024192750453948975 train acc 0.9952995008319487
epoch 30 batch id 801 loss 0.0005088913603685796 train acc 0.9951206824802358
epoch 30 batch id 1001 loss 0.011912509799003601 train acc 0.9952380952380983
epoch 30 train acc 0.9952923473774712


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 30 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 31 batch id 1 loss 0.0008611927623860538 train acc 1.0
epoch 31 batch id 201 loss 0.016222039237618446 train acc 0.9953565505804312
epoch 31 batch id 401 loss 0.0022706068120896816 train acc 0.995594347464673
epoch 31 batch id 601 loss 0.016089599579572678 train acc 0.9957293399889093
epoch 31 batch id 801 loss 0.00031930734985508025 train acc 0.9958697461506474
epoch 31 batch id 1001 loss 0.03165801241993904 train acc 0.9958624708624735
epoch 31 train acc 0.9958010891372874


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 31 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 32 batch id 1 loss 0.01738010160624981 train acc 0.9916666666666667
epoch 32 batch id 201 loss 0.0028484275098890066 train acc 0.9958126036484244
epoch 32 batch id 401 loss 0.009553368203341961 train acc 0.9955320033250225
epoch 32 batch id 601 loss 0.019923869520425797 train acc 0.9957709373266799
epoch 32 batch id 801 loss 0.00042484409641474485 train acc 0.9958697461506475
epoch 32 batch id 1001 loss 0.012945896945893764 train acc 0.9959207459207486
epoch 32 train acc 0.9959802235597581


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 32 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 33 batch id 1 loss 0.0010990611044690013 train acc 1.0
epoch 33 batch id 201 loss 0.03869185224175453 train acc 0.9961857379767827
epoch 33 batch id 401 loss 0.001294709974899888 train acc 0.9963632585203669
epoch 33 batch id 601 loss 0.012468636967241764 train acc 0.9964642262895194
epoch 33 batch id 801 loss 0.002822539769113064 train acc 0.9966396171452372
epoch 33 batch id 1001 loss 0.052685484290122986 train acc 0.9966033966033989
epoch 33 train acc 0.9966179421037534


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 33 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 34 batch id 1 loss 0.005934748332947493 train acc 1.0
epoch 34 batch id 201 loss 0.004345015622675419 train acc 0.9964344941956877
epoch 34 batch id 401 loss 0.000667196698486805 train acc 0.9963632585203667
epoch 34 batch id 601 loss 0.013867695815861225 train acc 0.9965335551858031
epoch 34 batch id 801 loss 0.00038560599205084145 train acc 0.9967436537661275
epoch 34 batch id 1001 loss 0.0013712274376302958 train acc 0.9965784215784238
epoch 34 train acc 0.9965821152192597


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 34 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 35 batch id 1 loss 0.0015938612632453442 train acc 1.0
epoch 35 batch id 201 loss 0.02715771645307541 train acc 0.9963930348258704
epoch 35 batch id 401 loss 0.0014279949245974422 train acc 0.9964048212801342
epoch 35 batch id 601 loss 0.02193731814622879 train acc 0.9966306156406005
epoch 35 batch id 801 loss 0.00021451547218021005 train acc 0.9965251768622575
epoch 35 batch id 1001 loss 0.008600577712059021 train acc 0.9965617715617738
epoch 35 train acc 0.9966394382344503


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 35 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 36 batch id 1 loss 0.05010809376835823 train acc 0.9916666666666667
epoch 36 batch id 201 loss 0.0063397763296961784 train acc 0.9967247097844113
epoch 36 batch id 401 loss 0.048327527940273285 train acc 0.9964256026600179
epoch 36 batch id 601 loss 0.0083948764950037 train acc 0.9966999445368846
epoch 36 batch id 801 loss 0.0003973706334363669 train acc 0.9966916354556824
epoch 36 batch id 1001 loss 0.02863279916346073 train acc 0.9966616716616739
epoch 36 train acc 0.9967039266265394


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 36 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 37 batch id 1 loss 0.001325409160926938 train acc 1.0
epoch 37 batch id 201 loss 0.0054891291074454784 train acc 0.9966003316749587
epoch 37 batch id 401 loss 0.0009931671665981412 train acc 0.9966126350789704
epoch 37 batch id 601 loss 0.00965342577546835 train acc 0.9969633943427635
epoch 37 batch id 801 loss 0.00019907984824385494 train acc 0.9970973782771552
epoch 37 batch id 1001 loss 0.005382658913731575 train acc 0.9970362970362989
epoch 37 train acc 0.9970836916021771


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 37 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 38 batch id 1 loss 0.0018603861099109054 train acc 1.0
epoch 38 batch id 201 loss 0.017149129882454872 train acc 0.9969320066334995
epoch 38 batch id 401 loss 0.0011779364431276917 train acc 0.9969243557772248
epoch 38 batch id 601 loss 0.006105189677327871 train acc 0.9970604547975613
epoch 38 batch id 801 loss 0.011665606871247292 train acc 0.9972014148980459
epoch 38 batch id 1001 loss 0.005568103399127722 train acc 0.9972194472194492
epoch 38 train acc 0.9971553453711649


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 38 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 39 batch id 1 loss 0.0006090205279178917 train acc 1.0
epoch 39 batch id 201 loss 0.005396130960434675 train acc 0.9968076285240465
epoch 39 batch id 401 loss 0.00040430770604871213 train acc 0.9969451371571083
epoch 39 batch id 601 loss 0.022575618699193 train acc 0.9969079312257365
epoch 39 batch id 801 loss 0.00023660586157348007 train acc 0.9969725343320868
epoch 39 batch id 1001 loss 0.003826993517577648 train acc 0.9969530469530491
epoch 39 train acc 0.996961880194897


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 39 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 40 batch id 1 loss 0.015035848133265972 train acc 0.9916666666666667
epoch 40 batch id 201 loss 0.012948472052812576 train acc 0.9969734660033174
epoch 40 batch id 401 loss 0.000332687224727124 train acc 0.9970074812967594
epoch 40 batch id 601 loss 0.01844598911702633 train acc 0.9970327232390478
epoch 40 batch id 801 loss 0.00027040610439144075 train acc 0.9969205160216417
epoch 40 batch id 1001 loss 0.0028772051446139812 train acc 0.9969197469197492
epoch 40 train acc 0.9969117225566053


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 40 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 41 batch id 1 loss 0.0023859506472945213 train acc 1.0
epoch 41 batch id 201 loss 0.0024737513158470392 train acc 0.9961028192371475
epoch 41 batch id 401 loss 0.005532281938940287 train acc 0.996467165419785
epoch 41 batch id 601 loss 0.024388961493968964 train acc 0.9965612867443167
epoch 41 batch id 801 loss 0.00022399674344342202 train acc 0.9962338743237642
epoch 41 batch id 1001 loss 0.0038536761421710253 train acc 0.9964285714285738
epoch 41 train acc 0.9965462883347662


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 41 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 42 batch id 1 loss 0.0021657096222043037 train acc 1.0
epoch 42 batch id 201 loss 0.005899610463529825 train acc 0.9961028192371479
epoch 42 batch id 401 loss 0.0030007031746208668 train acc 0.9959684123025784
epoch 42 batch id 601 loss 0.012762398459017277 train acc 0.9963117027176945
epoch 42 batch id 801 loss 0.00024056177062448114 train acc 0.9960466084061614
epoch 42 batch id 1001 loss 0.0200730599462986 train acc 0.9962287712287737
epoch 42 train acc 0.9962525078819132


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 42 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 43 batch id 1 loss 0.03258727118372917 train acc 0.9916666666666667
epoch 43 batch id 201 loss 0.0020456863567233086 train acc 0.9952736318407953
epoch 43 batch id 401 loss 0.03266225755214691 train acc 0.9954488778054873
epoch 43 batch id 601 loss 0.024182302877306938 train acc 0.9957848031059363
epoch 43 batch id 801 loss 0.00020809596753679216 train acc 0.9956200582605101
epoch 43 batch id 1001 loss 0.0060185762122273445 train acc 0.9956876456876483
epoch 43 train acc 0.9958010891372882


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 43 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 44 batch id 1 loss 0.0066623808816075325 train acc 1.0
epoch 44 batch id 201 loss 0.002886832458898425 train acc 0.9962271973466003
epoch 44 batch id 401 loss 0.0268320944160223 train acc 0.9958021612635094
epoch 44 batch id 601 loss 0.015917671844363213 train acc 0.995965058236275
epoch 44 batch id 801 loss 0.015113134868443012 train acc 0.995942571785271
epoch 44 batch id 1001 loss 0.008184673264622688 train acc 0.9960456210456238
epoch 44 train acc 0.9961092003439368


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 44 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 45 batch id 1 loss 0.0386185422539711 train acc 0.9916666666666667
epoch 45 batch id 201 loss 0.003743949346244335 train acc 0.9960199004975124
epoch 45 batch id 401 loss 0.019963396713137627 train acc 0.9958852867830438
epoch 45 batch id 601 loss 0.018620463088154793 train acc 0.9960759844703293
epoch 45 batch id 801 loss 0.00031268640304915607 train acc 0.9961090303786957
epoch 45 batch id 1001 loss 0.026594113558530807 train acc 0.9961455211455238
epoch 45 train acc 0.9961951848667232


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 45 test acc 1.0
current patience: 0
************************************************************************************


  0%|          | 0/1163 [00:00<?, ?it/s]

epoch 46 batch id 1 loss 0.007861709222197533 train acc 0.9916666666666667
epoch 46 batch id 201 loss 0.004305522423237562 train acc 0.9963515754560527
epoch 46 batch id 401 loss 0.034994784742593765 train acc 0.9964256026600176
epoch 46 batch id 601 loss 0.02057291753590107 train acc 0.9964642262895191
epoch 46 batch id 801 loss 0.00023636898549739271 train acc 0.9964315439034561
epoch 46 batch id 1001 loss 0.014523906633257866 train acc 0.9964119214119238
epoch 46 train acc 0.9964817999426758


  0%|          | 0/291 [00:00<?, ?it/s]

epoch 46 test acc 1.0
current patience: 0
************************************************************************************


In [None]:
class BERTDataset1(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair) 

        self.sentences = [transform([clean_text(i)]) for i in dataset['data']]
        self.labels = [np.int32(0) for i in dataset['data']]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))
        #return (self.sentences[i] + (0, ))

    def __len__(self):
        return (len(self.sentences))


test_set = BERTDataset1(test, 3, 2, tok, max_len, True, False)
test_input = torch.utils.data.DataLoader(test_set, batch_size=1, num_workers=2)

In [None]:
model = BERTClassifier(bertmodel, dr_rate=0.5).to(device)

checkpoint = torch.load('/content/drive/MyDrive/natural/torchckpt/model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

In [None]:
model.eval()


result = []

for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_input)):
  token_ids = token_ids.long().to(device)
  segment_ids = segment_ids.long().to(device)
  valid_length= valid_length
  label = label.long().to(device)
  out = model(token_ids, valid_length, segment_ids)
  for i in out:
      logits = i
      logits = logits.detach().cpu().numpy()
      final = np.argmax(logits)
  result.append(final)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  # Remove the CWD from sys.path while we load stuff.


  0%|          | 0/43576 [00:00<?, ?it/s]

In [None]:
submission = pd.read_csv('/content/drive/MyDrive/natural/sample_submission.csv')

submission['label'] = np.array(result)

submission.to_csv('/content/drive/MyDrive/natural/kobert_clean_baseline.csv', index=False)