# Convolution2d

## Configurations

### Install and import necessary libraries

In [1]:
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
from PIL import Image

import os
import random
from datetime import datetime

import torch
from torch.utils.data import Dataset, DataLoader,SubsetRandomSampler
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import torchvision.models as models
from torchvision import transforms
from torchsummary import summary

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


#### GPU or not?

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

#### Path specified

In [3]:
DATA_DIR = 'data'
LABEL_DIR = 'label2.csv'
IMAGE_DIR = os.path.join(DATA_DIR, 'DaanForestPark')

### Hyperparameters

In [4]:
EPOCHES = 20
FILTER_NUMS = 8
FILTER_NUMS2 = 16
CHANNEL_NUMS = 3
KERNEL_SIZE = 13
STRIDE = 1
BATCH_SIZE = 32
NUM_WORKERS = 4
LR = 1e-2

In [5]:
# ToPILImage() -> Resize() -> ToTensor()
transform = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.3),
#         transforms.ColorJitter(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])        
        ])

### Define the DataSet and the DataLoader

In [6]:
class MyDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        _images, _labels = [], []
        # total amount of dataset 
        _number = 0
        # Reading the categorical file
        label_df = pd.read_csv(label_dir)
        
        # Iterate all files including .jpg inages  
        for subdir, dirs, files in tqdm(os.walk(image_dir)):
            for filename in files:
                corr_label = label_df[label_df['dirpath']==subdir[len(DATA_DIR)+1:]]['label'].values
                if corr_label.size!= 0 and filename.endswith(('jpg')):
                    _images.append(subdir + os.sep + filename)
                    _labels.append(corr_label)
                    _number+=1
        
        # Randomly arrange data pairs
        mapIndexPosition = list(zip(_images, _labels))
        random.shuffle(mapIndexPosition)
        _images, _labels = zip(*mapIndexPosition)

        self._image = iter(_images)
        self._labels = iter(_labels)
        self._number = _number
        self._category = label_df['label'].nunique()
        self.transform = transform
        
    def __len__(self):
        return self._number

    def __getitem__(self, index):    
        img = next(self._image)
        lab = next(self._labels)
        
        img = self._loadimage(img)
        if self.transform:
            img = self.transform(img)        
        return img, lab
     
    def _categorical(self, label):
        return np.arange(self._category) == label[:,None]
    
    def _loadimage(self, file):
        return Image.open(file).convert('RGB')
    
    def get_categorical_nums(self):
        return self._category
    

In [7]:
train_dataset = MyDataset(IMAGE_DIR, LABEL_DIR, transform=transform)

valid_size = .2
num_train = len(train_dataset)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, sampler=train_sampler, drop_last=True)
valid_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, sampler=valid_sampler, drop_last=True)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [8]:
# def init_weights(m):
#     if type(m) == nn.Conv2d:
#         nn.init.xavier_uniform(m.weight)
#         m.bias.data.fill_(0.01)
#     if type(m) == nn.Linear:
#         nn.init.uniform_(m.weight)
#         m.bias.data.fill_(0.01)   
# class SimpleCNN(nn.Module):
    
#     def __init__(self, target):
#         super(SimpleCNN, self).__init__()
#         # Input size (3, 1136, 640)
# #         self.imgzipper  = nn.AvgPool2d(kernel_size=ZIPSIZE)
#         # Input size (3, 284, 160)
#         self.conv1 = nn.Sequential(
#             nn.Conv2d(in_channels=CHANNEL_NUMS,
#                         out_channels=FILTER_NUMS,
#                         kernel_size=KERNEL_SIZE,
#                         stride=STRIDE,
#                         padding=(KERNEL_SIZE-STRIDE)//2 # padding=(kernel_size-stride)/2 -> original size
#                     ),
#             nn.Dropout(0.5),
#             nn.ReLU(),
#             # (8, 1136, 640)
#             nn.MaxPool2d(kernel_size=KERNEL_SIZE)
#             # (8, 87, 49)
#             # zipper (8, 21, 12)
#         ).apply(init_weights)
        
