<a href="https://colab.research.google.com/github/ttogle918/news_by_kobert/blob/master/classify_category/classify_news_category_kobert_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive

drive.mount('/content/drive')
data_path = '/content/drive/My Drive/Colab Notebooks/NextLab/news_class9x1400'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install mxnet
!pip install gluonnlp pandas tqdm
!pip install sentencepiece
!pip install transformers
!pip install torch

In [None]:
!pip install git+https://git@github.com/SKTBrain/KoBERT.git@master

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook
from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model
from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup

In [3]:
# gpu 연산이 가능하면 'cuda:0', 아니면 'cpu' 출력
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [4]:
#bert 모델, vocab 불러오기
bertmodel, vocab = get_pytorch_kobert_model()

using cached model. /content/.cache/kobert_v1.zip
using cached model. /content/.cache/kobert_news_wiki_ko_cased-1087f8699e.spiece


In [5]:
# text의 길이가 80이상 256이하의 문장 만들기
def text_splicing(text, max_seq_len) :
  text = text.split('\n')
  saved_text, s, i, data = '', 0, 0, []
  for t in text :
    if len(t) < 8 or t[0] in ['o','『', '(', '┌','│', '└', 'ㄴ', '┌','├','◎', '[', '■', 'ㄱ', '-', '.', '<'] : 
      continue
    if t[-1] in ['쪽', '-', '>', ']'] or t[1] in [')'] or t[0:2] in ['vs', '만,', '해!'] :
      continue
    if len(t) > max_seq_len :
        data.append(saved_text)
        data.append(t)
        saved_text, s = '', 0
        i += 1
    elif s + len(t) > max_seq_len :
        data.append(saved_text)
        saved_text, s = t, len(t)
        i += 1
    else :
        saved_text += t
        s += len(t)
  data.append(saved_text)
  
  ret_data = []
  for d in data :
    if len(d) >= 80 :
      ret_data.append(d)
  return ret_data

In [6]:
import os
max_seq_len = 256
dataset_train = []
labels = []
for (path, dir, files) in os.walk(data_path):
    for filename in files:
        ext = os.path.splitext(filename)[-1]
        if ext == '.txt':
            with open("%s/%s" % (path, filename), encoding="utf-8") as f:
              # labels.append( path[path.rindex('/')+1:])
              # dataset_train.append(f.read()[:max_seq_len])
              
              label = path[path.rindex('/')+1:]
              text = f.read()
              data = text_splicing(text, max_seq_len)
              dataset_train.extend(data)
              labels.extend([label]*len(data))

len(dataset_train), len(labels), set(labels)

(137023,
 137023,
 {'ITscience',
  'culture',
  'economy',
  'entertainment',
  'health',
  'life',
  'politic',
  'social',
  'sport'})

## 전처리

In [7]:
import pandas as pd

# labels_classes
df = pd.DataFrame({'content' : dataset_train, 'label' : labels})

# dataset의 balance 여부 확인
df.groupby(by=['label']).count()

Unnamed: 0_level_0,content
label,Unnamed: 1_level_1
ITscience,16362
culture,15139
economy,16302
entertainment,15334
health,16004
life,14764
politic,13766
social,14396
sport,14956


In [8]:
from sklearn.preprocessing import LabelEncoder

label_encoder = LabelEncoder()
label_encoder.fit(df['label'])
num_labels = len(label_encoder.classes_)

df['encoded_label'] = np.asarray(label_encoder.transform(df['label']), dtype=np.int32)
df.tail(3)
# label_encoder.inverse_transform(prediction_test)

Unnamed: 0,content,label,encoded_label
137020,된장국도 끓이고 배추 겉저리도 담궜다. 두루 둘러앉아 식사를 하는데 큰 형님 게오르...,life,5
137021,그리고 마무리는 게오르기 형님이 낮에 준비한 커다란 수박을 잘라먹었다.수박을 커다란...,life,5
137022,그리고는 내게 이렇게 오셔서 덕분에 우리가 고려 말을 많이 합니다.그리고 앞으로도 ...,life,5


