In [None]:
!pip install mxnet-cu101
!pip install gluonnlp pandas tqdm
!pip install sentencepiece==0.1.85
!pip install transformers==2.1.1
!pip install torch==1.5.0

In [None]:
!pip install git+https://git@github.com/SKTBrain/KoBERT.git@master

In [0]:
import pandas as pd
import numpy as np

import torch
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from torch.utils.data import Dataset, DataLoader

import gluonnlp as nlp
from tqdm import tqdm, tqdm_notebook

from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model

from transformers import AdamW
from transformers.optimization import WarmupLinearSchedule

In [0]:
device = torch.device("cuda:0")

In [8]:
bertmodel, vocab = get_pytorch_kobert_model()

[██████████████████████████████████████████████████]
[██████████████████████████████████████████████████]


In [0]:
train.to_csv('train.txt', sep='\t', index=False)
test.to_csv('test.txt', sep='\t', index=False)

In [0]:
dataset_train = nlp.data.TSVDataset("train.txt", field_indices = [1, 2], num_discard_samples=1)
dataset_test = nlp.data.TSVDataset("test.txt", field_indices = [1, 2], num_discard_samples=1)

In [11]:
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

using cached model


In [0]:
class BERTDataset(Dataset):
    
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len, pad, pair):
        
        transform = nlp.data.BERTSentenceTransform(bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)

        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i],))

    def __len__(self):
        return (len(self.labels))
    

In [0]:
max_len = 64
batch_size = 64
warmup_ratio = 0.02
num_epochs = 50
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-5

In [0]:
data_train = BERTDataset(dataset_train, 0, 1, tok, max_len, True, False)
data_test = BERTDataset(dataset_test, 0, 1, tok, max_len, True, False)

In [0]:
train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, num_workers=5)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=5)

In [0]:
class BERTClassifier(nn.Module):
    
    def __init__(self, bert, hidden_size=768, num_classes=2, dr_rate=None, params=None):
        
        super(BERTClassifier, self).__init__()
        
        self.bert = bert
        self.dr_rate = dr_rate
        self.classifier = nn.Linear(hidden_size, num_classes)
        
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    
    def gen_attention_mask(self, token_ids, valid_length):
        
        attention_mask = torch.zeros_like(token_ids)
        
        for i, v in enumerate(valid_length):
            attention_mask[i][:v]=1
            
        return attention_mask.float()

    
    def forward(self, token_ids, valid_length, segment_ids):
        
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids=token_ids, token_type_ids=segment_ids.long(), attention_mask=attention_mask.float().to(token_ids.device))
        
        if self.dr_rate:
            out = self.dropout(pooler)
            
        return self.classifier(out)
    

In [0]:
model = BERTClassifier(bertmodel, dr_rate=0.5).to(device)

In [0]:
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.01}
]

In [0]:
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

In [0]:
t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

In [0]:
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_step, t_total=t_total)

In [0]:
def calc_accuracy(X, Y):
    
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy() / max_indices.size()[0]
    
    return train_acc


In [0]:
for e in range(num_epochs):

    train_acc = 0.0
    test_acc = 0.0

    model.train()

    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):

        optimizer.zero_grad()

        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)

        valid_length = valid_length
        
        label = label.long().to(device)

        out = model(token_ids, valid_length, segment_ids)

        loss = loss_fn(out, label)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        
        optimizer.step()
        scheduler.step()
        
        train_acc += calc_accuracy(out, label)

        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e + 1, batch_id + 1, loss.data.cpu().numpy(), train_acc / (batch_id + 1)))

    print("epoch {} train acc {}".format(e + 1, train_acc / (batch_id + 1)))

    model.eval()

    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_dataloader)):

        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        
        valid_length = valid_length
        
        label = label.long().to(device)
        
        out = model(token_ids, valid_length, segment_ids)
        
        test_acc += calc_accuracy(out, label)
        
    print("epoch {} test acc {}".format(e + 1, test_acc / (batch_id + 1)))
    

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha)


