# ResNet 18

## Configurations

### Install and import necessary libraries

In [1]:
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
from PIL import Image

import os
import random
from datetime import datetime

import torch
from torch.utils.data import Dataset, DataLoader,SubsetRandomSampler
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import torchvision.models as models
from torchvision import transforms
from torchsummary import summary

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


#### GPU or not?

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

#### Path specified

In [3]:
DATA_DIR = 'data'
LABEL_DIR = 'label2.csv'
IMAGE_DIR = os.path.join(DATA_DIR, 'DaanForestPark')

### Hyperparameters

In [4]:
EPOCHES = 40
# FILTER_NUMS = 8
# FILTER_NUMS2 = 16
CHANNEL_NUMS = 3
# KERNEL_SIZE = 13
# STRIDE = 1
BATCH_SIZE = 32
NUM_WORKERS = 4
LR = 1e-2
VALID_RATIO = .2
PATIENCE = 3

In [5]:
# ToPILImage() -> Resize() -> ToTensor()
transform = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.3),
#         transforms.ColorJitter(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])        
        ])

### Define the DataSet and the DataLoader

In [6]:
class MyDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        _images, _labels = [], []
        # total amount of dataset 
        _number = 0
        # Reading the categorical file
        label_df = pd.read_csv(label_dir)
        
        # Iterate all files including .jpg inages  
        for subdir, dirs, files in tqdm(os.walk(image_dir)):
            for filename in files:
                corr_label = label_df[label_df['dirpath']==subdir[len(DATA_DIR)+1:]]['label'].values
                if corr_label.size!= 0 and filename.endswith(('jpg')):
                    _images.append(subdir + os.sep + filename)
                    _labels.append(corr_label)
                    _number+=1
        
        # Randomly arrange data pairs
        mapIndexPosition = list(zip(_images, _labels))
        random.shuffle(mapIndexPosition)
        _images, _labels = zip(*mapIndexPosition)

        self._image = iter(_images)
        self._labels = iter(_labels)
        self._number = _number
        self._category = label_df['label'].nunique()
        self.transform = transform
        
    def __len__(self):
        return self._number

    def __getitem__(self, index):    
        img = next(self._image)
        lab = next(self._labels)
        
        img = self._loadimage(img)
        if self.transform:
            img = self.transform(img)        
        return img, lab
     
    def _categorical(self, label):
        return np.arange(self._category) == label[:,None]
    
    def _loadimage(self, file):
        return Image.open(file).convert('RGB')
    
    def get_categorical_nums(self):
        return self._category
    

In [7]:
train_dataset = MyDataset(IMAGE_DIR, LABEL_DIR, transform=transform)

valid_size = VALID_RATIO
num_train = len(train_dataset)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, sampler=train_sampler, drop_last=True)
valid_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, sampler=valid_sampler, drop_last=True)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [8]:
# def init_weights(m):
#     if type(m) == nn.Conv2d:
#         nn.init.xavier_uniform(m.weight)
#         m.bias.data.fill_(0.01)
#     if type(m) == nn.Linear:
#         nn.init.uniform_(m.weight)
#         m.bias.data.fill_(0.01)   
# class SimpleCNN(nn.Module):
    
#     def __init__(self, target):
#         super(SimpleCNN, self).__init__()
#         # Input size (3, 1136, 640)
# #         self.imgzipper  = nn.AvgPool2d(kernel_size=ZIPSIZE)
#         # Input size (3, 284, 160)
#         self.conv1 = nn.Sequential(
#             nn.Conv2d(in_channels=CHANNEL_NUMS,
#                         out_channels=FILTER_NUMS,
#                         kernel_size=KERNEL_SIZE,
#                         stride=STRIDE,
#                         padding=(KERNEL_SIZE-STRIDE)//2 # padding=(kernel_size-stride)/2 -> original size
#                     ),
#             nn.Dropout(0.5),
#             nn.ReLU(),
#             # (8, 1136, 640)
#             nn.MaxPool2d(kernel_size=KERNEL_SIZE)
#             # (8, 87, 49)
#             # zipper (8, 21, 12)
#         ).apply(init_weights)
        