In [9]:
from sklearn.model_selection import train_test_split

# Split Train and Validation data
train_texts, test_texts, train_labels, test_labels = train_test_split(df.content, df.encoded_label, test_size=0.2, random_state=123, shuffle=True, stratify=df.encoded_label)
train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=0.2, random_state=123, shuffle=True, stratify=train_labels)

print(len(train_texts), len(train_labels))
print(len(val_texts), len(val_labels))
print(len(test_texts), len(test_labels))

87694 87694
21924 21924
27405 27405


In [10]:
dataset_train, dataset_val, dataset_test = [], [], []
text, label = '', ''
for text, label in zip(train_texts, train_labels) :
  dataset_train.append([text, label])
for text, label in zip(val_texts, val_labels) :
  dataset_val.append([text, label])
for text, label in zip(test_texts, test_labels) :
  dataset_test.append([text, label])
len(dataset_train), len(dataset_val), len(dataset_test)

(87694, 21924, 27405)

In [11]:
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)

        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))

In [12]:
## Setting parameters
max_len = 256
batch_size = 16
warmup_ratio = 0.1
num_epochs = 10
max_grad_norm = 1
log_interval = 200

In [13]:
#토큰화
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)
#BERTDataset 클래스 이용, TensorDataset으로 만들어주기
data_train = BERTDataset(dataset_train, 0, 1, tok, max_len, True, False)
data_val = BERTDataset(dataset_val, 0, 1, tok, max_len, True, False)
data_test = BERTDataset(dataset_test, 0, 1, tok, max_len, True, False)

using cached model. /content/.cache/kobert_news_wiki_ko_cased-1087f8699e.spiece


In [14]:
#배치 및 데이터로더 설정
train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size)
val_dataloader = torch.utils.data.DataLoader(data_val, batch_size=batch_size)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=1)

In [15]:
del labels

del df
del train_texts
del train_labels
del val_texts
del val_labels
del test_texts
del test_labels

del data_train
del data_val
del data_test
del tok
del tokenizer
del dataset_train
del dataset_val
del dataset_test

In [16]:
class BERTClassifier(nn.Module):
    def __init__(self, bert, hidden_size = 768,
                 num_classes=num_labels, ## 분류할 클래스 수
                 dr_rate=None, params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)

In [17]:
model = BERTClassifier(bertmodel, dr_rate=0.5).to(device) #gpu

In [18]:
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

In [19]:
def train_for_epochs(num_epochs, train_dataloader, val_dataloader) :
  # Prepare optimizer and schedule (linear warmup and decay)
  no_decay = ['bias', 'LayerNorm.weight']
  optimizer_grouped_parameters = [
      {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
      {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
  ]

  optimizer = AdamW(optimizer_grouped_parameters, lr=5e-5)  # lr = learning rate
  loss_fn = nn.CrossEntropyLoss()

  t_total = len(train_dataloader) * num_epochs
  warmup_step = int(t_total * warmup_ratio)

  scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)
  
  # train
  for e in range(num_epochs):
      train_acc = 0.0
      val_acc = 0.0
      val_loss = 0.0
      # training
      model.train()
      for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
          optimizer.zero_grad()
          token_ids = token_ids.long().to(device)
          segment_ids = segment_ids.long().to(device)
          valid_length= valid_length
          label = label.long().to(device)
          out = model(token_ids, valid_length, segment_ids)
          loss = loss_fn(out, label)
          loss.backward()
          torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
          optimizer.step()
          scheduler.step()  # Update learning rate schedule
          train_acc += calc_accuracy(out, label)
          if batch_id % log_interval == 0:
              print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
      print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
      model.eval()
      # validation
      for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(val_dataloader)):
          token_ids = token_ids.long().to(device)
          segment_ids = segment_ids.long().to(device)
          valid_length= valid_length
          label = label.long().to(device)
          out = model(token_ids, valid_length, segment_ids)
          val_loss += loss_fn(out, label)

          val_acc += calc_accuracy(out, label)
      print("epoch {} validation acc {} loss {}".format(e+1, val_acc / (batch_id+1), val_loss.data.cpu().numpy()))