epoch 1 batch id 1 loss 0.692930281162262 train acc 0.53125
epoch 1 batch id 201 loss 0.7156891822814941 train acc 0.5268967661691543
epoch 1 batch id 401 loss 0.7080405950546265 train acc 0.5340554862842892
epoch 1 batch id 601 loss 0.7054370641708374 train acc 0.5375155990016639
epoch 1 batch id 801 loss 0.717943549156189 train acc 0.5377848002496879

epoch 1 train acc 0.5373795612956129


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 1 test acc 0.5448069852941176


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 2 batch id 1 loss 0.6814784407615662 train acc 0.53125
epoch 2 batch id 201 loss 0.6741767525672913 train acc 0.544853855721393
epoch 2 batch id 401 loss 0.6934770345687866 train acc 0.55143391521197
epoch 2 batch id 601 loss 0.6703057289123535 train acc 0.5588342346089851
epoch 2 batch id 801 loss 0.68296217918396 train acc 0.5645677278401997

epoch 2 train acc 0.5645256765067651


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 2 test acc 0.5871629901960784


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 3 batch id 1 loss 0.6804776191711426 train acc 0.5625
epoch 3 batch id 201 loss 0.6585829257965088 train acc 0.6001243781094527
epoch 3 batch id 401 loss 0.6297572255134583 train acc 0.6083229426433915
epoch 3 batch id 601 loss 0.6650080680847168 train acc 0.6206842762063228
epoch 3 batch id 801 loss 0.5909916758537292 train acc 0.6337780898876404

epoch 3 train acc 0.6342263735137351


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 3 test acc 0.6237745098039216


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 4 batch id 1 loss 0.6280972361564636 train acc 0.6875
epoch 4 batch id 201 loss 0.5934402942657471 train acc 0.7018034825870647
epoch 4 batch id 401 loss 0.5327677726745605 train acc 0.7041380922693267
epoch 4 batch id 601 loss 0.4488980174064636 train acc 0.7157341930116472
epoch 4 batch id 801 loss 0.4769892394542694 train acc 0.7282303370786517

epoch 4 train acc 0.728920664206642


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 4 test acc 0.6468290441176471


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 5 batch id 1 loss 0.5188838243484497 train acc 0.765625
epoch 5 batch id 201 loss 0.47200506925582886 train acc 0.7866915422885572
epoch 5 batch id 401 loss 0.5321133732795715 train acc 0.7900950748129676
epoch 5 batch id 601 loss 0.3766273558139801 train acc 0.7979409317803661
epoch 5 batch id 801 loss 0.5463053584098816 train acc 0.8051068976279651

epoch 5 train acc 0.8051481139811398


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 5 test acc 0.6477481617647058


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 6 batch id 1 loss 0.4682958126068115 train acc 0.8125
epoch 6 batch id 201 loss 0.3030034899711609 train acc 0.8425062189054726
epoch 6 batch id 401 loss 0.37666550278663635 train acc 0.8428148379052369
epoch 6 batch id 601 loss 0.32966142892837524 train acc 0.8479877287853578
epoch 6 batch id 801 loss 0.45731475949287415 train acc 0.8536009675405742

epoch 6 train acc 0.8537963817138171


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 6 test acc 0.6539522058823529


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 7 batch id 1 loss 0.2964637279510498 train acc 0.890625
epoch 7 batch id 201 loss 0.289321631193161 train acc 0.8737562189054726
epoch 7 batch id 401 loss 0.3423621952533722 train acc 0.8739479426433915
epoch 7 batch id 601 loss 0.3185420334339142 train acc 0.8774178452579035
epoch 7 batch id 801 loss 0.3027728199958801 train acc 0.8799547440699126

