In [1]:
!git clone https://github.com/toy-f-rebellion/toy_ai.git
%cd /content/toy_ai
!pip install -r requirements.txt

Cloning into 'toy_ai'...
remote: Enumerating objects: 97, done.[K
remote: Counting objects: 100% (97/97), done.[K
remote: Compressing objects: 100% (78/78), done.[K
remote: Total 97 (delta 28), reused 14 (delta 3), pack-reused 0[K
Receiving objects: 100% (97/97), 2.51 MiB | 4.90 MiB/s, done.
Resolving deltas: 100% (28/28), done.
/content/toy_ai
Cloning into 'Modified-KoBERT'...
remote: Enumerating objects: 449, done.[K
remote: Counting objects: 100% (174/174), done.[K
remote: Compressing objects: 100% (66/66), done.[K
remote: Total 449 (delta 136), reused 118 (delta 108), pack-reused 275[K
Receiving objects: 100% (449/449), 226.32 KiB | 3.02 MiB/s, done.
Resolving deltas: 100% (227/227), done.
Collecting kobert_tokenizer (from -r requirements.txt (line 11))
  Cloning https://github.com/SKTBrain/KoBERT.git to /tmp/pip-install-fx61d4s1/kobert-tokenizer_8037825d59ac43a0aa57db92a45d9a75
  Running command git clone --filter=blob:none --quiet https://github.com/SKTBrain/KoBERT.git /t

# 1.환경설정

In [7]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


## (1) 라이브러리 불러오기

In [8]:
import warnings
warnings.filterwarnings('ignore')

# Setting Library
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook
import pandas as pd
from sklearn.model_selection import train_test_split
import pickle

# koBERT
from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model

# Transformers
from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup

device = torch.device('cuda')

## (2) 데이터 불러오기

In [5]:
with open('/content/toy_ai/train_data.pkl', 'rb') as f:
    train_data = pickle.load(f)
with open('/content/toy_ai/test_data.pkl', 'rb') as f:
    test_data = pickle.load(f)

## (3) KoBERT 모델 불러오기

In [9]:
# KoBERT로부터 model, vocabulary 불러오기
bertmodel, vocab = get_pytorch_kobert_model()

/content/toy_ai/.cache/kobert_v1.zip[██████████████████████████████████████████████████]
/content/toy_ai/.cache/kobert_news_wiki_ko_cased-1087f8699e.spiece[██████████████████████████████████████████████████]


# 2.Modeling

## (1) 입력 데이터셋 토큰화

- 함수 정의

In [10]:
# 각 데이터가 BERT 모델의 입력으로 들어갈 수 있도록 함수 정의
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):

        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)

        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]
    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))

In [11]:
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

using cached model. /content/toy_ai/.cache/kobert_news_wiki_ko_cased-1087f8699e.spiece


- Setting parameters

In [12]:
max_len = 64
batch_size = 64
warmup_ratio = 0.1
num_epochs = 1000
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-5

- 토큰화

In [13]:
train_data = BERTDataset(train_data, 0, 1, tok, max_len, True, False)
test_data = BERTDataset(test_data,0, 1, tok, max_len, True, False)

- torch 형식으로 반환

In [14]:
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=5)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, num_workers=5)

## (2) EarlyStopping

In [15]:
# Early Stopping
class EarlyStopping:
    def __init__(self, patience=5, verbose=False):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float("Inf")

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...")
        torch.save(model.state_dict(), "checkpoint.pt")
        self.val_loss_min = val_loss

## (3) Modeling

- KoBERT 함수 정의

In [16]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes = 7,   # 감정 클래스 수로 조정
                 dr_rate = None,
                 params = None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate

        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p = dr_rate)

    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)

        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device),return_dict = False)
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)

In [17]:
#BERT 모델 불러오기
model = BERTClassifier(bertmodel,  dr_rate=0.5).to(device)

#optimizer와 schedule 설정
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss() # 다중분류를 위한 대표적인 loss func

t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x7e0675e25420>

- Training

In [18]:
train_history=[]
test_history=[]
loss_history=[]

early_stopping = EarlyStopping(patience=5, verbose=True) # EarlyStopping 적용

for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)

        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
            train_history.append(train_acc / (batch_id+1))
            loss_history.append(loss.data.cpu().numpy())
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
    train_history.append(train_acc / (batch_id+1))

    model.eval()
    y_pred, y_test = [], []
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_acc += calc_accuracy(out, label)
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))
    test_history.append(test_acc / (batch_id+1))

    # Early stopping 및 ModelCheckpoint 적용
    epoch_test_loss = loss.item()
    early_stopping(epoch_test_loss, model)

    if early_stopping.early_stop:
        print("Early stopping")
        break

