<a href="https://colab.research.google.com/github/ttogle918/news_by_kobert/blob/master/model_epoch20_len_128_kobert_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive

drive.mount('/content/drive')
data_path = '/content/drive/My Drive/Colab Notebooks/news_class9x1400'

Mounted at /content/drive


In [None]:
!pip install mxnet
!pip install gluonnlp pandas tqdm
!pip install sentencepiece
!pip install transformers
!pip install torch

In [None]:
!pip install git+https://git@github.com/SKTBrain/KoBERT.git@master

In [4]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook
from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model
from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup

In [5]:
# gpu 연산이 가능하면 'cuda:0', 아니면 'cpu' 출력
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [6]:
#bert 모델, vocab 불러오기
bertmodel, vocab = get_pytorch_kobert_model()

/content/.cache/kobert_v1.zip[██████████████████████████████████████████████████]
/content/.cache/kobert_news_wiki_ko_cased-1087f8699e.spiece[██████████████████████████████████████████████████]


In [7]:
import os
max_seq_len = 128
dataset_train = []
labels = []
for (path, dir, files) in os.walk(data_path):
    for filename in files:
        ext = os.path.splitext(filename)[-1]
        if ext == '.txt':
            with open("%s/%s" % (path, filename), encoding="utf-8") as f:
              # labels.append( path[path.rindex('/')+1:])
              # dataset_train.append(f.read()[:max_seq_len])
              
              label = path[path.rindex('/')+1:]
              text = f.read()
              text = text.split('\n')
              saved_text, s, i = '', 0, 0
              for t in text :
                  if i == 2 :
                      break
                  if len(t) > max_seq_len :
                      labels.append( label )
                      dataset_train.append(saved_text)
                      labels.append( label )
                      dataset_train.append(t)
                      saved_text, s = '', 0
                      i += 1
                  elif s + len(t) > max_seq_len :
                      labels.append( label )
                      dataset_train.append(saved_text)
                      saved_text, s = t, len(t)
                      i += 1
                  else :
                      saved_text += t
                      s += len(t)
              labels.append( label )
              dataset_train.append(saved_text)


len(dataset_train)

41118

## 전처리

In [8]:
import pandas as pd

# labels_classes
df = pd.DataFrame({'content' : dataset_train, 'label' : labels})

# dataset의 balance 여부 확인
df.groupby(by=['label']).count()

Unnamed: 0_level_0,content
label,Unnamed: 1_level_1
ITscience,4504
culture,4543
economy,4618
entertainment,4544
health,4473
life,4517
politic,4771
social,4622
sport,4526


In [9]:
from sklearn.preprocessing import LabelEncoder

label_encoder = LabelEncoder()
label_encoder.fit(df['label'])
num_labels = len(label_encoder.classes_)

df['encoded_label'] = np.asarray(label_encoder.transform(df['label']), dtype=np.int32)
df.head(3)

Unnamed: 0,content,label,encoded_label
0,"원칙 있는 사회를 갈망한 노무현을 만나다노무현 전 대통령의 서거 후, 얼마 지나지 ...",social,7
1,"전 대통령이 스스로 목숨을 끊은 것은 무책임이라는 교수의 말에 내가 그만, 발끈해 ...",social,7
2,무책임이라는 말은 신념을 위해 자신의 목숨을 던진 정치가에 대한 평가로 적절치 않다...,social,7


In [10]:
from sklearn.model_selection import train_test_split

# Split Train and Validation data
train_texts, test_texts, train_labels, test_labels = train_test_split(df["content"].to_list(), df["encoded_label"].to_list(), test_size=0.2, random_state=0, shuffle=True, stratify=df["encoded_label"].to_list())

print(len(train_texts), len(train_labels), len(test_texts))

32894 32894 8224


In [11]:
dataset_train, dataset_test = [], []

for text, label in zip(train_texts, train_labels) :
  dataset_train.append([text, label])
for text, label in zip(test_texts, test_labels) :
  dataset_test.append([text, label])
len(dataset_train), len(dataset_test)

(32894, 8224)

In [12]:
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)

        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))