#         # (8, 87, 49)
#         self.conv2 = nn.Sequential(
#             nn.Conv2d(in_channels=FILTER_NUMS,
#                         out_channels=FILTER_NUMS2,
#                         kernel_size=KERNEL_SIZE,
#                         stride=STRIDE,
#                         padding=(KERNEL_SIZE-STRIDE)//2 # padding=(kernel_size-stride)/2 -> original size
#                     ),
#             nn.Dropout(0.5),
#             nn.ReLU(),
#             # (16, 87, 49)
#             nn.MaxPool2d(kernel_size=5)
#             # (16, 6, 3)
#         ).apply(init_weights)
#         self.MLP = nn.Sequential(
#             nn.Linear(128, 81),
#             nn.ReLU(),
#             nn.Linear(81, 81),
#             nn.ReLU(),
#             nn.Linear(81, target)
#         ).apply(init_weights)
#     def forward(self, x):
#         x = self.imgzipper(x)
#         x = self.conv1(x)
#         x = self.conv2(x)
#         x = x.view(x.size(0), -1)
#         x = self.MLP(x)
#         return x

### ResNet18

In [16]:
# model = SimpleCNN(train_dataset.get_categorical_nums()).to(device)
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, train_dataset.get_categorical_nums())

In [17]:
# channels, H, W
model = model_ft.to(device=device)
summary(model, input_size=(CHANNEL_NUMS, 340, 192))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 64, 170, 96]           9,408
       BatchNorm2d-2          [-1, 64, 170, 96]             128
              ReLU-3          [-1, 64, 170, 96]               0
         MaxPool2d-4           [-1, 64, 85, 48]               0
            Conv2d-5           [-1, 64, 85, 48]          36,864
       BatchNorm2d-6           [-1, 64, 85, 48]             128
              ReLU-7           [-1, 64, 85, 48]               0
            Conv2d-8           [-1, 64, 85, 48]          36,864
       BatchNorm2d-9           [-1, 64, 85, 48]             128
             ReLU-10           [-1, 64, 85, 48]               0
       BasicBlock-11           [-1, 64, 85, 48]               0
           Conv2d-12           [-1, 64, 85, 48]          36,864
      BatchNorm2d-13           [-1, 64, 85, 48]             128
             ReLU-14           [-1, 64,

In [11]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [12]:
import shutil 
def save_checkpoint(state, filename='ckpt/experiment/model.ckpt'):
    torch.save(state, filename)
    shutil.copyfile(filename, 'model_best.ckpt')

In [14]:
optimizer = torch.optim.Adam(model.parameters(), lr=LR)   # optimize all cnn parameters
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=0, verbose=True)
criterion = nn.CrossEntropyLoss().to(device=device)

# early stopping
min_val_loss = np.Inf
patience = 3
global_step = 1
writer = SummaryWriter('runs/experiment_{}'.format(datetime.now().strftime('%f')))

model_ft.train()

for epoch in range(EPOCHES):
    for i, (img_batch, label_batch) in tqdm(enumerate(train_loader)):    
        optimizer.zero_grad()      
        img_batch = img_batch.to(device=device)
        label_batch = label_batch.to(device=device)  
        output = model_ft(img_batch)
        loss = criterion(output, label_batch.squeeze())
        
        # l2 Regularization loss
        l2_regularization = 0
        for p in model.parameters():
            l2_regularization += torch.norm(p, 2)
