In [1]:
import torchvision
from torchvision import models
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

import os, subprocess
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

In [2]:
for name in dir(torchvision.models):
    print(name)

AlexNet
AlexNet_Weights
ConvNeXt
ConvNeXt_Base_Weights
ConvNeXt_Large_Weights
ConvNeXt_Small_Weights
ConvNeXt_Tiny_Weights
DenseNet
DenseNet121_Weights
DenseNet161_Weights
DenseNet169_Weights
DenseNet201_Weights
EfficientNet
EfficientNet_B0_Weights
EfficientNet_B1_Weights
EfficientNet_B2_Weights
EfficientNet_B3_Weights
EfficientNet_B4_Weights
EfficientNet_B5_Weights
EfficientNet_B6_Weights
EfficientNet_B7_Weights
EfficientNet_V2_L_Weights
EfficientNet_V2_M_Weights
EfficientNet_V2_S_Weights
GoogLeNet
GoogLeNetOutputs
GoogLeNet_Weights
Inception3
InceptionOutputs
Inception_V3_Weights
MNASNet
MNASNet0_5_Weights
MNASNet0_75_Weights
MNASNet1_0_Weights
MNASNet1_3_Weights
MaxVit
MaxVit_T_Weights
MobileNetV2
MobileNetV3
MobileNet_V2_Weights
MobileNet_V3_Large_Weights
MobileNet_V3_Small_Weights
RegNet
RegNet_X_16GF_Weights
RegNet_X_1_6GF_Weights
RegNet_X_32GF_Weights
RegNet_X_3_2GF_Weights
RegNet_X_400MF_Weights
RegNet_X_800MF_Weights
RegNet_X_8GF_Weights
RegNet_Y_128GF_Weights
RegNet_Y_16GF_We

In [3]:
import torch
import torch.nn as nn

from tqdm import tqdm


model = models.resnet18(pretrained=True)
model



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [4]:
model.fc.out_features=20
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [5]:
#parameter 수 확인
total_params = sum(p.numel() for p in model.parameters())
total_params

11689512

In [6]:
data_dir = "./data"
dir_lst = os.listdir(data_dir)
for dir in dir_lst:
    dir_path = os.listdir(f"./data/{dir}")

    for name in dir_path:
        if name == ".ipynb_checkpoints":
            print(name)
            os.rmdir(f"./data/{dir}/{name}")

In [7]:
IMG_TRANSFORM = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((512,512), antialias = True),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
    ])

In [8]:
train_imgfolder = ImageFolder(root="./data/train",
                              transform=IMG_TRANSFORM,
                              target_transform=None
                             )

test_imgfolder = ImageFolder(root="./data/test",
                              transform=IMG_TRANSFORM
                             )

In [9]:
Image.MAX_IMAGE_PIXELS = None

train_dataloader = DataLoader(dataset=train_imgfolder,
							  batch_size = 64,
                              num_workers=os.cpu_count(),
                              shuffle=True,
                              drop_last=True 
                             )

test_dataloader = DataLoader(dataset=test_imgfolder,
							  batch_size = 64,
                              num_workers=os.cpu_count(),
                              shuffle=True,
                              drop_last=True 
                            )

In [10]:
# model setting

EPOCH = 50
LEARNING_RATE = 1e-3 
DEVICE = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE, weight_decay = 0.001)
criterion = nn.CrossEntropyLoss()
print("Using device: ",DEVICE)

Using device:  cuda:2


In [11]:
def loss_fn(class_output, label, criterion):
    CE_loss = criterion(class_output, label)

    return CE_loss 

In [12]:
def train(train_loader, model, optimizer, criterion, epoch):
    model.train()
    correct = 0
    total = 0
    
    epoch_loss = []
    for index, (img_data,labels) in enumerate(train_loader):
        img_data, labels = img_data.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        
        outputs = model(img_data)
        loss = loss_fn(outputs, labels, criterion)
        
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
        
        
        loss.backward()
        optimizer.step()

        epoch_loss.append(loss.item())
        if index%100==0:
            print(f'train loss : {np.mean(epoch_loss):>7f}, [epoch:{epoch}, iter:{index}]')
            
    accuracy = correct / total
    print(f"TRAIN ACCURACY : {(100*accuracy):>0.1f}% [{correct}/{total}]")
    return np.mean(epoch_loss)