In [22]:
def test(test_dataloader) :
  out_li = []
  no_decay = ['bias', 'LayerNorm.weight']
  optimizer_grouped_parameters = [
      {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
      {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
  ]

  optimizer = AdamW(optimizer_grouped_parameters, lr=5e-5)  # lr = learning rate
  loss_fn = nn.CrossEntropyLoss()

  t_total = len(train_dataloader) * num_epochs
  warmup_step = int(t_total * warmup_ratio)
  test_acc, test_loss = 0.0, 0.0
  model.eval()

  scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)
  for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_dataloader)):
    token_ids = token_ids.long().to(device)
    segment_ids = segment_ids.long().to(device)
    valid_length= valid_length
    label = label.long().to(device)
    out = model(token_ids, valid_length, segment_ids)
    out_li.append(out)
    test_acc += calc_accuracy(out, label)
    test_loss += loss_fn(out, label)
  print("Test acc {}, loss {}".format(test_acc / (batch_id+1), test_loss))
  return out_li

In [28]:
import gc
gc.collect()
# torch.cuda.empty_cache()

476

In [21]:
train_for_epochs(num_epochs, train_dataloader, val_dataloader)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


  0%|          | 0/5481 [00:00<?, ?it/s]

epoch 1 batch id 1 loss 2.239077568054199 train acc 0.1875
epoch 1 batch id 201 loss 2.2967329025268555 train acc 0.12468905472636815
epoch 1 batch id 401 loss 1.9808170795440674 train acc 0.1561720698254364
epoch 1 batch id 601 loss 1.3976389169692993 train acc 0.25946339434276205
epoch 1 batch id 801 loss 1.4425305128097534 train acc 0.3578339575530587
epoch 1 batch id 1001 loss 1.2798954248428345 train acc 0.4288836163836164
epoch 1 batch id 1201 loss 0.8035870790481567 train acc 0.4777268942547877
epoch 1 batch id 1401 loss 0.6184626817703247 train acc 0.5149892933618844
epoch 1 batch id 1601 loss 0.6582783460617065 train acc 0.5441130543410369
epoch 1 batch id 1801 loss 0.8104749321937561 train acc 0.5659355913381455
epoch 1 batch id 2001 loss 0.6125224232673645 train acc 0.5858633183408296
epoch 1 batch id 2201 loss 0.3132786452770233 train acc 0.6002953203089505
epoch 1 batch id 2401 loss 0.8216688632965088 train acc 0.6132340691378593
epoch 1 batch id 2601 loss 0.27481183409690

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


  0%|          | 0/1371 [00:00<?, ?it/s]

epoch 1 validation acc 0.7677789934354485 loss 0.14821141958236694


  0%|          | 0/5481 [00:00<?, ?it/s]

epoch 2 batch id 1 loss 0.473906546831131 train acc 0.9375
epoch 2 batch id 201 loss 0.8161687850952148 train acc 0.7646144278606966
epoch 2 batch id 401 loss 0.6354888677597046 train acc 0.7605985037406484
epoch 2 batch id 601 loss 0.40739431977272034 train acc 0.7606073211314476
epoch 2 batch id 801 loss 1.0934593677520752 train acc 0.7653714107365793
epoch 2 batch id 1001 loss 0.6499863862991333 train acc 0.7698551448551448
epoch 2 batch id 1201 loss 0.4825512766838074 train acc 0.7718047460449625
epoch 2 batch id 1401 loss 0.6302244663238525 train acc 0.7743576017130621
epoch 2 batch id 1601 loss 0.8570829629898071 train acc 0.7776780137414117
epoch 2 batch id 1801 loss 0.332589328289032 train acc 0.7788034425319267
epoch 2 batch id 2001 loss 0.46415477991104126 train acc 0.7817028985507246
epoch 2 batch id 2201 loss 0.1691857874393463 train acc 0.7811506133575647
epoch 2 batch id 2401 loss 0.8651722073554993 train acc 0.7832153269471054
epoch 2 batch id 2601 loss 0.267309278249740

  0%|          | 0/1371 [00:00<?, ?it/s]