#         # (8, 87, 49)
#         self.conv2 = nn.Sequential(
#             nn.Conv2d(in_channels=FILTER_NUMS,
#                         out_channels=FILTER_NUMS2,
#                         kernel_size=KERNEL_SIZE,
#                         stride=STRIDE,
#                         padding=(KERNEL_SIZE-STRIDE)//2 # padding=(kernel_size-stride)/2 -> original size
#                     ),
#             nn.Dropout(0.5),
#             nn.ReLU(),
#             # (16, 87, 49)
#             nn.MaxPool2d(kernel_size=5)
#             # (16, 6, 3)
#         ).apply(init_weights)
#         self.MLP = nn.Sequential(
#             nn.Linear(128, 81),
#             nn.ReLU(),
#             nn.Linear(81, 81),
#             nn.ReLU(),
#             nn.Linear(81, target)
#         ).apply(init_weights)
#     def forward(self, x):
#         x = self.imgzipper(x)
#         x = self.conv1(x)
#         x = self.conv2(x)
#         x = x.view(x.size(0), -1)
#         x = self.MLP(x)
#         return x

### ResNet18

In [8]:
# model = SimpleCNN(train_dataset.get_categorical_nums()).to(device)
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, train_dataset.get_categorical_nums())

In [None]:
""" 分層設定 lr """
# large_lr_layers = list(map(id,model.fc.parameters()))
# small_lr_layers = filter(lambda p:id(p) not in large_lr_layers,model.parameters())
# optimizer = torch.optim.SGD([
#             {"params":large_lr_layers},
#             {"params":small_lr_layers,"lr":1e-4}
#             ],lr = 1e-2,momenum=0.9)

In [9]:
# channels, H, W
model = model_ft.to(device=device)
summary(model, input_size=(CHANNEL_NUMS, 340, 192))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 64, 170, 96]           9,408
       BatchNorm2d-2          [-1, 64, 170, 96]             128
              ReLU-3          [-1, 64, 170, 96]               0
         MaxPool2d-4           [-1, 64, 85, 48]               0
            Conv2d-5           [-1, 64, 85, 48]          36,864
       BatchNorm2d-6           [-1, 64, 85, 48]             128
              ReLU-7           [-1, 64, 85, 48]               0
            Conv2d-8           [-1, 64, 85, 48]          36,864
       BatchNorm2d-9           [-1, 64, 85, 48]             128
             ReLU-10           [-1, 64, 85, 48]               0
       BasicBlock-11           [-1, 64, 85, 48]               0
           Conv2d-12           [-1, 64, 85, 48]          36,864
      BatchNorm2d-13           [-1, 64, 85, 48]             128
             ReLU-14           [-1, 64,

### learning Rate dynamically tunning

In [None]:
def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = LR * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
        
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [11]:
import shutil 
def save_checkpoint(model, state, filename='runs/ckpt/model.ckpt'):
    torch.save(state, filename)
    torch.save(model, 'model_best.ckpt')
#     shutil.copyfile(filename, 'model_best.ckpt')

In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=LR)   # optimize all cnn parameters
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=0, verbose=True)
criterion = nn.CrossEntropyLoss().to(device=device)

# early stopping
min_val_loss = np.Inf
patience = PATIENCE
global_step = 1
# build the writer file
write_file = 'runs/experiment_{}'.format(datetime.now().strftime('%f'))
writer = SummaryWriter(write_file)
os.mkdir(write_file + '/ckpt')

model_ft.train()