In [13]:
def test(test_loader, model, criterion, epoch):
    model.eval()

    epoch_loss = []
    correct = 0
    total = 0
    
    with torch.no_grad():
        for index, (img_data,labels) in enumerate(test_loader):
            img_data, labels = img_data.to(DEVICE), labels.to(DEVICE)
            outputs = model(img_data)
    
            loss = loss_fn(outputs, labels, criterion)

            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
            
            epoch_loss.append(loss.item())
    
    accuracy = correct / total
    
    print(f"TEST ACCURACY : {(100*accuracy):>0.1f}% [{correct}/{total}]")
    return np.mean(epoch_loss)

In [14]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
train_loss_lst, test_loss_lst = [], []

model_version = 1


test_best_loss = 100
test_current_loss = 100
early_stop_threshold = 5
early_stop_trigger = 0

model = model.to(DEVICE)

for i in tqdm(range(EPOCH), desc = 'Train'):
    print(f"EPOCH {i+1} \n-------------------------")
    train_loss = train(train_dataloader, model, optimizer, criterion, i+1)
    test_loss = test(test_dataloader, model, criterion, i+1)

    train_loss_lst.append(train_loss)
    test_loss_lst.append(test_loss)

    if test_loss < test_best_loss :
        print("...MODEL SAVE...")
        test_best_loss = test_loss 
        
        if model_version == 0:
            torch.save(model.state_dict(),'contrastive_model_for_style_weight_v2_without_target_style.pth')
        else:
            torch.save(model.state_dict(),'contrastive_model_for_style_weight_resnet18.pth')
        
    if test_current_loss < test_loss:
        early_stop_trigger += 1 
    else:
        early_stop_trigger = 0 
    test_current_loss  = test_loss 

    print(f'\nEPOCH:{i+1} | train loss : {train_loss}, test loss : {test_loss}\n')
    
    if early_stop_trigger >= early_stop_threshold:
        break
        
print("DONE!")

Train:   0%|                                       | 0/50 [00:00<?, ?it/s]

EPOCH 1 
-------------------------
train loss : 8.987834, [epoch:1, iter:0]




train loss : 2.298772, [epoch:1, iter:100]




train loss : 2.067837, [epoch:1, iter:200]
train loss : 1.943789, [epoch:1, iter:300]
train loss : 1.872494, [epoch:1, iter:400]
train loss : 1.827620, [epoch:1, iter:500]
train loss : 1.791963, [epoch:1, iter:600]
train loss : 1.764496, [epoch:1, iter:700]
train loss : 1.739869, [epoch:1, iter:800]
train loss : 1.723020, [epoch:1, iter:900]
TRAIN ACCURACY : 45.7% [28149/61632]


Train:   2%|▌                           | 1/50 [10:50<8:51:15, 650.52s/it]

TEST ACCURACY : 29.4% [3421/11648]
...MODEL SAVE...

EPOCH:1 | train loss : 1.714554451955318, test loss : 2.682305021600409

EPOCH 2 
-------------------------




train loss : 1.409655, [epoch:2, iter:0]
train loss : 1.531189, [epoch:2, iter:100]
train loss : 1.522819, [epoch:2, iter:200]
train loss : 1.538915, [epoch:2, iter:300]
train loss : 1.532045, [epoch:2, iter:400]




train loss : 1.526812, [epoch:2, iter:500]
train loss : 1.524522, [epoch:2, iter:600]
train loss : 1.521263, [epoch:2, iter:700]
train loss : 1.514198, [epoch:2, iter:800]
train loss : 1.515567, [epoch:2, iter:900]
TRAIN ACCURACY : 50.7% [31248/61632]


Train:   4%|█                           | 2/50 [22:12<8:54:59, 668.73s/it]

TEST ACCURACY : 32.1% [3737/11648]
...MODEL SAVE...

EPOCH:2 | train loss : 1.5158448782541547, test loss : 2.2113801208171218

EPOCH 3 
-------------------------
train loss : 1.493022, [epoch:3, iter:0]
train loss : 1.451385, [epoch:3, iter:100]
train loss : 1.439528, [epoch:3, iter:200]
train loss : 1.444452, [epoch:3, iter:300]




train loss : 1.440256, [epoch:3, iter:400]
train loss : 1.439969, [epoch:3, iter:500]
train loss : 1.442058, [epoch:3, iter:600]
train loss : 1.441310, [epoch:3, iter:700]
train loss : 1.437975, [epoch:3, iter:800]
train loss : 1.440111, [epoch:3, iter:900]
TRAIN ACCURACY : 53.2% [32809/61632]