epoch 2 validation acc 0.8005105762217359 loss 0.07032507658004761


  0%|          | 0/5481 [00:00<?, ?it/s]

epoch 3 batch id 1 loss 0.6870565414428711 train acc 0.875
epoch 3 batch id 201 loss 0.4472864866256714 train acc 0.8330223880597015
epoch 3 batch id 401 loss 0.5483356714248657 train acc 0.8340087281795511
epoch 3 batch id 601 loss 0.08045139908790588 train acc 0.8320507487520798
epoch 3 batch id 801 loss 0.865158200263977 train acc 0.8352059925093633
epoch 3 batch id 1001 loss 0.7181187272071838 train acc 0.8385989010989011
epoch 3 batch id 1201 loss 0.5141580104827881 train acc 0.8400291423813488
epoch 3 batch id 1401 loss 0.5435770153999329 train acc 0.8416309778729479
epoch 3 batch id 1601 loss 0.5675480961799622 train acc 0.8436524047470331
epoch 3 batch id 1801 loss 0.3525419533252716 train acc 0.8433856191004997
epoch 3 batch id 2001 loss 0.3458406925201416 train acc 0.8450774612693653
epoch 3 batch id 2201 loss 0.11505425721406937 train acc 0.8450420263516584
epoch 3 batch id 2401 loss 1.0012986660003662 train acc 0.8462359433569346
epoch 3 batch id 2601 loss 0.358532845973968

  0%|          | 0/1371 [00:00<?, ?it/s]

epoch 3 validation acc 0.8085339168490153 loss 0.019931610673666


  0%|          | 0/5481 [00:00<?, ?it/s]

epoch 4 batch id 1 loss 0.5802267789840698 train acc 0.8125
epoch 4 batch id 201 loss 0.34816527366638184 train acc 0.8781094527363185
epoch 4 batch id 401 loss 0.5022806525230408 train acc 0.8799875311720698
epoch 4 batch id 601 loss 0.4066784679889679 train acc 0.8787437603993344
epoch 4 batch id 801 loss 0.5045167803764343 train acc 0.8798377028714107
epoch 4 batch id 1001 loss 0.7391334772109985 train acc 0.8826173826173827
epoch 4 batch id 1201 loss 0.30086061358451843 train acc 0.8831182348043297
epoch 4 batch id 1401 loss 0.5367258787155151 train acc 0.8836991434689507
epoch 4 batch id 1601 loss 0.3262191116809845 train acc 0.8845643347907558
epoch 4 batch id 1801 loss 0.10635478049516678 train acc 0.8850638534147696
epoch 4 batch id 2001 loss 0.28313612937927246 train acc 0.8863380809595203
epoch 4 batch id 2201 loss 0.05095259100198746 train acc 0.8853646069968196
epoch 4 batch id 2401 loss 0.8608707785606384 train acc 0.8863754685547689
epoch 4 batch id 2601 loss 0.0330398045

  0%|          | 0/1371 [00:00<?, ?it/s]

epoch 4 validation acc 0.8203409919766593 loss 0.017821144312620163


  0%|          | 0/5481 [00:00<?, ?it/s]

epoch 5 batch id 1 loss 0.2763855755329132 train acc 0.9375
epoch 5 batch id 201 loss 0.23679119348526 train acc 0.914179104477612
epoch 5 batch id 401 loss 0.14344049990177155 train acc 0.9089775561097256
epoch 5 batch id 601 loss 0.21487963199615479 train acc 0.9092138103161398
epoch 5 batch id 801 loss 0.16703841090202332 train acc 0.9106585518102372
epoch 5 batch id 1001 loss 1.0989773273468018 train acc 0.9119005994005994
epoch 5 batch id 1201 loss 0.17856532335281372 train acc 0.9123646960865945
epoch 5 batch id 1401 loss 0.6907854676246643 train acc 0.9140346181299072
epoch 5 batch id 1601 loss 0.07043783366680145 train acc 0.915365396627108
epoch 5 batch id 1801 loss 0.13698044419288635 train acc 0.9154289283731261
epoch 5 batch id 2001 loss 0.2522393763065338 train acc 0.9166041979010495
epoch 5 batch id 2201 loss 0.028657404705882072 train acc 0.9171399363925489
epoch 5 batch id 2401 loss 0.38521480560302734 train acc 0.9182892544773011
epoch 5 batch id 2601 loss 0.0217505618

  0%|          | 0/1371 [00:00<?, ?it/s]