In [13]:
## Setting parameters
max_len = 128
batch_size = 32
warmup_ratio = 0.1
num_epochs = 10
max_grad_norm = 1
log_interval = 200

In [14]:
#토큰화
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)
#BERTDataset 클래스 이용, TensorDataset으로 만들어주기
data_train = BERTDataset(dataset_train, 0, 1, tok, max_len, True, False)
data_test = BERTDataset(dataset_test, 0, 1, tok, max_len, True, False)

using cached model. /content/.cache/kobert_news_wiki_ko_cased-1087f8699e.spiece


In [15]:
#배치 및 데이터로더 설정
train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, num_workers=4)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=4)

In [16]:
del df
# del data_train
# del data_test
# del tok
# del tokenizer
# del dataset_train
# del dataset_test
del train_texts
del train_labels

In [17]:
class BERTClassifier(nn.Module):
    def __init__(self, bert, hidden_size = 768,
                 num_classes=num_labels, ## 분류할 클래스 수
                 dr_rate=None, params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)

In [18]:
model = BERTClassifier(bertmodel, dr_rate=0.5).to(device) #gpu

In [19]:
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

In [25]:
def train_for_epochs(num_epochs, train_dataloader, test_dataloader) :
  # Prepare optimizer and schedule (linear warmup and decay)
  no_decay = ['bias', 'LayerNorm.weight']
  optimizer_grouped_parameters = [
      {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
      {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
  ]

  optimizer = AdamW(optimizer_grouped_parameters, lr=5e-5)  # lr = learning rate
  loss_fn = nn.CrossEntropyLoss()

  t_total = len(train_dataloader) * num_epochs
  warmup_step = int(t_total * warmup_ratio)

  scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)
  
  # train
  for e in range(num_epochs):
      train_acc = 0.0
      test_acc = 0.0
      # training
      model.train()
      for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
          optimizer.zero_grad()
          token_ids = token_ids.long().to(device)
          segment_ids = segment_ids.long().to(device)
          valid_length= valid_length
          label = label.long().to(device)
          out = model(token_ids, valid_length, segment_ids)
          loss = loss_fn(out, label)
          loss.backward()
          torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
          optimizer.step()
          scheduler.step()  # Update learning rate schedule
          train_acc += calc_accuracy(out, label)
          if batch_id % log_interval == 0:
              print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
      print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
      model.eval()
      # validation
      for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_dataloader)):
          token_ids = token_ids.long().to(device)
          segment_ids = segment_ids.long().to(device)
          valid_length= valid_length
          label = label.long().to(device)
          out = model(token_ids, valid_length, segment_ids)
          test_acc += calc_accuracy(out, label)
      print("epoch {} validation acc {}".format(e+1, test_acc / (batch_id+1)))


## 1.

epoch 5 batch id 1 loss 0.07639920711517334 train acc 0.984375

epoch 5 train acc 0.977254746835443


epoch 5 validation acc 0.8131510416666666

+ Setting parameters
    + max_len = 64
    + batch_size = 64
    + warmup_ratio = 0.1
    + num_epochs = 5
    + max_grad_norm = 1
    + log_interval = 200
    + learning_rate =  5e-5


In [26]:
# train_for_epochs(num_epochs, train_dataloader, test_dataloader)

## 2
epoch 5 batch id 1 loss 0.02432561106979847 train acc 1.0

epoch 5 batch id 201 loss 0.13712140917778015 train acc 0.9689054726368159

epoch 5 train acc 0.9723214285714286

epoch 5 validation acc 0.8468253968253968

+ max_len = 128
+ batch_size = 32
+ warmup_ratio = 0.1
+ num_epochs = 5
+ max_grad_norm = 1
+ log_interval = 200
+ learning_rate =  5e-5

In [27]:
# train_for_epochs(num_epochs, train_dataloader, test_dataloader)

## 3.

epoch 5 batch id 1 loss 0.05628944933414459 train acc 1.0

epoch 5 batch id 201 loss 0.007393729407340288 train acc 0.9704601990049752

epoch 5 batch id 401 loss 0.018062269315123558 train acc 0.9730361596009975

epoch 5 batch id 601 loss 0.2231445163488388 train acc 0.9751455906821963