Train:   6%|█▋                          | 3/50 [34:02<8:58:41, 687.68s/it]

TEST ACCURACY : 36.0% [4191/11648]
...MODEL SAVE...

EPOCH:3 | train loss : 1.4387350769429192, test loss : 2.0568444591302137

EPOCH 4 
-------------------------
train loss : 1.678887, [epoch:4, iter:0]
train loss : 1.366955, [epoch:4, iter:100]




train loss : 1.374354, [epoch:4, iter:200]
train loss : 1.382975, [epoch:4, iter:300]
train loss : 1.374962, [epoch:4, iter:400]
train loss : 1.377744, [epoch:4, iter:500]




train loss : 1.378542, [epoch:4, iter:600]
train loss : 1.379782, [epoch:4, iter:700]
train loss : 1.379457, [epoch:4, iter:800]
train loss : 1.377592, [epoch:4, iter:900]
TRAIN ACCURACY : 55.1% [33977/61632]


Train:   8%|██▏                         | 4/50 [45:14<8:42:38, 681.70s/it]

TEST ACCURACY : 40.3% [4691/11648]
...MODEL SAVE...

EPOCH:4 | train loss : 1.375707357781451, test loss : 1.9852931578080732

EPOCH 5 
-------------------------
train loss : 1.116356, [epoch:5, iter:0]




train loss : 1.288447, [epoch:5, iter:100]
train loss : 1.308718, [epoch:5, iter:200]




train loss : 1.312303, [epoch:5, iter:300]
train loss : 1.325294, [epoch:5, iter:400]
train loss : 1.325087, [epoch:5, iter:500]
train loss : 1.324672, [epoch:5, iter:600]
train loss : 1.327446, [epoch:5, iter:700]
train loss : 1.332970, [epoch:5, iter:800]
train loss : 1.334152, [epoch:5, iter:900]
TRAIN ACCURACY : 56.4% [34788/61632]


Train:  10%|██▊                         | 5/50 [56:39<8:32:11, 682.91s/it]

TEST ACCURACY : 43.8% [5101/11648]
...MODEL SAVE...

EPOCH:5 | train loss : 1.3329688326096856, test loss : 1.7643767693540553

EPOCH 6 
-------------------------
train loss : 1.118202, [epoch:6, iter:0]




train loss : 1.260561, [epoch:6, iter:100]




train loss : 1.270809, [epoch:6, iter:200]
train loss : 1.284568, [epoch:6, iter:300]
train loss : 1.284718, [epoch:6, iter:400]
train loss : 1.293439, [epoch:6, iter:500]
train loss : 1.292802, [epoch:6, iter:600]
train loss : 1.293742, [epoch:6, iter:700]
train loss : 1.294093, [epoch:6, iter:800]
train loss : 1.296533, [epoch:6, iter:900]
TRAIN ACCURACY : 57.5% [35459/61632]


Train:  12%|███                       | 6/50 [1:08:00<8:20:07, 682.00s/it]

TEST ACCURACY : 42.4% [4938/11648]

EPOCH:6 | train loss : 1.297459176581968, test loss : 1.8911626528907608

EPOCH 7 
-------------------------
train loss : 1.107564, [epoch:7, iter:0]
train loss : 1.235846, [epoch:7, iter:100]




train loss : 1.242498, [epoch:7, iter:200]
train loss : 1.248926, [epoch:7, iter:300]
train loss : 1.244994, [epoch:7, iter:400]
train loss : 1.250693, [epoch:7, iter:500]
train loss : 1.254574, [epoch:7, iter:600]
train loss : 1.258687, [epoch:7, iter:700]




train loss : 1.262955, [epoch:7, iter:800]
train loss : 1.263774, [epoch:7, iter:900]
TRAIN ACCURACY : 58.6% [36110/61632]


Train:  14%|███▋                      | 7/50 [1:18:55<8:02:34, 673.36s/it]

TEST ACCURACY : 39.4% [4595/11648]

EPOCH:7 | train loss : 1.2647960875512159, test loss : 2.0208221660865533

EPOCH 8 
-------------------------
train loss : 1.042290, [epoch:8, iter:0]




train loss : 1.166914, [epoch:8, iter:100]
train loss : 1.212521, [epoch:8, iter:200]
train loss : 1.219120, [epoch:8, iter:300]
train loss : 1.227342, [epoch:8, iter:400]