epoch 7 train acc 0.8800468942189421


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 7 test acc 0.6672794117647058


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 8 batch id 1 loss 0.22589358687400818 train acc 0.9375
epoch 8 batch id 201 loss 0.2522638142108917 train acc 0.902285447761194
epoch 8 batch id 401 loss 0.48050081729888916 train acc 0.9006390274314214
epoch 8 batch id 601 loss 0.2216798961162567 train acc 0.9012583194675541
epoch 8 batch id 801 loss 0.13167265057563782 train acc 0.9032069288389513

epoch 8 train acc 0.902976373513735


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 8 test acc 0.6643688725490197


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 9 batch id 1 loss 0.2561289072036743 train acc 0.921875
epoch 9 batch id 201 loss 0.08328679949045181 train acc 0.9154228855721394
epoch 9 batch id 401 loss 0.45293453335762024 train acc 0.9168874688279302
epoch 9 batch id 601 loss 0.13704177737236023 train acc 0.9185472129783694
epoch 9 batch id 801 loss 0.14496387541294098 train acc 0.9186953807740325

epoch 9 train acc 0.9188691574415744


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 9 test acc 0.6710324754901961


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 10 batch id 1 loss 0.2308560609817505 train acc 0.921875
epoch 10 batch id 201 loss 0.09695502370595932 train acc 0.9196983830845771
epoch 10 batch id 401 loss 0.2843998670578003 train acc 0.922069825436409
epoch 10 batch id 601 loss 0.12790842354297638 train acc 0.9264247088186356
epoch 10 batch id 801 loss 0.13445614278316498 train acc 0.9281952247191011

epoch 10 train acc 0.9281544690446903


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 10 test acc 0.6770833333333334


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 11 batch id 1 loss 0.20925888419151306 train acc 0.890625
epoch 11 batch id 201 loss 0.06553097814321518 train acc 0.9344682835820896
epoch 11 batch id 401 loss 0.15908880531787872 train acc 0.937110349127182
epoch 11 batch id 601 loss 0.09162919968366623 train acc 0.9377599833610649
epoch 11 batch id 801 loss 0.23906871676445007 train acc 0.9362710674157303

epoch 11 train acc 0.9360496104961049


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 11 test acc 0.6724877450980392


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 12 batch id 1 loss 0.20514090359210968 train acc 0.90625
epoch 12 batch id 201 loss 0.1589888036251068 train acc 0.9421641791044776
epoch 12 batch id 401 loss 0.1688346415758133 train acc 0.9423316708229427
epoch 12 batch id 601 loss 0.15104588866233826 train acc 0.942231697171381
epoch 12 batch id 801 loss 0.09832236170768738 train acc 0.9431374843945068

epoch 12 train acc 0.9430722119721198


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 12 test acc 0.6787683823529411


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 13 batch id 1 loss 0.25816062092781067 train acc 0.90625
epoch 13 batch id 201 loss 0.08018875867128372 train acc 0.9498600746268657
epoch 13 batch id 401 loss 0.15577250719070435 train acc 0.9480595386533666
epoch 13 batch id 601 loss 0.05561491847038269 train acc 0.9488872712146422
epoch 13 batch id 801 loss 0.055585119873285294 train acc 0.9506671348314607

epoch 13 train acc 0.9507380073800739


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 13 test acc 0.6776960784313726


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 14 batch id 1 loss 0.14896172285079956 train acc 0.9375
epoch 14 batch id 201 loss 0.23625095188617706 train acc 0.9552238805970149
epoch 14 batch id 401 loss 0.12904895842075348 train acc 0.9532808603491272
epoch 14 batch id 601 loss 0.05612000450491905 train acc 0.9545289101497504
epoch 14 batch id 801 loss 0.07964742183685303 train acc 0.9551342072409488

epoch 14 train acc 0.9546945469454694


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 14 test acc 0.6868872549019608


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 15 batch id 1 loss 0.12512989342212677 train acc 0.96875
epoch 15 batch id 201 loss 0.04740113392472267 train acc 0.9585665422885572
epoch 15 batch id 401 loss 0.16497136652469635 train acc 0.95768391521197
epoch 15 batch id 601 loss 0.0693722516298294 train acc 0.9585326539101497
epoch 15 batch id 801 loss 0.1298936903476715 train acc 0.9590550873907615