for epoch in range(EPOCHES):
    for i, (img_batch, label_batch) in tqdm(enumerate(train_loader)):    
        optimizer.zero_grad()      
        img_batch = img_batch.to(device=device)
        label_batch = label_batch.to(device=device)  
        output = model_ft(img_batch)
        loss = criterion(output, label_batch.squeeze())
        
        # l2 Regularization loss
        l1_regularization = 0
        l2_regularization = 0
        for p in model.parameters():
            l1_regularization += torch.norm(p, 1)
            l2_regularization += torch.norm(p, 2)
        loss = loss + 1e-3 * l2_regularization

        loss.backward()
        # clip the grandient value for avoiding explosion
        nn.utils.clip_grad_norm_(model.parameters(), 0.9) 
        optimizer.step()

        # Compute accuracy
        _, predicted = torch.max(output.cpu().data, 1)
        accuracy = torch.sum(predicted == label_batch.cpu().data.view(-1), dtype=torch.float32) / BATCH_SIZE
        
        # Write tensorboard
        writer.add_scalar('train/Accuracy', accuracy.item(), global_step)
        writer.add_scalar('train/Loss', loss.item(), global_step)
        writer.add_scalar('train/L1RegLoss', l1_regularization.item(), global_step)
        writer.add_scalar('train/L2RegLoss', l2_regularization.item(), global_step)
        writer.add_scalar('train/LR', get_lr(optimizer), global_step)
                
        global_step += 1
        
        if i % 50== 0:
            print('epoch {}, step {}, \
            total_loss={:.3f}, \
            accuracy={:.3f}'.format(epoch+1, i, loss.item(), accuracy.item()))
    
    
    print('--- Validation phase ---')
    eval_loss = 0
    with torch.no_grad():
        for i, (img_batch, label_batch) in enumerate(valid_loader):
            output = model(img_batch.to(device))
            _, predicted = torch.max(output.cpu().data, 1)
            loss = criterion(output, label_batch.to(device).squeeze())
            accuracy = torch.sum(predicted == label_batch.data.view(-1), dtype=torch.float32) / BATCH_SIZE
            eval_loss += loss.item()
            
            # Write tensorboard
            for cat in range(train_dataset.get_categorical_nums()):
                writer.add_pr_curve('valid/pr_curve', (label_batch == cat).int().squeeze(),
                                    (predicted == cat).int().squeeze(), epoch*len(valid_loader)+i)
            
            writer.add_images('valid/image_batch', img_batch, epoch*len(valid_loader)+i)
            writer.add_scalar('valid/Accuracy', accuracy.item(), epoch*len(valid_loader)+i)
            writer.add_scalar('valid/Loss', loss.item(), epoch*len(valid_loader)+i)
    
    eval_loss = eval_loss / len(valid_loader)
    
    scheduler.step(eval_loss)
    
    print('epoch {}, val_loss={:.3f}'.format(epoch+1, eval_loss))

    ## Early Stopping
    if eval_loss < min_val_loss:
        save_checkpoint(model, 
            {
            'epoch': epoch+1,
            'state_dict': model.state_dict(),
            'best_loss': eval_loss,
            'optimizer' :optimizer.state_dict(),
            }, os.path.join(write_file, 'ckpt/resNet_{}.ckpt'.format(epoch+1)))
        min_val_loss = eval_loss
    else:
        patience-=1
    if patience == 0:
        print('Early stopping')
        break

writer.close()
print('Finish all training !')

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 1, step 0,             total_loss=5.353,             accuracy=0.000
epoch 1, step 50,             total_loss=6.396,             accuracy=0.000
epoch 1, step 100,             total_loss=5.273,             accuracy=0.062
epoch 1, step 150,             total_loss=4.536,             accuracy=0.156
epoch 1, step 200,             total_loss=4.746,             accuracy=0.062
epoch 1, step 250,             total_loss=4.282,             accuracy=0.125
epoch 1, step 300,             total_loss=4.505,             accuracy=0.094
epoch 1, step 350,             total_loss=4.209,             accuracy=0.125
epoch 1, step 400,             total_loss=5.827,             accuracy=0.062
epoch 1, step 450,             total_loss=4.463,             accuracy=0.125
epoch 1, step 500,             total_loss=4.614,             accuracy=0.156