epoch 5 train acc 0.9745039682539682

epoch 5 validation acc 0.8631329113924051

+ max_len = 256
+ batch_size = 32
+ warmup_ratio = 0.1
+ num_epochs = 5
+ max_grad_norm = 1
+ log_interval = 200
+ learning_rate =  5e-5

## 4.

epoch 5 train acc 0.9691468253968254

epoch 5 validation acc 0.8753968253968254

+ max_len = 512
+ batch_size = 8
+ warmup_ratio = 0.1
+ num_epochs = 5
+ max_grad_norm = 1
+ log_interval = 200

In [28]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [29]:
train_for_epochs(20, train_dataloader, test_dataloader)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


  0%|          | 0/1028 [00:00<?, ?it/s]

epoch 1 batch id 1 loss 0.8518131971359253 train acc 0.65625
epoch 1 batch id 201 loss 0.8585792183876038 train acc 0.7271455223880597
epoch 1 batch id 401 loss 0.7697248458862305 train acc 0.736284289276808
epoch 1 batch id 601 loss 0.6003919839859009 train acc 0.7464122296173045
epoch 1 batch id 801 loss 0.8277593851089478 train acc 0.7567883895131086
epoch 1 batch id 1001 loss 0.775217592716217 train acc 0.7643294205794205
epoch 1 train acc 0.7647029020752271


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


  0%|          | 0/257 [00:00<?, ?it/s]

epoch 1 validation acc 0.7399075875486382


  0%|          | 0/1028 [00:00<?, ?it/s]

epoch 2 batch id 1 loss 0.9427915811538696 train acc 0.65625
epoch 2 batch id 201 loss 0.7082683444023132 train acc 0.7465796019900498
epoch 2 batch id 401 loss 0.6777576804161072 train acc 0.7537406483790524
epoch 2 batch id 601 loss 0.5744646191596985 train acc 0.7626351913477537
epoch 2 batch id 801 loss 0.6947447061538696 train acc 0.7724719101123596
epoch 2 batch id 1001 loss 0.5085083842277527 train acc 0.7797202797202797
epoch 2 train acc 0.7799894617380025


  0%|          | 0/257 [00:00<?, ?it/s]

epoch 2 validation acc 0.7340710116731517


  0%|          | 0/1028 [00:00<?, ?it/s]

epoch 3 batch id 1 loss 0.7468093037605286 train acc 0.71875
epoch 3 batch id 201 loss 0.55743408203125 train acc 0.792910447761194
epoch 3 batch id 401 loss 0.5291843414306641 train acc 0.7982387780548629
epoch 3 batch id 601 loss 0.45964327454566956 train acc 0.8060004159733777
epoch 3 batch id 801 loss 0.48825517296791077 train acc 0.8150358926342073
epoch 3 batch id 1001 loss 0.4480017125606537 train acc 0.8216471028971029
epoch 3 train acc 0.8221283236057069


  0%|          | 0/257 [00:00<?, ?it/s]

epoch 3 validation acc 0.7491488326848249


  0%|          | 0/1028 [00:00<?, ?it/s]

epoch 4 batch id 1 loss 0.7196599841117859 train acc 0.8125
epoch 4 batch id 201 loss 0.48828113079071045 train acc 0.8421952736318408
epoch 4 batch id 401 loss 0.5376256108283997 train acc 0.8431265586034913
epoch 4 batch id 601 loss 0.2988562285900116 train acc 0.8493136439267887
epoch 4 batch id 801 loss 0.28824976086616516 train acc 0.8521379525593009
epoch 4 batch id 1001 loss 0.138186976313591 train acc 0.8576423576423576
epoch 4 train acc 0.8579077496757458


  0%|          | 0/257 [00:00<?, ?it/s]

epoch 4 validation acc 0.7402723735408561


  0%|          | 0/1028 [00:00<?, ?it/s]