epoch 15 train acc 0.9588714637146372


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 15 test acc 0.6870404411764706


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 16 batch id 1 loss 0.09715751558542252 train acc 0.953125
epoch 16 batch id 201 loss 0.10357582569122314 train acc 0.9621424129353234
epoch 16 batch id 401 loss 0.11330382525920868 train acc 0.9608790523690773
epoch 16 batch id 601 loss 0.089016392827034 train acc 0.9625883943427621
epoch 16 batch id 801 loss 0.04175589978694916 train acc 0.9637757490636704

epoch 16 train acc 0.963599323493235


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 16 test acc 0.6803002450980392


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 17 batch id 1 loss 0.19140055775642395 train acc 0.9375
epoch 17 batch id 201 loss 0.07048177719116211 train acc 0.962764303482587
epoch 17 batch id 401 loss 0.12977814674377441 train acc 0.9624766209476309
epoch 17 batch id 601 loss 0.05514690279960632 train acc 0.9639403078202995
epoch 17 batch id 801 loss 0.07592938095331192 train acc 0.9653167915106118

epoch 17 train acc 0.9651932144321443


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 17 test acc 0.6850490196078431


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 18 batch id 1 loss 0.05971921980381012 train acc 0.984375
epoch 18 batch id 201 loss 0.13260996341705322 train acc 0.9661847014925373
epoch 18 batch id 401 loss 0.13884475827217102 train acc 0.9666848503740648
epoch 18 batch id 601 loss 0.12568482756614685 train acc 0.966410149750416
epoch 18 batch id 801 loss 0.0428154394030571 train acc 0.9671894506866417

epoch 18 train acc 0.9671356088560885


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 18 test acc 0.6886488970588235


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 19 batch id 1 loss 0.08581168949604034 train acc 0.96875
epoch 19 batch id 201 loss 0.02817576751112938 train acc 0.9704601990049752
epoch 19 batch id 401 loss 0.04964839667081833 train acc 0.969334476309227
epoch 19 batch id 601 loss 0.20482110977172852 train acc 0.9698159317803661
epoch 19 batch id 801 loss 0.11519742757081985 train acc 0.9700569600499376

epoch 19 train acc 0.9699774497744977


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 19 test acc 0.6824448529411765


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 20 batch id 1 loss 0.06300149112939835 train acc 0.984375
epoch 20 batch id 201 loss 0.07845138013362885 train acc 0.9717817164179104
epoch 20 batch id 401 loss 0.09650184214115143 train acc 0.9707761845386533
epoch 20 batch id 601 loss 0.06060398370027542 train acc 0.9715058236272879
epoch 20 batch id 801 loss 0.013496166095137596 train acc 0.9719881398252185

epoch 20 train acc 0.9719570008200082


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 20 test acc 0.6893382352941176


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 21 batch id 1 loss 0.056874219328165054 train acc 0.984375
epoch 21 batch id 201 loss 0.13110560178756714 train acc 0.974735696517413
epoch 21 batch id 401 loss 0.07515087723731995 train acc 0.9726854738154613
epoch 21 batch id 601 loss 0.06371209770441055 train acc 0.9732737104825291
epoch 21 batch id 801 loss 0.03488355502486229 train acc 0.9741534019975031

epoch 21 train acc 0.9738776137761378


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 21 test acc 0.6901807598039216


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 22 batch id 1 loss 0.07412464916706085 train acc 0.96875
epoch 22 batch id 201 loss 0.1301984041929245 train acc 0.9735696517412935
epoch 22 batch id 401 loss 0.13240262866020203 train acc 0.9739323566084788
epoch 22 batch id 601 loss 0.06818997859954834 train acc 0.9751975873544093
epoch 22 batch id 801 loss 0.012170707806944847 train acc 0.9754993757802747