train loss : 1.233287, [epoch:8, iter:500]
train loss : 1.235673, [epoch:8, iter:600]
train loss : 1.239996, [epoch:8, iter:700]
train loss : 1.242908, [epoch:8, iter:800]
train loss : 1.244151, [epoch:8, iter:900]
TRAIN ACCURACY : 59.1% [36455/61632]


Train:  16%|████▏                     | 8/50 [1:30:01<7:49:42, 671.02s/it]

TEST ACCURACY : 37.4% [4360/11648]

EPOCH:8 | train loss : 1.2449244839503881, test loss : 1.9978906152012583

EPOCH 9 
-------------------------
train loss : 1.212343, [epoch:9, iter:0]




train loss : 1.186526, [epoch:9, iter:100]
train loss : 1.191636, [epoch:9, iter:200]
train loss : 1.207223, [epoch:9, iter:300]
train loss : 1.214596, [epoch:9, iter:400]
train loss : 1.224556, [epoch:9, iter:500]
train loss : 1.223562, [epoch:9, iter:600]
train loss : 1.221041, [epoch:9, iter:700]




train loss : 1.222042, [epoch:9, iter:800]
train loss : 1.226882, [epoch:9, iter:900]
TRAIN ACCURACY : 59.8% [36865/61632]


Train:  18%|████▋                     | 9/50 [1:41:20<7:40:05, 673.32s/it]

TEST ACCURACY : 41.5% [4839/11648]

EPOCH:9 | train loss : 1.2289359972484386, test loss : 1.940482919687753

EPOCH 10 
-------------------------
train loss : 1.066191, [epoch:10, iter:0]
train loss : 1.173605, [epoch:10, iter:100]
train loss : 1.176728, [epoch:10, iter:200]




train loss : 1.172124, [epoch:10, iter:300]
train loss : 1.181968, [epoch:10, iter:400]
train loss : 1.183945, [epoch:10, iter:500]




train loss : 1.189873, [epoch:10, iter:600]
train loss : 1.196634, [epoch:10, iter:700]
train loss : 1.195547, [epoch:10, iter:800]
train loss : 1.201659, [epoch:10, iter:900]
TRAIN ACCURACY : 60.8% [37478/61632]


Train:  20%|█████                    | 10/50 [1:52:38<7:30:00, 675.01s/it]

TEST ACCURACY : 42.8% [4989/11648]

EPOCH:10 | train loss : 1.2026123765349264, test loss : 1.8170290320784181

EPOCH 11 
-------------------------
train loss : 1.255784, [epoch:11, iter:0]
train loss : 1.118835, [epoch:11, iter:100]




train loss : 1.141459, [epoch:11, iter:200]
train loss : 1.155315, [epoch:11, iter:300]
train loss : 1.165298, [epoch:11, iter:400]
train loss : 1.170520, [epoch:11, iter:500]




train loss : 1.176247, [epoch:11, iter:600]
train loss : 1.175779, [epoch:11, iter:700]
train loss : 1.178169, [epoch:11, iter:800]
train loss : 1.180707, [epoch:11, iter:900]
TRAIN ACCURACY : 61.5% [37892/61632]


Train:  22%|█████▌                   | 11/50 [2:03:43<7:16:44, 671.90s/it]

TEST ACCURACY : 46.0% [5359/11648]
...MODEL SAVE...

EPOCH:11 | train loss : 1.1819401125917801, test loss : 1.6561037511615964

EPOCH 12 
-------------------------




train loss : 1.370647, [epoch:12, iter:0]
train loss : 1.094568, [epoch:12, iter:100]
train loss : 1.104326, [epoch:12, iter:200]




train loss : 1.128557, [epoch:12, iter:300]
train loss : 1.132047, [epoch:12, iter:400]
train loss : 1.146635, [epoch:12, iter:500]
train loss : 1.152607, [epoch:12, iter:600]
train loss : 1.163890, [epoch:12, iter:700]
train loss : 1.168357, [epoch:12, iter:800]
train loss : 1.172823, [epoch:12, iter:900]
TRAIN ACCURACY : 61.4% [37860/61632]


Train:  24%|██████                   | 12/50 [2:15:04<7:07:20, 674.76s/it]

TEST ACCURACY : 42.1% [4899/11648]

EPOCH:12 | train loss : 1.1747131234388857, test loss : 1.8143794615190107