--- Validation phase ---
epoch 1, val_loss=3.739


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 2, step 0,             total_loss=4.097,             accuracy=0.125
epoch 2, step 50,             total_loss=4.061,             accuracy=0.125
epoch 2, step 100,             total_loss=4.057,             accuracy=0.188
epoch 2, step 150,             total_loss=3.911,             accuracy=0.125
epoch 2, step 200,             total_loss=4.178,             accuracy=0.156
epoch 2, step 250,             total_loss=3.109,             accuracy=0.250
epoch 2, step 300,             total_loss=3.945,             accuracy=0.125
epoch 2, step 350,             total_loss=3.530,             accuracy=0.156
epoch 2, step 400,             total_loss=4.281,             accuracy=0.219
epoch 2, step 450,             total_loss=3.473,             accuracy=0.281
epoch 2, step 500,             total_loss=3.630,             accuracy=0.250

--- Validation phase ---
epoch 2, val_loss=2.788


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 3, step 0,             total_loss=3.203,             accuracy=0.281
epoch 3, step 50,             total_loss=3.462,             accuracy=0.156
epoch 3, step 100,             total_loss=3.686,             accuracy=0.219
epoch 3, step 150,             total_loss=2.847,             accuracy=0.344
epoch 3, step 200,             total_loss=2.884,             accuracy=0.438
epoch 3, step 250,             total_loss=2.589,             accuracy=0.406
epoch 3, step 300,             total_loss=3.098,             accuracy=0.219
epoch 3, step 350,             total_loss=3.077,             accuracy=0.281
epoch 3, step 400,             total_loss=3.606,             accuracy=0.312
epoch 3, step 450,             total_loss=2.767,             accuracy=0.531
epoch 3, step 500,             total_loss=3.216,             accuracy=0.281

--- Validation phase ---
epoch 3, val_loss=2.164


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 4, step 0,             total_loss=2.962,             accuracy=0.375
epoch 4, step 50,             total_loss=2.386,             accuracy=0.469
epoch 4, step 100,             total_loss=2.998,             accuracy=0.281
epoch 4, step 150,             total_loss=2.178,             accuracy=0.594
epoch 4, step 200,             total_loss=2.191,             accuracy=0.562
epoch 4, step 250,             total_loss=2.136,             accuracy=0.531
epoch 4, step 300,             total_loss=2.686,             accuracy=0.406
epoch 4, step 350,             total_loss=2.583,             accuracy=0.344
epoch 4, step 400,             total_loss=3.025,             accuracy=0.469
epoch 4, step 450,             total_loss=2.763,             accuracy=0.438
epoch 4, step 500,             total_loss=2.919,             accuracy=0.375

--- Validation phase ---
epoch 4, val_loss=1.952


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 5, step 0,             total_loss=2.779,             accuracy=0.438
epoch 5, step 50,             total_loss=2.212,             accuracy=0.531
epoch 5, step 100,             total_loss=2.674,             accuracy=0.406
epoch 5, step 150,             total_loss=1.999,             accuracy=0.656
epoch 5, step 200,             total_loss=1.936,             accuracy=0.688
epoch 5, step 250,             total_loss=1.716,             accuracy=0.688
epoch 5, step 300,             total_loss=2.311,             accuracy=0.500
epoch 5, step 350,             total_loss=1.942,             accuracy=0.625
epoch 5, step 400,             total_loss=2.592,             accuracy=0.438
epoch 5, step 450,             total_loss=2.211,             accuracy=0.500
epoch 5, step 500,             total_loss=2.762,             accuracy=0.406

--- Validation phase ---
epoch 5, val_loss=1.627


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 6, step 0,             total_loss=2.390,             accuracy=0.500
epoch 6, step 50,             total_loss=2.016,             accuracy=0.688
epoch 6, step 100,             total_loss=2.451,             accuracy=0.375
epoch 6, step 150,             total_loss=2.000,             accuracy=0.531
epoch 6, step 200,             total_loss=1.918,             accuracy=0.562
epoch 6, step 250,             total_loss=1.490,             accuracy=0.812
epoch 6, step 300,             total_loss=2.229,             accuracy=0.594
epoch 6, step 350,             total_loss=2.140,             accuracy=0.594
epoch 6, step 400,             total_loss=2.344,             accuracy=0.438
epoch 6, step 450,             total_loss=2.143,             accuracy=0.531
epoch 6, step 500,             total_loss=2.430,             accuracy=0.406