epoch 5 validation acc 0.8239879649890591 loss 0.0040969387628138065


  0%|          | 0/5481 [00:00<?, ?it/s]

epoch 6 batch id 1 loss 0.141564279794693 train acc 0.9375
epoch 6 batch id 201 loss 0.4022715389728546 train acc 0.9430970149253731
epoch 6 batch id 401 loss 0.005338583141565323 train acc 0.9395261845386533
epoch 6 batch id 601 loss 0.023725692182779312 train acc 0.9381239600665557
epoch 6 batch id 801 loss 0.2377147078514099 train acc 0.9392166042446941
epoch 6 batch id 1001 loss 0.7407792806625366 train acc 0.9407467532467533
epoch 6 batch id 1201 loss 0.02545190043747425 train acc 0.941455037468776
epoch 6 batch id 1401 loss 0.38775911927223206 train acc 0.9410242683797287
epoch 6 batch id 1601 loss 0.2772865891456604 train acc 0.9426139912554653
epoch 6 batch id 1801 loss 0.25490084290504456 train acc 0.9424625208217657
epoch 6 batch id 2001 loss 0.20843011140823364 train acc 0.9429035482258871
epoch 6 batch id 2201 loss 0.003821031656116247 train acc 0.9435767832803271
epoch 6 batch id 2401 loss 0.6941861510276794 train acc 0.9445543523531862
epoch 6 batch id 2601 loss 0.0050424

  0%|          | 0/1371 [00:00<?, ?it/s]

epoch 6 validation acc 0.8312363238512035 loss 0.0018540300661697984


  0%|          | 0/5481 [00:00<?, ?it/s]

epoch 7 batch id 1 loss 0.48754969239234924 train acc 0.875
epoch 7 batch id 201 loss 0.058231171220541 train acc 0.960820895522388
epoch 7 batch id 401 loss 0.13661795854568481 train acc 0.960567331670823
epoch 7 batch id 601 loss 0.002736012451350689 train acc 0.961522462562396
epoch 7 batch id 801 loss 0.07056961953639984 train acc 0.9620006242197253
epoch 7 batch id 1001 loss 0.7010361552238464 train acc 0.9622877122877123
epoch 7 batch id 1201 loss 0.16288365423679352 train acc 0.961958784346378
epoch 7 batch id 1401 loss 0.35229235887527466 train acc 0.962615988579586
epoch 7 batch id 1601 loss 0.0037573506124317646 train acc 0.9634603372891942
epoch 7 batch id 1801 loss 0.009468049742281437 train acc 0.963041365907829
epoch 7 batch id 2001 loss 0.2280537486076355 train acc 0.9638930534732634
epoch 7 batch id 2201 loss 0.0019398084841668606 train acc 0.9639368468877783
epoch 7 batch id 2401 loss 0.012539787217974663 train acc 0.9642336526447314
epoch 7 batch id 2601 loss 0.037562

  0%|          | 0/1371 [00:00<?, ?it/s]

epoch 7 validation acc 0.8361141502552881 loss 0.0008498610695824027


  0%|          | 0/5481 [00:00<?, ?it/s]