EPOCH 13 
-------------------------
train loss : 1.060805, [epoch:13, iter:0]
train loss : 1.147825, [epoch:13, iter:100]
train loss : 1.140674, [epoch:13, iter:200]
train loss : 1.132008, [epoch:13, iter:300]
train loss : 1.137269, [epoch:13, iter:400]
train loss : 1.138794, [epoch:13, iter:500]
train loss : 1.144823, [epoch:13, iter:600]




train loss : 1.148767, [epoch:13, iter:700]
train loss : 1.146530, [epoch:13, iter:800]




train loss : 1.148674, [epoch:13, iter:900]
TRAIN ACCURACY : 62.1% [38293/61632]


Train:  26%|██████▌                  | 13/50 [2:26:17<6:55:36, 673.97s/it]

TEST ACCURACY : 43.4% [5052/11648]

EPOCH:13 | train loss : 1.1503312049872654, test loss : 1.8470720131318648

EPOCH 14 
-------------------------
train loss : 1.132159, [epoch:14, iter:0]
train loss : 1.083514, [epoch:14, iter:100]
train loss : 1.106857, [epoch:14, iter:200]
train loss : 1.103810, [epoch:14, iter:300]
train loss : 1.111695, [epoch:14, iter:400]




train loss : 1.122655, [epoch:14, iter:500]
train loss : 1.128454, [epoch:14, iter:600]




train loss : 1.131344, [epoch:14, iter:700]
train loss : 1.135465, [epoch:14, iter:800]
train loss : 1.138012, [epoch:14, iter:900]
TRAIN ACCURACY : 62.7% [38653/61632]


Train:  28%|███████                  | 14/50 [2:37:21<6:42:39, 671.10s/it]

TEST ACCURACY : 39.3% [4580/11648]

EPOCH:14 | train loss : 1.1377122309349037, test loss : 2.0367282655212904

EPOCH 15 
-------------------------
train loss : 1.114927, [epoch:15, iter:0]




train loss : 1.059304, [epoch:15, iter:100]
train loss : 1.074821, [epoch:15, iter:200]
train loss : 1.084307, [epoch:15, iter:300]




train loss : 1.094718, [epoch:15, iter:400]
train loss : 1.101315, [epoch:15, iter:500]
train loss : 1.107743, [epoch:15, iter:600]
train loss : 1.110926, [epoch:15, iter:700]
train loss : 1.117188, [epoch:15, iter:800]
train loss : 1.122495, [epoch:15, iter:900]
TRAIN ACCURACY : 62.9% [38740/61632]


Train:  30%|███████▌                 | 15/50 [2:48:12<6:27:57, 665.08s/it]

TEST ACCURACY : 44.0% [5120/11648]

EPOCH:15 | train loss : 1.12497683174266, test loss : 1.8503057969795478

EPOCH 16 
-------------------------
train loss : 1.045979, [epoch:16, iter:0]
train loss : 1.070574, [epoch:16, iter:100]
train loss : 1.066800, [epoch:16, iter:200]
train loss : 1.077166, [epoch:16, iter:300]




train loss : 1.084672, [epoch:16, iter:400]
train loss : 1.084765, [epoch:16, iter:500]
train loss : 1.096924, [epoch:16, iter:600]




train loss : 1.105130, [epoch:16, iter:700]
train loss : 1.109505, [epoch:16, iter:800]
train loss : 1.115801, [epoch:16, iter:900]
TRAIN ACCURACY : 63.2% [38949/61632]


Train:  32%|████████                 | 16/50 [2:59:22<6:17:43, 666.57s/it]

TEST ACCURACY : 43.7% [5085/11648]

EPOCH:16 | train loss : 1.118119881160534, test loss : 1.894145525418795

EPOCH 17 
-------------------------
train loss : 0.736936, [epoch:17, iter:0]




train loss : 1.033812, [epoch:17, iter:100]
train loss : 1.047596, [epoch:17, iter:200]




train loss : 1.063191, [epoch:17, iter:300]
train loss : 1.074994, [epoch:17, iter:400]
train loss : 1.086891, [epoch:17, iter:500]
train loss : 1.093787, [epoch:17, iter:600]
train loss : 1.099772, [epoch:17, iter:700]
train loss : 1.103872, [epoch:17, iter:800]
train loss : 1.107562, [epoch:17, iter:900]
TRAIN ACCURACY : 63.3% [39028/61632]