epoch 22 train acc 0.97539975399754


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 22 test acc 0.6868872549019608


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 23 batch id 1 loss 0.055073875933885574 train acc 0.984375
epoch 23 batch id 201 loss 0.05522094666957855 train acc 0.9780783582089553
epoch 23 batch id 401 loss 0.1758449375629425 train acc 0.975802680798005
epoch 23 batch id 601 loss 0.005678997840732336 train acc 0.9767834858569051
epoch 23 batch id 801 loss 0.11004841327667236 train acc 0.9775280898876404

epoch 23 train acc 0.9774369618696187


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 23 test acc 0.6883425245098039


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 24 batch id 1 loss 0.08312152326107025 train acc 0.96875
epoch 24 batch id 201 loss 0.08087216317653656 train acc 0.9757462686567164
epoch 24 batch id 401 loss 0.04915575310587883 train acc 0.9751792394014963
epoch 24 batch id 601 loss 0.05406063050031662 train acc 0.9768094841930116
epoch 24 batch id 801 loss 0.07738521695137024 train acc 0.9780547752808989

epoch 24 train acc 0.9780506867568676


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 24 test acc 0.6920955882352942


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 25 batch id 1 loss 0.06915892660617828 train acc 0.984375
epoch 25 batch id 201 loss 0.026809310540556908 train acc 0.9803327114427861
epoch 25 batch id 401 loss 0.08428865671157837 train acc 0.9796602244389028
epoch 25 batch id 601 loss 0.08819082379341125 train acc 0.9798512895174709
epoch 25 batch id 801 loss 0.07845597714185715 train acc 0.9798884207240949

epoch 25 train acc 0.9798739237392374


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 25 test acc 0.6888786764705882


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 26 batch id 1 loss 0.23546326160430908 train acc 0.9375
epoch 26 batch id 201 loss 0.008207015693187714 train acc 0.9807213930348259
epoch 26 batch id 401 loss 0.07974331080913544 train acc 0.9785692019950125
epoch 26 batch id 601 loss 0.025305375456809998 train acc 0.9797212978369384
epoch 26 batch id 801 loss 0.018563343212008476 train acc 0.9808247503121099

epoch 26 train acc 0.9807208384583846


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 26 test acc 0.6883425245098039


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 27 batch id 1 loss 0.047116804867982864 train acc 0.984375
epoch 27 batch id 201 loss 0.0214090496301651 train acc 0.9803327114427861
epoch 27 batch id 401 loss 0.11256546527147293 train acc 0.9795043640897756
epoch 27 batch id 601 loss 0.008239291608333588 train acc 0.9808392262895175
epoch 27 batch id 801 loss 0.018261387944221497 train acc 0.9812148876404494

epoch 27 train acc 0.9812410311603116


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 27 test acc 0.6953890931372549


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 28 batch id 1 loss 0.1691398322582245 train acc 0.96875
epoch 28 batch id 201 loss 0.038298483937978745 train acc 0.9820429104477612
epoch 28 batch id 401 loss 0.12452568113803864 train acc 0.9818033042394015
epoch 28 batch id 601 loss 0.004071078263223171 train acc 0.9830230865224625
epoch 28 batch id 801 loss 0.09122873842716217 train acc 0.9831655742821473

epoch 28 train acc 0.9830104551045511


HBox(children=(FloatProgress(value=0.0, max=204.0), HTML(value='')))


epoch 28 test acc 0.6933976715686274


HBox(children=(FloatProgress(value=0.0, max=813.0), HTML(value='')))

epoch 29 batch id 1 loss 0.062094125896692276 train acc 0.984375
epoch 29 batch id 201 loss 0.026307323947548866 train acc 0.9828980099502488
epoch 29 batch id 401 loss 0.1942361444234848 train acc 0.9831281172069826
epoch 29 batch id 601 loss 0.007811684161424637 train acc 0.9833350665557404