--- Validation phase ---
epoch 6, val_loss=1.483


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 7, step 0,             total_loss=2.357,             accuracy=0.469
epoch 7, step 50,             total_loss=1.688,             accuracy=0.656
epoch 7, step 100,             total_loss=2.491,             accuracy=0.531
epoch 7, step 150,             total_loss=2.047,             accuracy=0.531
epoch 7, step 200,             total_loss=1.604,             accuracy=0.688
epoch 7, step 250,             total_loss=1.606,             accuracy=0.750
epoch 7, step 300,             total_loss=2.303,             accuracy=0.438
epoch 7, step 350,             total_loss=2.115,             accuracy=0.625
epoch 7, step 400,             total_loss=2.273,             accuracy=0.500
epoch 7, step 450,             total_loss=1.953,             accuracy=0.625
epoch 7, step 500,             total_loss=2.086,             accuracy=0.594

--- Validation phase ---
epoch 7, val_loss=1.290


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 8, step 0,             total_loss=2.234,             accuracy=0.625
epoch 8, step 50,             total_loss=1.894,             accuracy=0.594
epoch 8, step 100,             total_loss=2.111,             accuracy=0.594
epoch 8, step 150,             total_loss=1.623,             accuracy=0.750
epoch 8, step 200,             total_loss=1.873,             accuracy=0.656
epoch 8, step 250,             total_loss=1.462,             accuracy=0.844
epoch 8, step 300,             total_loss=1.982,             accuracy=0.594
epoch 8, step 350,             total_loss=2.143,             accuracy=0.531
epoch 8, step 400,             total_loss=2.232,             accuracy=0.531
epoch 8, step 450,             total_loss=1.685,             accuracy=0.656
epoch 8, step 500,             total_loss=1.919,             accuracy=0.562

--- Validation phase ---
epoch 8, val_loss=1.180


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 9, step 0,             total_loss=2.189,             accuracy=0.594
epoch 9, step 50,             total_loss=1.905,             accuracy=0.625
epoch 9, step 100,             total_loss=1.760,             accuracy=0.531
epoch 9, step 150,             total_loss=1.966,             accuracy=0.594
epoch 9, step 200,             total_loss=1.550,             accuracy=0.688
epoch 9, step 250,             total_loss=1.371,             accuracy=0.844
epoch 9, step 300,             total_loss=1.879,             accuracy=0.625
epoch 9, step 350,             total_loss=1.643,             accuracy=0.688
epoch 9, step 400,             total_loss=2.173,             accuracy=0.500
epoch 9, step 450,             total_loss=1.699,             accuracy=0.625
epoch 9, step 500,             total_loss=1.927,             accuracy=0.719

--- Validation phase ---
epoch 9, val_loss=1.173


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 10, step 0,             total_loss=1.841,             accuracy=0.625
epoch 10, step 50,             total_loss=1.492,             accuracy=0.688
epoch 10, step 100,             total_loss=2.261,             accuracy=0.500
epoch 10, step 150,             total_loss=1.740,             accuracy=0.750
epoch 10, step 200,             total_loss=1.287,             accuracy=0.812
epoch 10, step 250,             total_loss=1.535,             accuracy=0.719
epoch 10, step 300,             total_loss=2.094,             accuracy=0.531
epoch 10, step 350,             total_loss=1.948,             accuracy=0.719
epoch 10, step 400,             total_loss=2.212,             accuracy=0.562
epoch 10, step 450,             total_loss=1.412,             accuracy=0.812
epoch 10, step 500,             total_loss=1.937,             accuracy=0.625