epoch 5 batch id 1 loss 0.5524066686630249 train acc 0.8125
epoch 5 batch id 201 loss 0.40970373153686523 train acc 0.8826181592039801
epoch 5 batch id 401 loss 0.4074438810348511 train acc 0.8786627182044888
epoch 5 batch id 601 loss 0.23068329691886902 train acc 0.879731697171381
epoch 5 batch id 801 loss 0.1749836653470993 train acc 0.8827637328339576
epoch 5 batch id 1001 loss 0.10233385115861893 train acc 0.8859265734265734
epoch 5 train acc 0.8856335116731517


  0%|          | 0/257 [00:00<?, ?it/s]

epoch 5 validation acc 0.7474464980544747


  0%|          | 0/1028 [00:00<?, ?it/s]

epoch 6 batch id 1 loss 0.5871011018753052 train acc 0.84375
epoch 6 batch id 201 loss 0.4316767454147339 train acc 0.8950559701492538
epoch 6 batch id 401 loss 0.43840834498405457 train acc 0.8963528678304239
epoch 6 batch id 601 loss 0.12696395814418793 train acc 0.8985024958402662
epoch 6 batch id 801 loss 0.4267238974571228 train acc 0.9005149812734082
epoch 6 batch id 1001 loss 0.18562452495098114 train acc 0.9027847152847153
epoch 6 train acc 0.902320444228275


  0%|          | 0/257 [00:00<?, ?it/s]

epoch 6 validation acc 0.743920233463035


  0%|          | 0/1028 [00:00<?, ?it/s]

epoch 7 batch id 1 loss 0.41620635986328125 train acc 0.84375
epoch 7 batch id 201 loss 0.5446838736534119 train acc 0.912157960199005
epoch 7 batch id 401 loss 0.3931114077568054 train acc 0.9136533665835411
epoch 7 batch id 601 loss 0.5425512194633484 train acc 0.9145694675540765
epoch 7 batch id 801 loss 0.15470312535762787 train acc 0.9158863920099876
epoch 7 batch id 1001 loss 0.15211740136146545 train acc 0.91624000999001
epoch 7 train acc 0.9156371595330739


  0%|          | 0/257 [00:00<?, ?it/s]

epoch 7 validation acc 0.7529182879377432


  0%|          | 0/1028 [00:00<?, ?it/s]

epoch 8 batch id 1 loss 0.36818474531173706 train acc 0.90625
epoch 8 batch id 201 loss 0.45562055706977844 train acc 0.9245957711442786
epoch 8 batch id 401 loss 0.39149683713912964 train acc 0.9234725685785536
epoch 8 batch id 601 loss 0.28385818004608154 train acc 0.9244488352745425
epoch 8 batch id 801 loss 0.33365023136138916 train acc 0.9252887016229713
epoch 8 batch id 1001 loss 0.25981605052948 train acc 0.9266046453546454
epoch 8 train acc 0.9258815661478599


  0%|          | 0/257 [00:00<?, ?it/s]

epoch 8 validation acc 0.7395428015564203


  0%|          | 0/1028 [00:00<?, ?it/s]

epoch 9 batch id 1 loss 0.24752385914325714 train acc 0.90625
epoch 9 batch id 201 loss 0.5155288577079773 train acc 0.929570895522388
epoch 9 batch id 401 loss 0.37127479910850525 train acc 0.9311876558603491
epoch 9 batch id 601 loss 0.056808341294527054 train acc 0.9318843594009983
epoch 9 batch id 801 loss 0.11401622742414474 train acc 0.9335986267166042
epoch 9 batch id 1001 loss 0.270746648311615 train acc 0.9343156843156843
epoch 9 train acc 0.9336332684824903


  0%|          | 0/257 [00:00<?, ?it/s]

epoch 9 validation acc 0.7508511673151751


  0%|          | 0/1028 [00:00<?, ?it/s]

epoch 10 batch id 1 loss 0.1302991807460785 train acc 0.96875
epoch 10 batch id 201 loss 0.33230769634246826 train acc 0.9364116915422885
epoch 10 batch id 401 loss 0.5459650158882141 train acc 0.9358634663341646
epoch 10 batch id 601 loss 0.08903054893016815 train acc 0.9360440931780366
epoch 10 batch id 801 loss 0.06233544647693634 train acc 0.9377340823970037
epoch 10 batch id 1001 loss 0.04438771307468414 train acc 0.938936063936064
epoch 10 train acc 0.9384666828793774


  0%|          | 0/257 [00:00<?, ?it/s]