epoch 8 batch id 1 loss 0.08902984857559204 train acc 0.9375
epoch 8 batch id 201 loss 0.1803007572889328 train acc 0.9741915422885572
epoch 8 batch id 401 loss 0.7121216654777527 train acc 0.9766209476309227
epoch 8 batch id 601 loss 0.17898161709308624 train acc 0.9761855241264559
epoch 8 batch id 801 loss 0.004388557281345129 train acc 0.9765917602996255
epoch 8 batch id 1001 loss 0.32275548577308655 train acc 0.9775224775224776
epoch 8 batch id 1201 loss 0.013381151482462883 train acc 0.97751873438801
epoch 8 batch id 1401 loss 0.10342729836702347 train acc 0.9772037830121342
epoch 8 batch id 1601 loss 0.41482117772102356 train acc 0.9775140537164272
epoch 8 batch id 1801 loss 0.010129486210644245 train acc 0.977547196002221
epoch 8 batch id 2001 loss 0.16489101946353912 train acc 0.9779485257371314
epoch 8 batch id 2201 loss 0.0008446586434729397 train acc 0.9782201272149024
epoch 8 batch id 2401 loss 0.022375378757715225 train acc 0.9785766347355269
epoch 8 batch id 2601 loss 0.0

  0%|          | 0/1371 [00:00<?, ?it/s]

epoch 8 validation acc 0.8405816921954777 loss 0.0005085912998765707


  0%|          | 0/5481 [00:00<?, ?it/s]

epoch 9 batch id 1 loss 0.04567670822143555 train acc 1.0
epoch 9 batch id 201 loss 0.2575564384460449 train acc 0.9869402985074627
epoch 9 batch id 401 loss 0.14742781221866608 train acc 0.98893391521197
epoch 9 batch id 601 loss 0.0006994801224209368 train acc 0.9882487520798668
epoch 9 batch id 801 loss 0.03303229808807373 train acc 0.988061797752809
epoch 9 batch id 1001 loss 0.08788710832595825 train acc 0.9880744255744256
epoch 9 batch id 1201 loss 0.0006090165697969496 train acc 0.988134887593672
epoch 9 batch id 1401 loss 0.0006143934442661703 train acc 0.9884457530335474
epoch 9 batch id 1601 loss 0.22167281806468964 train acc 0.988366645846346
epoch 9 batch id 1801 loss 0.0031903814524412155 train acc 0.9884092171016102
epoch 9 batch id 2001 loss 0.246619313955307 train acc 0.9882871064467766
epoch 9 batch id 2201 loss 0.00039722578367218375 train acc 0.9884143571104044
epoch 9 batch id 2401 loss 0.0006321587134152651 train acc 0.9886245314452311
epoch 9 batch id 2601 loss 0.

  0%|          | 0/1371 [00:00<?, ?it/s]

epoch 9 validation acc 0.8425419401896426 loss 0.0003792193892877549


  0%|          | 0/5481 [00:00<?, ?it/s]

epoch 10 batch id 1 loss 0.0009809585753828287 train acc 1.0
epoch 10 batch id 201 loss 0.0005467613809742033 train acc 0.9934701492537313
epoch 10 batch id 401 loss 0.0005796737968921661 train acc 0.9932980049875312
epoch 10 batch id 601 loss 0.00042638438753783703 train acc 0.9931364392678869
epoch 10 batch id 801 loss 0.0007492119329981506 train acc 0.9925873907615481
epoch 10 batch id 1001 loss 0.01030302420258522 train acc 0.9927572427572428
epoch 10 batch id 1201 loss 0.0004734884132631123 train acc 0.992714404662781
epoch 10 batch id 1401 loss 0.0008631906239315867 train acc 0.9928622412562456
epoch 10 batch id 1601 loss 0.0036301936488598585 train acc 0.99289506558401
epoch 10 batch id 1801 loss 0.009461541660130024 train acc 0.9929553026096613
epoch 10 batch id 2001 loss 0.18386955559253693 train acc 0.9927848575712144
epoch 10 batch id 2201 loss 0.00038276176201179624 train acc 0.9929009541117674
epoch 10 batch id 2401 loss 0.0005393975297920406 train acc 0.9929196168263223
e

  0%|          | 0/1371 [00:00<?, ?it/s]

epoch 10 validation acc 0.8440007293946025 loss 0.0003491892130114138


In [26]:
torch.save(model, data_path+'/classify_category/epoch10_val_acc:0.844.pt')