--- Validation phase ---
epoch 10, val_loss=1.098


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 11, step 0,             total_loss=1.511,             accuracy=0.688
epoch 11, step 50,             total_loss=1.387,             accuracy=0.812
epoch 11, step 100,             total_loss=2.054,             accuracy=0.656
epoch 11, step 150,             total_loss=1.495,             accuracy=0.812
epoch 11, step 200,             total_loss=1.565,             accuracy=0.656
epoch 11, step 250,             total_loss=1.239,             accuracy=0.844
epoch 11, step 300,             total_loss=1.802,             accuracy=0.656
epoch 11, step 350,             total_loss=1.556,             accuracy=0.688
epoch 11, step 400,             total_loss=2.251,             accuracy=0.594
epoch 11, step 450,             total_loss=1.459,             accuracy=0.781
epoch 11, step 500,             total_loss=1.792,             accuracy=0.656

--- Validation phase ---
epoch 11, val_loss=1.011


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 12, step 0,             total_loss=1.820,             accuracy=0.594
epoch 12, step 50,             total_loss=1.294,             accuracy=0.875
epoch 12, step 100,             total_loss=2.104,             accuracy=0.531
epoch 12, step 150,             total_loss=1.602,             accuracy=0.625
epoch 12, step 200,             total_loss=1.555,             accuracy=0.750
epoch 12, step 250,             total_loss=1.620,             accuracy=0.688
epoch 12, step 300,             total_loss=1.801,             accuracy=0.719
epoch 12, step 350,             total_loss=1.660,             accuracy=0.688
epoch 12, step 400,             total_loss=2.322,             accuracy=0.719
epoch 12, step 450,             total_loss=1.821,             accuracy=0.625
epoch 12, step 500,             total_loss=1.477,             accuracy=0.750

--- Validation phase ---
epoch 12, val_loss=0.863


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 13, step 0,             total_loss=1.535,             accuracy=0.688
epoch 13, step 50,             total_loss=1.376,             accuracy=0.781
epoch 13, step 100,             total_loss=2.153,             accuracy=0.656
epoch 13, step 150,             total_loss=1.372,             accuracy=0.875
epoch 13, step 200,             total_loss=1.408,             accuracy=0.812
epoch 13, step 250,             total_loss=1.179,             accuracy=0.844
epoch 13, step 300,             total_loss=2.202,             accuracy=0.531
epoch 13, step 350,             total_loss=1.659,             accuracy=0.688
epoch 13, step 400,             total_loss=1.670,             accuracy=0.625
epoch 13, step 450,             total_loss=1.271,             accuracy=0.844
epoch 13, step 500,             total_loss=1.447,             accuracy=0.812

--- Validation phase ---
Epoch    13: reducing learning rate of group 0 to 9.0000e-03.
epoch 13, val_loss=0.869


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 14, step 0,             total_loss=1.728,             accuracy=0.625
epoch 14, step 50,             total_loss=1.562,             accuracy=0.719
epoch 14, step 100,             total_loss=1.707,             accuracy=0.688
epoch 14, step 150,             total_loss=1.639,             accuracy=0.781
epoch 14, step 200,             total_loss=1.502,             accuracy=0.750
epoch 14, step 250,             total_loss=1.336,             accuracy=0.812
epoch 14, step 300,             total_loss=1.722,             accuracy=0.750
epoch 14, step 350,             total_loss=1.512,             accuracy=0.781
epoch 14, step 400,             total_loss=2.132,             accuracy=0.625
epoch 14, step 450,             total_loss=1.285,             accuracy=0.875
epoch 14, step 500,             total_loss=1.537,             accuracy=0.812

--- Validation phase ---
epoch 14, val_loss=0.802


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 15, step 0,             total_loss=1.566,             accuracy=0.656
epoch 15, step 50,             total_loss=1.197,             accuracy=0.875
epoch 15, step 100,             total_loss=1.967,             accuracy=0.719
epoch 15, step 150,             total_loss=1.242,             accuracy=0.875
epoch 15, step 200,             total_loss=1.530,             accuracy=0.781
epoch 15, step 250,             total_loss=1.317,             accuracy=0.781
epoch 15, step 300,             total_loss=1.662,             accuracy=0.656
epoch 15, step 350,             total_loss=1.775,             accuracy=0.688
epoch 15, step 400,             total_loss=1.493,             accuracy=0.656
epoch 15, step 450,             total_loss=1.522,             accuracy=0.750
epoch 15, step 500,             total_loss=1.288,             accuracy=0.812