epoch 10 validation acc 0.7653210116731517


  0%|          | 0/1028 [00:00<?, ?it/s]

epoch 11 batch id 1 loss 0.12023554742336273 train acc 0.96875
epoch 11 batch id 201 loss 0.3101385533809662 train acc 0.9430970149253731
epoch 11 batch id 401 loss 0.2006029188632965 train acc 0.9427992518703242
epoch 11 batch id 601 loss 0.0024206473026424646 train acc 0.942283693843594
epoch 11 batch id 801 loss 0.05962695926427841 train acc 0.9443274032459426
epoch 11 batch id 1001 loss 0.06242094188928604 train acc 0.9450237262737263
epoch 11 train acc 0.9445160505836576


  0%|          | 0/257 [00:00<?, ?it/s]

epoch 11 validation acc 0.7588764591439688


  0%|          | 0/1028 [00:00<?, ?it/s]

epoch 12 batch id 1 loss 0.26583942770957947 train acc 0.90625
epoch 12 batch id 201 loss 0.14128001034259796 train acc 0.9472947761194029
epoch 12 batch id 401 loss 0.18179872632026672 train acc 0.9480205735660848
epoch 12 batch id 601 loss 0.09410177916288376 train acc 0.9462354409317804
epoch 12 batch id 801 loss 0.06838011741638184 train acc 0.9479946941323346
epoch 12 batch id 1001 loss 0.051492754369974136 train acc 0.9491133866133866
epoch 12 train acc 0.9483807555123217


  0%|          | 0/257 [00:00<?, ?it/s]

epoch 12 validation acc 0.7592412451361867


  0%|          | 0/1028 [00:00<?, ?it/s]

epoch 13 batch id 1 loss 0.1312480866909027 train acc 0.9375
epoch 13 batch id 201 loss 0.13131403923034668 train acc 0.9485385572139303
epoch 13 batch id 401 loss 0.17932632565498352 train acc 0.9480985037406484
epoch 13 batch id 601 loss 0.1466955989599228 train acc 0.9481073211314476
epoch 13 batch id 801 loss 0.07401932030916214 train acc 0.9496722846441947
epoch 13 batch id 1001 loss 0.05479883775115013 train acc 0.9506431068931069
epoch 13 train acc 0.9500506647211413


  0%|          | 0/257 [00:00<?, ?it/s]

epoch 13 validation acc 0.7638618677042801


  0%|          | 0/1028 [00:00<?, ?it/s]

epoch 14 batch id 1 loss 0.12221885472536087 train acc 0.9375
epoch 14 batch id 201 loss 0.14922356605529785 train acc 0.9538246268656716
epoch 14 batch id 401 loss 0.19025015830993652 train acc 0.952930174563591
epoch 14 batch id 601 loss 0.0010590155143290758 train acc 0.9521110648918469
epoch 14 batch id 801 loss 0.06405200809240341 train acc 0.9532225343320849
epoch 14 batch id 1001 loss 0.05586092174053192 train acc 0.9539210789210789
epoch 14 train acc 0.9533357652399481


  0%|          | 0/257 [00:00<?, ?it/s]

epoch 14 validation acc 0.7653210116731517


  0%|          | 0/1028 [00:00<?, ?it/s]

epoch 15 batch id 1 loss 0.13573624193668365 train acc 0.9375
epoch 15 batch id 201 loss 0.12220367044210434 train acc 0.9547574626865671
epoch 15 batch id 401 loss 0.19581614434719086 train acc 0.9540211970074813
epoch 15 batch id 601 loss 0.0010231357300654054 train acc 0.9535669717138103
epoch 15 batch id 801 loss 0.06433054059743881 train acc 0.954939138576779
epoch 15 batch id 1001 loss 0.06019100174307823 train acc 0.9554195804195804
epoch 15 train acc 0.9548253080415046


  0%|          | 0/257 [00:00<?, ?it/s]

epoch 15 validation acc 0.7656857976653697


  0%|          | 0/1028 [00:00<?, ?it/s]