#         loss = loss + 0.1 * l2_regularization

        loss.backward()
        # clip the grandient value for avoiding explosion
        nn.utils.clip_grad_norm_(model.parameters(), 0.9) 
        optimizer.step()

        # Compute accuracy
        _, predicted = torch.max(output.cpu().data, 1)
        accuracy = torch.sum(predicted == label_batch.cpu().data.view(-1), dtype=torch.float32) / BATCH_SIZE
        
        # Write tensorboard
        writer.add_scalar('train/Accuracy', accuracy.item(), global_step)
        writer.add_scalar('train/Loss', loss.item(), global_step)
        writer.add_scalar('train/L2RegLoss', l2_regularization.item(), global_step)
        writer.add_scalar('train/LR', get_lr(optimizer), global_step)
                
        global_step += 1
        
        if i % 50== 0:
            print('epoch {}, step {}, \
            l2_loss={:.3f}, total_loss={:.3f}, \
            accuracy={:.3f}'.format(epoch+1, i, l2_regularization.item() , loss.item(), accuracy.item()))
    
    
    print('--- Validation phase ---')
    eval_loss = 0
    with torch.no_grad():
        for i, (img_batch, label_batch) in enumerate(valid_loader):
            output = model(img_batch.to(device))
            _, predicted = torch.max(output.cpu().data, 1)
            loss = criterion(output, label_batch.to(device).squeeze())
            accuracy = torch.sum(predicted == label_batch.data.view(-1), dtype=torch.float32) / BATCH_SIZE
            eval_loss += loss.item()
            
            # Write tensorboard
            writer.add_pr_curve('valid/pr_curve', label_batch.squeeze(), predicted.squeeze(), epoch*len(valid_loader)+i)
            writer.add_images('valid/image_batch', img_batch, epoch*len(valid_loader)+i)
            writer.add_scalar('valid/Accuracy', accuracy.item(), epoch*len(valid_loader)+i)
            writer.add_scalar('valid/Loss', loss.item(), epoch*len(valid_loader)+i)
    
    eval_loss = eval_loss / len(valid_loader)
    
    scheduler.step(eval_loss)
    
    print('epoch {}, val_loss={:.3f}'.format(epoch+1, eval_loss))

    ## Early Stopping
    if eval_loss < min_val_loss:
        save_checkpoint({
            'epoch': epoch+1,
            'state_dict': model.state_dict(),
            'best_loss': eval_loss,
            'optimizer' :optimizer.state_dict(),
            }, 'ckpt/Transform/resNet_{}.ckpt'.format(epoch+1))
        min_val_loss = eval_loss
    else:
        patience-=1
    if patience == 0:
        print('Early stopping')
        break
        
writer.close()
print('Finish all training !')

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f01b3896d08>
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f01b3896d08>
Traceback (most recent call last):
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f01b3896d08>
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/silence/.pyenv/versions/3.7.2/envs/python-3.7.2/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f01b3896d08>
  File "/home/silence/.pyenv/versions/3.7.2/envs/python-3.7.2/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
  File "/home/silence/.pyenv/versions/3.7.2/envs/python-3.7.2/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 961, in __del__
Traceback (most recent call last):
    self._shutdown_workers()


epoch 1, step 0,             l2_loss=498.729, total_loss=3.501,             accuracy=0.094
epoch 1, step 50,             l2_loss=1032.029, total_loss=4.039,             accuracy=0.000
epoch 1, step 100,             l2_loss=1153.462, total_loss=4.069,             accuracy=0.125
epoch 1, step 150,             l2_loss=1189.826, total_loss=4.079,             accuracy=0.062
epoch 1, step 200,             l2_loss=1209.372, total_loss=4.199,             accuracy=0.094
epoch 1, step 250,             l2_loss=1234.380, total_loss=3.965,             accuracy=0.094
epoch 1, step 300,             l2_loss=1281.459, total_loss=3.769,             accuracy=0.031
epoch 1, step 350,             l2_loss=1333.341, total_loss=3.288,             accuracy=0.188
epoch 1, step 400,             l2_loss=1382.245, total_loss=3.586,             accuracy=0.031
epoch 1, step 450,             l2_loss=1424.712, total_loss=3.621,             accuracy=0.125
epoch 1, step 500,             l2_loss=1463.885, total_loss=3.47

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f01b3896d08>
Traceback (most recent call last):
  File "/home/silence/.pyenv/versions/3.7.2/envs/python-3.7.2/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/home/silence/.pyenv/versions/3.7.2/envs/python-3.7.2/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/home/silence/.pyenv/versions/3.7.2/lib/python3.7/multiprocessing/process.py", line 138, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f01b3896d08>
Traceback (most recent call last):
  File "/home/silence/.pyenv/versions/3.7.2/envs/python-3.7.2/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 961, in __del__
Exception ignored in: <function _MultiP