--- Validation phase ---
Epoch    15: reducing learning rate of group 0 to 8.1000e-03.
epoch 15, val_loss=0.844


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 16, step 0,             total_loss=1.925,             accuracy=0.625
epoch 16, step 50,             total_loss=1.356,             accuracy=0.812
epoch 16, step 100,             total_loss=1.761,             accuracy=0.750
epoch 16, step 150,             total_loss=1.400,             accuracy=0.781
epoch 16, step 200,             total_loss=1.309,             accuracy=0.875
epoch 16, step 250,             total_loss=1.303,             accuracy=0.781
epoch 16, step 300,             total_loss=1.236,             accuracy=0.844
epoch 16, step 350,             total_loss=1.424,             accuracy=0.781
epoch 16, step 400,             total_loss=2.118,             accuracy=0.688
epoch 16, step 450,             total_loss=1.250,             accuracy=0.781
epoch 16, step 500,             total_loss=1.267,             accuracy=0.812

--- Validation phase ---
epoch 16, val_loss=0.722


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 17, step 0,             total_loss=1.590,             accuracy=0.719
epoch 17, step 50,             total_loss=1.346,             accuracy=0.812
epoch 17, step 100,             total_loss=1.609,             accuracy=0.719
epoch 17, step 150,             total_loss=1.611,             accuracy=0.781
epoch 17, step 200,             total_loss=1.430,             accuracy=0.781
epoch 17, step 250,             total_loss=1.061,             accuracy=0.906
epoch 17, step 300,             total_loss=1.519,             accuracy=0.719
epoch 17, step 350,             total_loss=1.416,             accuracy=0.781
epoch 17, step 400,             total_loss=1.706,             accuracy=0.750
epoch 17, step 450,             total_loss=1.226,             accuracy=0.781
epoch 17, step 500,             total_loss=1.530,             accuracy=0.719

--- Validation phase ---
epoch 17, val_loss=0.694


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 18, step 0,             total_loss=1.362,             accuracy=0.812
epoch 18, step 50,             total_loss=1.211,             accuracy=0.844
epoch 18, step 100,             total_loss=1.444,             accuracy=0.844
epoch 18, step 150,             total_loss=1.739,             accuracy=0.719
epoch 18, step 200,             total_loss=1.345,             accuracy=0.812
epoch 18, step 250,             total_loss=1.378,             accuracy=0.781
epoch 18, step 300,             total_loss=1.593,             accuracy=0.719
epoch 18, step 350,             total_loss=1.448,             accuracy=0.812
epoch 18, step 400,             total_loss=1.690,             accuracy=0.750
epoch 18, step 450,             total_loss=1.117,             accuracy=0.938
epoch 18, step 500,             total_loss=1.331,             accuracy=0.812

--- Validation phase ---
Epoch    18: reducing learning rate of group 0 to 7.2900e-03.
epoch 18, val_loss=0.747
Early stopping
Finish all training !


In [13]:
# CKPT_PATH = 'model_best.ckpt'
# model.load_state_dict(torch.load(CKPT_PATH)['state_dict'])
model.eval()
acc = 0
for i, (img_batch, label_batch) in enumerate(valid_loader):
    output = model(img_batch.to(device))
    _, predicted = torch.max(output.cpu().data, 1)
    accuracy = torch.sum(predicted == label_batch.data.view(-1), dtype=torch.float32) / BATCH_SIZE
    acc += accuracy
print('accuracy={}'.format(acc/len(valid_loader)))

accuracy=0.8102189898490906


In [20]:
# torch.save(model, 'whole_model.ckpt')