epoch 16 batch id 1 loss 0.13737276196479797 train acc 0.9375
epoch 16 batch id 201 loss 0.13567806780338287 train acc 0.9558457711442786
epoch 16 batch id 401 loss 0.1948680728673935 train acc 0.9540211970074813
epoch 16 batch id 601 loss 0.0005703341448679566 train acc 0.9539829450915142
epoch 16 batch id 801 loss 0.06485196202993393 train acc 0.9556413857677902
epoch 16 batch id 1001 loss 0.050106145441532135 train acc 0.9566058941058941
epoch 16 train acc 0.9558872405966277


  0%|          | 0/257 [00:00<?, ?it/s]

epoch 16 validation acc 0.7678745136186771


  0%|          | 0/1028 [00:00<?, ?it/s]

epoch 17 batch id 1 loss 0.12606869637966156 train acc 0.9375
epoch 17 batch id 201 loss 0.13917843997478485 train acc 0.9560012437810945
epoch 17 batch id 401 loss 0.19846078753471375 train acc 0.955423940149626
epoch 17 batch id 601 loss 0.0007560323574580252 train acc 0.9556988352745425
epoch 17 batch id 801 loss 0.05677701532840729 train acc 0.9569288389513109
epoch 17 batch id 1001 loss 0.0539688803255558 train acc 0.9575424575424576
epoch 17 train acc 0.9568620298313878


  0%|          | 0/257 [00:00<?, ?it/s]

epoch 17 validation acc 0.7715223735408561


  0%|          | 0/1028 [00:00<?, ?it/s]

epoch 18 batch id 1 loss 0.1288568377494812 train acc 0.9375
epoch 18 batch id 201 loss 0.14771442115306854 train acc 0.9580223880597015
epoch 18 batch id 401 loss 0.19144180417060852 train acc 0.9564370324189526
epoch 18 batch id 601 loss 0.0008694581338204443 train acc 0.9556988352745425
epoch 18 batch id 801 loss 0.05923460051417351 train acc 0.9569288389513109
epoch 18 batch id 1001 loss 0.04243139550089836 train acc 0.9576361138861139
epoch 18 train acc 0.9569836251621271


  0%|          | 0/257 [00:00<?, ?it/s]

epoch 18 validation acc 0.7692120622568094


  0%|          | 0/1028 [00:00<?, ?it/s]

epoch 19 batch id 1 loss 0.11256863921880722 train acc 0.96875
epoch 19 batch id 201 loss 0.1331794112920761 train acc 0.9583333333333334
epoch 19 batch id 401 loss 0.18915899097919464 train acc 0.9561253117206983
epoch 19 batch id 601 loss 0.0006006392650306225 train acc 0.9558028286189684
epoch 19 batch id 801 loss 0.06125078350305557 train acc 0.9570068664169787
epoch 19 batch id 1001 loss 0.05064002797007561 train acc 0.9576985514485514
epoch 19 train acc 0.957014023994812


  0%|          | 0/257 [00:00<?, ?it/s]

epoch 19 validation acc 0.7700632295719845


  0%|          | 0/1028 [00:00<?, ?it/s]

epoch 20 batch id 1 loss 0.1210225373506546 train acc 0.9375
epoch 20 batch id 201 loss 0.13035458326339722 train acc 0.9580223880597015
epoch 20 batch id 401 loss 0.18618766963481903 train acc 0.9561253117206983
epoch 20 batch id 601 loss 0.0005952384672127664 train acc 0.9557508319467554
epoch 20 batch id 801 loss 0.06226697191596031 train acc 0.956850811485643
epoch 20 batch id 1001 loss 0.05128684267401695 train acc 0.9576361138861139
epoch 20 train acc 0.9569836251621271


  0%|          | 0/257 [00:00<?, ?it/s]

epoch 20 validation acc 0.7701848249027238


In [31]:
torch.save(model, data_path+'/model_epoch20_len_128.pt')

In [None]:
# train_for_epochs(5, train_dataloader, test_dataloader)

In [None]:
# torch.save(model, data_path+'model_epoch5.pt')