Train:  34%|████████▌                | 17/50 [3:10:53<6:10:39, 673.94s/it]

TEST ACCURACY : 42.7% [4974/11648]

EPOCH:17 | train loss : 1.1090533863965844, test loss : 1.9016262954407996

EPOCH 18 
-------------------------
train loss : 1.048547, [epoch:18, iter:0]
train loss : 1.006583, [epoch:18, iter:100]
train loss : 1.019218, [epoch:18, iter:200]




train loss : 1.041133, [epoch:18, iter:300]
train loss : 1.053420, [epoch:18, iter:400]
train loss : 1.072063, [epoch:18, iter:500]
train loss : 1.076873, [epoch:18, iter:600]
train loss : 1.084679, [epoch:18, iter:700]
train loss : 1.089067, [epoch:18, iter:800]




train loss : 1.095963, [epoch:18, iter:900]
TRAIN ACCURACY : 63.9% [39364/61632]


Train:  36%|█████████                | 18/50 [3:22:06<5:59:18, 673.71s/it]

TEST ACCURACY : 44.8% [5218/11648]

EPOCH:18 | train loss : 1.099544215115679, test loss : 1.811018827852312

EPOCH 19 
-------------------------
train loss : 1.066224, [epoch:19, iter:0]
train loss : 1.003148, [epoch:19, iter:100]
train loss : 1.029757, [epoch:19, iter:200]
train loss : 1.051111, [epoch:19, iter:300]
train loss : 1.053009, [epoch:19, iter:400]
train loss : 1.055947, [epoch:19, iter:500]




train loss : 1.063562, [epoch:19, iter:600]
train loss : 1.071719, [epoch:19, iter:700]
train loss : 1.077475, [epoch:19, iter:800]
train loss : 1.082908, [epoch:19, iter:900]
TRAIN ACCURACY : 64.2% [39576/61632]


Train:  38%|█████████▌               | 19/50 [3:33:10<5:46:34, 670.79s/it]

TEST ACCURACY : 39.0% [4547/11648]

EPOCH:19 | train loss : 1.0866351310214026, test loss : 1.988170442345378

EPOCH 20 
-------------------------
train loss : 0.831658, [epoch:20, iter:0]




train loss : 0.987621, [epoch:20, iter:100]
train loss : 1.004146, [epoch:20, iter:200]
train loss : 1.019862, [epoch:20, iter:300]
train loss : 1.031255, [epoch:20, iter:400]
train loss : 1.043543, [epoch:20, iter:500]
train loss : 1.060842, [epoch:20, iter:600]
train loss : 1.063791, [epoch:20, iter:700]
train loss : 1.071011, [epoch:20, iter:800]




train loss : 1.077743, [epoch:20, iter:900]
TRAIN ACCURACY : 64.4% [39717/61632]


Train:  40%|██████████               | 20/50 [3:44:22<5:35:27, 670.91s/it]

TEST ACCURACY : 40.4% [4711/11648]

EPOCH:20 | train loss : 1.0799733617595424, test loss : 2.1595089632076223

EPOCH 21 
-------------------------
train loss : 1.263471, [epoch:21, iter:0]
train loss : 1.006890, [epoch:21, iter:100]
train loss : 1.016773, [epoch:21, iter:200]




train loss : 1.034425, [epoch:21, iter:300]
train loss : 1.046259, [epoch:21, iter:400]
train loss : 1.056746, [epoch:21, iter:500]
train loss : 1.060549, [epoch:21, iter:600]
train loss : 1.063013, [epoch:21, iter:700]




train loss : 1.065568, [epoch:21, iter:800]
train loss : 1.069314, [epoch:21, iter:900]
TRAIN ACCURACY : 64.8% [39914/61632]


Train:  42%|██████████▌              | 21/50 [3:55:19<5:22:18, 666.85s/it]

TEST ACCURACY : 40.8% [4757/11648]

EPOCH:21 | train loss : 1.0706028815122903, test loss : 1.9306557505995363

EPOCH 22 
-------------------------
train loss : 1.049485, [epoch:22, iter:0]


In [None]:
figure = plt.figure(figsize = (10,5))
rows, cols = 1, 2 

title = ['train loss', 'test loss']
values = [train_loss_lst, test_loss_lst]
for i in range(1, rows* cols +1):
    figure.add_subplot(rows,cols,i)
    plt.title(title[i-1])
    plt.plot(values[i-1])

plt.show()