epoch 1, val_loss=3.281


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 2, step 0,             l2_loss=1505.968, total_loss=3.515,             accuracy=0.219
epoch 2, step 50,             l2_loss=1544.397, total_loss=2.982,             accuracy=0.312
epoch 2, step 100,             l2_loss=1584.044, total_loss=3.270,             accuracy=0.125
epoch 2, step 150,             l2_loss=1618.714, total_loss=3.341,             accuracy=0.094
epoch 2, step 200,             l2_loss=1678.672, total_loss=3.352,             accuracy=0.125
epoch 2, step 250,             l2_loss=1744.354, total_loss=3.354,             accuracy=0.125
epoch 2, step 300,             l2_loss=1788.474, total_loss=3.238,             accuracy=0.156
epoch 2, step 350,             l2_loss=1824.340, total_loss=2.883,             accuracy=0.250
epoch 2, step 400,             l2_loss=1868.870, total_loss=3.181,             accuracy=0.219
epoch 2, step 450,             l2_loss=1900.854, total_loss=3.291,             accuracy=0.219
epoch 2, step 500,             l2_loss=1933.787, total_loss=3.2

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 3, step 0,             l2_loss=1967.072, total_loss=2.860,             accuracy=0.281
epoch 3, step 50,             l2_loss=2001.847, total_loss=2.559,             accuracy=0.312
epoch 3, step 100,             l2_loss=2037.356, total_loss=2.853,             accuracy=0.125
epoch 3, step 150,             l2_loss=2067.986, total_loss=2.792,             accuracy=0.219
epoch 3, step 200,             l2_loss=2100.204, total_loss=2.726,             accuracy=0.281
epoch 3, step 250,             l2_loss=2127.115, total_loss=3.177,             accuracy=0.219
epoch 3, step 300,             l2_loss=2161.536, total_loss=2.530,             accuracy=0.312
epoch 3, step 350,             l2_loss=2187.573, total_loss=2.129,             accuracy=0.406
epoch 3, step 400,             l2_loss=2220.009, total_loss=2.484,             accuracy=0.250
epoch 3, step 450,             l2_loss=2245.386, total_loss=2.859,             accuracy=0.219
epoch 3, step 500,             l2_loss=2273.683, total_loss=2.2

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 4, step 0,             l2_loss=2301.588, total_loss=2.566,             accuracy=0.375
epoch 4, step 50,             l2_loss=2330.956, total_loss=2.095,             accuracy=0.375
epoch 4, step 100,             l2_loss=2356.584, total_loss=2.607,             accuracy=0.250
epoch 4, step 150,             l2_loss=2383.306, total_loss=2.315,             accuracy=0.375
epoch 4, step 200,             l2_loss=2415.496, total_loss=2.057,             accuracy=0.406
epoch 4, step 250,             l2_loss=2440.919, total_loss=2.753,             accuracy=0.250
epoch 4, step 300,             l2_loss=2472.382, total_loss=2.202,             accuracy=0.438
epoch 4, step 350,             l2_loss=2496.306, total_loss=1.871,             accuracy=0.469
epoch 4, step 400,             l2_loss=2521.280, total_loss=1.980,             accuracy=0.375
epoch 4, step 450,             l2_loss=2541.902, total_loss=2.566,             accuracy=0.281
epoch 4, step 500,             l2_loss=2569.971, total_loss=1.9

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 5, step 0,             l2_loss=2605.767, total_loss=2.331,             accuracy=0.375
epoch 5, step 50,             l2_loss=2634.507, total_loss=1.651,             accuracy=0.500
epoch 5, step 100,             l2_loss=2668.616, total_loss=2.166,             accuracy=0.312
epoch 5, step 150,             l2_loss=2691.591, total_loss=1.884,             accuracy=0.281
epoch 5, step 200,             l2_loss=2712.852, total_loss=1.884,             accuracy=0.438
epoch 5, step 250,             l2_loss=2735.174, total_loss=2.389,             accuracy=0.312
epoch 5, step 300,             l2_loss=2759.368, total_loss=2.114,             accuracy=0.406
epoch 5, step 350,             l2_loss=2780.916, total_loss=1.447,             accuracy=0.625
epoch 5, step 400,             l2_loss=2802.073, total_loss=1.799,             accuracy=0.344
epoch 5, step 450,             l2_loss=2830.143, total_loss=2.054,             accuracy=0.406
epoch 5, step 500,             l2_loss=2855.910, total_loss=2.0

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 6, step 0,             l2_loss=2882.608, total_loss=2.183,             accuracy=0.406
epoch 6, step 50,             l2_loss=2908.012, total_loss=1.517,             accuracy=0.594
epoch 6, step 100,             l2_loss=2927.029, total_loss=2.028,             accuracy=0.438
epoch 6, step 150,             l2_loss=2949.371, total_loss=1.462,             accuracy=0.594
epoch 6, step 200,             l2_loss=2967.582, total_loss=1.505,             accuracy=0.531
epoch 6, step 250,             l2_loss=2988.384, total_loss=1.960,             accuracy=0.438
epoch 6, step 300,             l2_loss=3016.040, total_loss=1.641,             accuracy=0.594
epoch 6, step 350,             l2_loss=3037.905, total_loss=1.215,             accuracy=0.594
epoch 6, step 400,             l2_loss=3058.981, total_loss=1.643,             accuracy=0.500
epoch 6, step 450,             l2_loss=3080.887, total_loss=1.478,             accuracy=0.438
epoch 6, step 500,             l2_loss=3104.628, total_loss=1.3

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 7, step 0,             l2_loss=3132.623, total_loss=1.731,             accuracy=0.469
epoch 7, step 50,             l2_loss=3155.863, total_loss=0.914,             accuracy=0.688
epoch 7, step 100,             l2_loss=3173.207, total_loss=1.670,             accuracy=0.469
epoch 7, step 150,             l2_loss=3195.291, total_loss=1.455,             accuracy=0.625
epoch 7, step 200,             l2_loss=3216.220, total_loss=1.212,             accuracy=0.656
epoch 7, step 250,             l2_loss=3236.025, total_loss=1.877,             accuracy=0.500
epoch 7, step 300,             l2_loss=3262.333, total_loss=1.667,             accuracy=0.594
epoch 7, step 350,             l2_loss=3281.756, total_loss=1.183,             accuracy=0.594
epoch 7, step 400,             l2_loss=3304.215, total_loss=1.463,             accuracy=0.500
epoch 7, step 450,             l2_loss=3326.282, total_loss=1.673,             accuracy=0.562
epoch 7, step 500,             l2_loss=3352.019, total_loss=1.6

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 8, step 0,             l2_loss=3380.879, total_loss=1.459,             accuracy=0.531
epoch 8, step 50,             l2_loss=3404.839, total_loss=0.944,             accuracy=0.719
epoch 8, step 100,             l2_loss=3421.717, total_loss=1.635,             accuracy=0.531
epoch 8, step 150,             l2_loss=3441.150, total_loss=1.417,             accuracy=0.562
epoch 8, step 200,             l2_loss=3463.680, total_loss=1.081,             accuracy=0.719
epoch 8, step 250,             l2_loss=3483.100, total_loss=1.648,             accuracy=0.469
epoch 8, step 300,             l2_loss=3503.276, total_loss=1.537,             accuracy=0.625
epoch 8, step 350,             l2_loss=3520.439, total_loss=0.885,             accuracy=0.656
epoch 8, step 400,             l2_loss=3536.891, total_loss=1.651,             accuracy=0.500
epoch 8, step 450,             l2_loss=3561.492, total_loss=1.736,             accuracy=0.500
epoch 8, step 500,             l2_loss=3578.805, total_loss=1.3

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 9, step 0,             l2_loss=3605.349, total_loss=1.420,             accuracy=0.562
epoch 9, step 50,             l2_loss=3624.835, total_loss=0.796,             accuracy=0.781
epoch 9, step 100,             l2_loss=3641.301, total_loss=1.197,             accuracy=0.594
epoch 9, step 150,             l2_loss=3656.838, total_loss=1.064,             accuracy=0.688
epoch 9, step 200,             l2_loss=3675.490, total_loss=1.199,             accuracy=0.688
epoch 9, step 250,             l2_loss=3691.177, total_loss=1.433,             accuracy=0.469
epoch 9, step 300,             l2_loss=3711.055, total_loss=1.400,             accuracy=0.562
epoch 9, step 350,             l2_loss=3725.696, total_loss=0.777,             accuracy=0.781
epoch 9, step 400,             l2_loss=3739.050, total_loss=1.395,             accuracy=0.625
epoch 9, step 450,             l2_loss=3758.344, total_loss=0.986,             accuracy=0.719
epoch 9, step 500,             l2_loss=3774.152, total_loss=1.4

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 10, step 0,             l2_loss=3794.368, total_loss=1.120,             accuracy=0.688
epoch 10, step 50,             l2_loss=3814.518, total_loss=0.669,             accuracy=0.719
epoch 10, step 100,             l2_loss=3832.568, total_loss=1.404,             accuracy=0.562
epoch 10, step 150,             l2_loss=3848.726, total_loss=0.864,             accuracy=0.750
epoch 10, step 200,             l2_loss=3868.152, total_loss=1.058,             accuracy=0.625
epoch 10, step 250,             l2_loss=3884.509, total_loss=0.966,             accuracy=0.719
epoch 10, step 300,             l2_loss=3900.419, total_loss=1.065,             accuracy=0.688
epoch 10, step 350,             l2_loss=3915.320, total_loss=0.669,             accuracy=0.812
epoch 10, step 400,             l2_loss=3928.321, total_loss=1.196,             accuracy=0.812
epoch 10, step 450,             l2_loss=3947.436, total_loss=0.956,             accuracy=0.656
epoch 10, step 500,             l2_loss=3963.096, tot

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 11, step 0,             l2_loss=3983.469, total_loss=0.943,             accuracy=0.688
epoch 11, step 50,             l2_loss=4000.359, total_loss=0.737,             accuracy=0.719
epoch 11, step 100,             l2_loss=4015.674, total_loss=1.100,             accuracy=0.656
epoch 11, step 150,             l2_loss=4029.589, total_loss=0.802,             accuracy=0.656
epoch 11, step 200,             l2_loss=4046.073, total_loss=1.144,             accuracy=0.594
epoch 11, step 250,             l2_loss=4059.652, total_loss=0.811,             accuracy=0.688
epoch 11, step 300,             l2_loss=4079.741, total_loss=0.942,             accuracy=0.719
epoch 11, step 350,             l2_loss=4092.622, total_loss=0.545,             accuracy=0.906
epoch 11, step 400,             l2_loss=4107.122, total_loss=1.455,             accuracy=0.562
epoch 11, step 450,             l2_loss=4123.946, total_loss=1.308,             accuracy=0.656
epoch 11, step 500,             l2_loss=4141.405, tot

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 12, step 0,             l2_loss=4159.882, total_loss=1.077,             accuracy=0.688
epoch 12, step 50,             l2_loss=4174.384, total_loss=0.560,             accuracy=0.844
epoch 12, step 100,             l2_loss=4186.660, total_loss=1.098,             accuracy=0.781
epoch 12, step 150,             l2_loss=4200.418, total_loss=0.916,             accuracy=0.688
epoch 12, step 200,             l2_loss=4218.945, total_loss=0.878,             accuracy=0.719
epoch 12, step 250,             l2_loss=4233.623, total_loss=0.964,             accuracy=0.719
epoch 12, step 300,             l2_loss=4248.164, total_loss=0.848,             accuracy=0.719
epoch 12, step 350,             l2_loss=4262.421, total_loss=0.540,             accuracy=0.781
epoch 12, step 400,             l2_loss=4276.633, total_loss=1.133,             accuracy=0.656
epoch 12, step 450,             l2_loss=4290.013, total_loss=0.659,             accuracy=0.812
epoch 12, step 500,             l2_loss=4302.478, tot

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 13, step 0,             l2_loss=4317.831, total_loss=1.029,             accuracy=0.688
epoch 13, step 50,             l2_loss=4334.532, total_loss=0.282,             accuracy=0.906
epoch 13, step 100,             l2_loss=4345.997, total_loss=0.922,             accuracy=0.656
epoch 13, step 150,             l2_loss=4356.713, total_loss=0.660,             accuracy=0.812



KeyboardInterrupt: 

In [19]:
CKPT_PATH = 'model_best.ckpt'
model.load_state_dict(torch.load(CKPT_PATH)['state_dict'])
model.eval()
acc = 0
for i, (img_batch, label_batch) in enumerate(valid_loader):
    output = model(img_batch.to(device))
    _, predicted = torch.max(output.cpu().data, 1)
    accuracy = torch.sum(predicted == label_batch.data.view(-1), dtype=torch.float32) / BATCH_SIZE
    acc += accuracy
print('accuracy={}'.format(acc/len(valid_loader)))

accuracy=0.7737226486206055