# 최적의 모델 불러오기
model.load_state_dict(torch.load("checkpoint.pt"))

  0%|          | 0/243 [00:00<?, ?it/s]

epoch 1 batch id 1 loss 2.0302786827087402 train acc 0.15625
epoch 1 batch id 201 loss 1.9005939960479736 train acc 0.14412313432835822
epoch 1 train acc 0.1509539842873176


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 1 test acc 0.18812939110070256
Validation loss decreased (inf --> 2.063841). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 2 batch id 1 loss 1.9411746263504028 train acc 0.140625
epoch 2 batch id 201 loss 1.8398562669754028 train acc 0.20856654228855723
epoch 2 train acc 0.21858632622521512


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 2 test acc 0.3555986533957845
Validation loss decreased (2.063841 --> 1.755566). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 3 batch id 1 loss 1.7424836158752441 train acc 0.296875
epoch 3 batch id 201 loss 1.6935451030731201 train acc 0.33488805970149255
epoch 3 train acc 0.34677796483352036


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 3 test acc 0.49732874707259955
Validation loss decreased (1.755566 --> 1.598049). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 4 batch id 1 loss 1.5568076372146606 train acc 0.5
epoch 4 batch id 201 loss 1.4199926853179932 train acc 0.5165578358208955
epoch 4 train acc 0.5341902824541713


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 4 test acc 0.7059133489461358
Validation loss decreased (1.598049 --> 1.354909). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 5 batch id 1 loss 1.2347124814987183 train acc 0.765625
epoch 5 batch id 201 loss 1.042399287223816 train acc 0.6926305970149254
epoch 5 train acc 0.7034815750093527


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 5 test acc 0.7977385831381734
Validation loss decreased (1.354909 --> 0.984819). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 6 batch id 1 loss 0.8094711303710938 train acc 0.90625
epoch 6 batch id 201 loss 0.7163453698158264 train acc 0.8055814676616916
epoch 6 train acc 0.8118043864571642


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 6 test acc 0.8532786885245901
Validation loss decreased (0.984819 --> 0.615283). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 7 batch id 1 loss 0.5325483679771423 train acc 0.90625
epoch 7 batch id 201 loss 0.48149412870407104 train acc 0.8599191542288557
epoch 7 train acc 0.8648405349794238


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 7 test acc 0.8796179742388759
Validation loss decreased (0.615283 --> 0.280110). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 8 batch id 1 loss 0.31621530652046204 train acc 0.96875
epoch 8 batch id 201 loss 0.37122464179992676 train acc 0.8910914179104478
epoch 8 train acc 0.8945473251028807


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 8 test acc 0.8940866510538641
Validation loss decreased (0.280110 --> 0.182634). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 9 batch id 1 loss 0.25102970004081726 train acc 0.96875
epoch 9 batch id 201 loss 0.29193228483200073 train acc 0.9095926616915423
epoch 9 train acc 0.9112011316872428


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 9 test acc 0.8997218969555034
Validation loss decreased (0.182634 --> 0.084821). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 10 batch id 1 loss 0.18063707649707794 train acc 0.96875
epoch 10 batch id 201 loss 0.24039623141288757 train acc 0.9228855721393034
epoch 10 train acc 0.9241898148148148


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 10 test acc 0.9063817330210773
Validation loss decreased (0.084821 --> 0.063178). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 11 batch id 1 loss 0.14516757428646088 train acc 0.953125
epoch 11 batch id 201 loss 0.22549870610237122 train acc 0.9333799751243781
epoch 11 train acc 0.9355066872427984


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 11 test acc 0.9081747658079625
Validation loss decreased (0.063178 --> 0.035831). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 12 batch id 1 loss 0.07803201675415039 train acc 0.984375
epoch 12 batch id 201 loss 0.20475730299949646 train acc 0.9431747512437811
epoch 12 train acc 0.943994341563786


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 12 test acc 0.9089432084309133
Validation loss decreased (0.035831 --> 0.032701). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 13 batch id 1 loss 0.14848850667476654 train acc 0.96875
epoch 13 batch id 201 loss 0.17096775770187378 train acc 0.9478389303482587
epoch 13 train acc 0.9494598765432098


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 13 test acc 0.9097116510538641
Validation loss decreased (0.032701 --> 0.017112). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 14 batch id 1 loss 0.0864713117480278 train acc 0.984375
epoch 14 batch id 201 loss 0.17376822233200073 train acc 0.9575559701492538
epoch 14 train acc 0.9586548353909465


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 14 test acc 0.912317037470726
EarlyStopping counter: 1 out of 5


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 15 batch id 1 loss 0.03237147629261017 train acc 1.0
epoch 15 batch id 201 loss 0.19523926079273224 train acc 0.9607431592039801
epoch 15 train acc 0.9621270576131687


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 15 test acc 0.912317037470726
Validation loss decreased (0.017112 --> 0.012283). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 16 batch id 1 loss 0.03229549527168274 train acc 1.0
epoch 16 batch id 201 loss 0.13124492764472961 train acc 0.966806592039801
epoch 16 train acc 0.9674639917695473


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 16 test acc 0.9153908079625293
Validation loss decreased (0.012283 --> 0.011894). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 17 batch id 1 loss 0.07131589949131012 train acc 0.984375
epoch 17 batch id 201 loss 0.13169802725315094 train acc 0.9703047263681592
epoch 17 train acc 0.9711291152263375


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 17 test acc 0.9145784543325527
Validation loss decreased (0.011894 --> 0.005770). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 18 batch id 1 loss 0.013992941938340664 train acc 1.0
epoch 18 batch id 201 loss 0.07916980236768723 train acc 0.9733364427860697
epoch 18 train acc 0.974022633744856


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 18 test acc 0.9109923887587822
Validation loss decreased (0.005770 --> 0.004734). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 19 batch id 1 loss 0.09010907262563705 train acc 0.96875
epoch 19 batch id 201 loss 0.04058951139450073 train acc 0.9755130597014925
epoch 19 train acc 0.9760802469135802


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 19 test acc 0.9166276346604215
Validation loss decreased (0.004734 --> 0.004029). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 20 batch id 1 loss 0.04644201695919037 train acc 0.984375
epoch 20 batch id 201 loss 0.07728546112775803 train acc 0.9783115671641791
epoch 20 train acc 0.9789094650205762


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 20 test acc 0.9089871194379391
Validation loss decreased (0.004029 --> 0.003221). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 21 batch id 1 loss 0.052427612245082855 train acc 0.984375
epoch 21 batch id 201 loss 0.03474579006433487 train acc 0.9787779850746269
epoch 21 train acc 0.979809670781893


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 21 test acc 0.9118047423887587
Validation loss decreased (0.003221 --> 0.002668). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 22 batch id 1 loss 0.0136365732178092 train acc 1.0
epoch 22 batch id 201 loss 0.03769530728459358 train acc 0.9821983830845771
epoch 22 train acc 0.9827031893004116


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 22 test acc 0.9176083138173302
EarlyStopping counter: 1 out of 5


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 23 batch id 1 loss 0.02658260613679886 train acc 0.984375
epoch 23 batch id 201 loss 0.017800800502300262 train acc 0.9827425373134329
epoch 23 train acc 0.98315329218107


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 23 test acc 0.9140661592505854
Validation loss decreased (0.002668 --> 0.001581). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 24 batch id 1 loss 0.0036328190471976995 train acc 1.0
epoch 24 batch id 201 loss 0.011637282557785511 train acc 0.9817319651741293
epoch 24 train acc 0.9821244855967078


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 24 test acc 0.914322306791569
Validation loss decreased (0.001581 --> 0.001396). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 25 batch id 1 loss 0.006546442396938801 train acc 1.0
epoch 25 batch id 201 loss 0.08820851147174835 train acc 0.9844527363184079
epoch 25 train acc 0.9849537037037037


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 25 test acc 0.9181206088992975
EarlyStopping counter: 1 out of 5


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 26 batch id 1 loss 0.004659601487219334 train acc 1.0
epoch 26 batch id 201 loss 0.12608622014522552 train acc 0.9858519900497512
epoch 26 train acc 0.9864326131687243


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 26 test acc 0.9176522248243559
Validation loss decreased (0.001396 --> 0.000843). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 27 batch id 1 loss 0.006362922955304384 train acc 1.0
epoch 27 batch id 201 loss 0.0551166795194149 train acc 0.9868625621890548
epoch 27 train acc 0.9868184156378601


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 27 test acc 0.9160714285714285
Validation loss decreased (0.000843 --> 0.000838). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 28 batch id 1 loss 0.0018340353854000568 train acc 1.0
epoch 28 batch id 201 loss 0.003591499524191022 train acc 0.9862406716417911
epoch 28 train acc 0.9867541152263375


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 28 test acc 0.9109923887587822
Validation loss decreased (0.000838 --> 0.000691). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 29 batch id 1 loss 0.003117840038612485 train acc 1.0
epoch 29 batch id 201 loss 0.02799053117632866 train acc 0.9874067164179104
epoch 29 train acc 0.9864969135802469


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 29 test acc 0.9150907494145198
Validation loss decreased (0.000691 --> 0.000546). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 30 batch id 1 loss 0.003055269829928875 train acc 1.0
epoch 30 batch id 201 loss 0.06792864203453064 train acc 0.9873289800995025
epoch 30 train acc 0.9875900205761317


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 30 test acc 0.9209382318501171
Validation loss decreased (0.000546 --> 0.000484). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 31 batch id 1 loss 0.0029917601495981216 train acc 1.0
epoch 31 batch id 201 loss 0.005019358359277248 train acc 0.9870180348258707
epoch 31 train acc 0.9866898148148148


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 31 test acc 0.9091554449648712
EarlyStopping counter: 1 out of 5


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 32 batch id 1 loss 0.0036651163827627897 train acc 1.0
epoch 32 batch id 201 loss 0.01350217591971159 train acc 0.9867070895522388
epoch 32 train acc 0.9870756172839507


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 32 test acc 0.9140661592505854
EarlyStopping counter: 2 out of 5


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 33 batch id 1 loss 0.0021898930426687002 train acc 1.0
epoch 33 batch id 201 loss 0.024953579530119896 train acc 0.9885727611940298
epoch 33 train acc 0.9888117283950617


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 33 test acc 0.9161153395784543
EarlyStopping counter: 3 out of 5


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 34 batch id 1 loss 0.004097661003470421 train acc 1.0
epoch 34 batch id 201 loss 0.016104871407151222 train acc 0.9891169154228856
epoch 34 train acc 0.9895833333333334


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 34 test acc 0.9145784543325527
Validation loss decreased (0.000484 --> 0.000360). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 35 batch id 1 loss 0.004090734291821718 train acc 1.0
epoch 35 batch id 201 loss 0.0304888803511858 train acc 0.9891169154228856
epoch 35 train acc 0.9891332304526749


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 35 test acc 0.9173960772833724
EarlyStopping counter: 1 out of 5


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 36 batch id 1 loss 0.02004646323621273 train acc 0.984375
epoch 36 batch id 201 loss 0.01458366122096777 train acc 0.9884172885572139
epoch 36 train acc 0.9888117283950617


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 36 test acc 0.9153468969555034
EarlyStopping counter: 2 out of 5


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 37 batch id 1 loss 0.0011399193899706006 train acc 1.0
epoch 37 batch id 201 loss 0.06629371643066406 train acc 0.9893501243781094
epoch 37 train acc 0.9893261316872428


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 37 test acc 0.9138100117096019
EarlyStopping counter: 3 out of 5


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 38 batch id 1 loss 0.0027085545007139444 train acc 1.0
epoch 38 batch id 201 loss 0.0047619459219276905 train acc 0.9890391791044776
epoch 38 train acc 0.9895190329218106


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 38 test acc 0.9158591920374707
Validation loss decreased (0.000360 --> 0.000297). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 39 batch id 1 loss 0.005632504355162382 train acc 1.0
epoch 39 batch id 201 loss 0.013778961263597012 train acc 0.988106343283582
epoch 39 train acc 0.9886188271604939


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 39 test acc 0.9146223653395784
Validation loss decreased (0.000297 --> 0.000276). Saving model...


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 40 batch id 1 loss 0.004914435558021069 train acc 1.0
epoch 40 batch id 201 loss 0.011142357252538204 train acc 0.9890391791044776
epoch 40 train acc 0.9892618312757202


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 40 test acc 0.9181206088992975
EarlyStopping counter: 1 out of 5


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 41 batch id 1 loss 0.006962301209568977 train acc 1.0
epoch 41 batch id 201 loss 0.062268372625112534 train acc 0.9897388059701493
epoch 41 train acc 0.9902263374485597


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 41 test acc 0.9166276346604215
EarlyStopping counter: 2 out of 5


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 42 batch id 1 loss 0.003040491370484233 train acc 1.0
epoch 42 batch id 201 loss 0.01357151661068201 train acc 0.9890391791044776
epoch 42 train acc 0.9891975308641975


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 42 test acc 0.9153468969555034
EarlyStopping counter: 3 out of 5


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 43 batch id 1 loss 0.003461226122453809 train acc 1.0
epoch 43 batch id 201 loss 0.008647197857499123 train acc 0.9911380597014925
epoch 43 train acc 0.9905478395061729


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 43 test acc 0.9146223653395784
EarlyStopping counter: 4 out of 5


  0%|          | 0/243 [00:00<?, ?it/s]

epoch 44 batch id 1 loss 0.0037550036795437336 train acc 1.0
epoch 44 batch id 201 loss 0.08241549879312515 train acc 0.990282960199005
epoch 44 train acc 0.9904835390946503


  0%|          | 0/61 [00:00<?, ?it/s]

epoch 44 test acc 0.9158591920374707
EarlyStopping counter: 5 out of 5
Early stopping


<All keys matched successfully>

# 3. Model Save

In [19]:
torch.save(model, '/content/toy_ai/model.pth')