# Convolution2d

## Configurations

### Install and import necessary libraries

In [15]:
import pandas as pd
import numpy as np
import cv2
from tqdm.notebook import tqdm
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader,SubsetRandomSampler
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import torchvision.models as models
from torchvision import transforms
from torchsummary import summary

import os
import random
from datetime import datetime

#### GPU or not?

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

#### Path specified

In [17]:
DATA_DIR = 'data'
LABEL_DIR = 'label2.csv'
IMAGE_DIR = os.path.join(DATA_DIR, 'DaanForestPark')
# IMAGE_DIR = os.path.join(DATA_DIR, '17category')

### Hyperparameters

In [18]:
EPOCHES = 20
ZIPSIZE = 4
FILTER_NUMS = 8
FILTER_NUMS2 = 16
CHANNEL_NUMS = 3
KERNEL_SIZE = 13
STRIDE = 1
BATCH_SIZE = 32
NUM_WORKERS = 4
LR = 1e-2
HEIGHT, WEIDTH = 1136, 640
# 340 * 192

In [19]:
# ToPILImage() -> Resize() -> ToTensor()
transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
        ])

### Define the DataSet and the DataLoader

In [27]:
class MyDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        _images, _labels = [], []
        # total amount of dataset 
        _number = 0
        # Reading the categorical file
        label_df = pd.read_csv(label_dir)
        
        # Iterate all files including .jpg inages  
        for subdir, dirs, files in tqdm(os.walk(image_dir)):
            for filename in files:
                corr_label = label_df[label_df['dirpath']==subdir[len(DATA_DIR)+1:]]['label'].values
                if corr_label.size!= 0 and filename.endswith(('jpg')):
                    _images.append(subdir + os.sep + filename)
                    _labels.append(corr_label)
                    _number+=1
        
        # Randomly arrange data pairs
        mapIndexPosition = list(zip(_images, _labels))
        random.shuffle(mapIndexPosition)
        _images, _labels = zip(*mapIndexPosition)

        self._image = iter(_images)
        self._labels = iter(_labels)
        self._number = _number
        self._category = label_df['label'].nunique()
        self.transform = transform
        
    def __len__(self):
        return self._number

    def __getitem__(self, index):    
        img = next(self._image)
        lab = next(self._labels)
        
        img = self._loadimage(img)
        if self.transform:
            img = self.transform(img)
        
        return img, lab
     
    def _categorical(self, label):
        return np.arange(self._category) == label[:,None]
    
    def _loadimage(self, file):
#         image = cv2.imread(file).reshape(HEIGHT, WEIDTH,CHANNEL_NUMS)
        image = cv2.imread(file)
        return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    def _resizeimage(self, img, scale_percent=30):
        width = int(img.shape[1] * scale_percent / 100)
        height = int(img.shape[0] * scale_percent / 100)
        dim = (width, height)
        resized = cv2.resize(img, dim, interpolation=cv2.INTER_AREA)
        return resized
    
    def get_categorical_nums(self):
        return self._category
    

In [28]:
train_dataset = MyDataset(IMAGE_DIR, LABEL_DIR, transform=transform)

valid_size = .1
num_train = len(train_dataset)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, sampler=train_sampler, drop_last=True)
valid_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, sampler=valid_sampler, drop_last=True)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [29]:
# def init_weights(m):
#     if type(m) == nn.Conv2d:
#         nn.init.xavier_uniform(m.weight)
#         m.bias.data.fill_(0.01)
#     if type(m) == nn.Linear:
#         nn.init.uniform_(m.weight)
#         m.bias.data.fill_(0.01)   
# class SimpleCNN(nn.Module):
    
#     def __init__(self, target):
#         super(SimpleCNN, self).__init__()
#         # Input size (3, 1136, 640)
# #         self.imgzipper  = nn.AvgPool2d(kernel_size=ZIPSIZE)
#         # Input size (3, 284, 160)
#         self.conv1 = nn.Sequential(
#             nn.Conv2d(in_channels=CHANNEL_NUMS,
#                         out_channels=FILTER_NUMS,
#                         kernel_size=KERNEL_SIZE,
#                         stride=STRIDE,
#                         padding=(KERNEL_SIZE-STRIDE)//2 # padding=(kernel_size-stride)/2 -> original size
#                     ),
#             nn.Dropout(0.5),
#             nn.ReLU(),
#             # (8, 1136, 640)
#             nn.MaxPool2d(kernel_size=KERNEL_SIZE)
#             # (8, 87, 49)
#             # zipper (8, 21, 12)
#         ).apply(init_weights)
        
#         # (8, 87, 49)
#         self.conv2 = nn.Sequential(
#             nn.Conv2d(in_channels=FILTER_NUMS,
#                         out_channels=FILTER_NUMS2,
#                         kernel_size=KERNEL_SIZE,
#                         stride=STRIDE,
#                         padding=(KERNEL_SIZE-STRIDE)//2 # padding=(kernel_size-stride)/2 -> original size
#                     ),
#             nn.Dropout(0.5),
#             nn.ReLU(),
#             # (16, 87, 49)
#             nn.MaxPool2d(kernel_size=5)
#             # (16, 6, 3)
#         ).apply(init_weights)
#         self.MLP = nn.Sequential(
#             nn.Linear(128, 81),
#             nn.ReLU(),
#             nn.Linear(81, 81),
#             nn.ReLU(),
#             nn.Linear(81, target)
#         ).apply(init_weights)
#     def forward(self, x):
#         x = self.imgzipper(x)
#         x = self.conv1(x)
#         x = self.conv2(x)
#         x = x.view(x.size(0), -1)
#         x = self.MLP(x)
#         return x

### ResNet18

In [30]:
# model = SimpleCNN(train_dataset.get_categorical_nums()).to(device)
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 85)

In [31]:
# channels, H, W
model = model_ft.to(device=device)
summary(model, input_size=(CHANNEL_NUMS, 340, 192))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 64, 170, 96]           9,408
       BatchNorm2d-2          [-1, 64, 170, 96]             128
              ReLU-3          [-1, 64, 170, 96]               0
         MaxPool2d-4           [-1, 64, 85, 48]               0
            Conv2d-5           [-1, 64, 85, 48]          36,864
       BatchNorm2d-6           [-1, 64, 85, 48]             128
              ReLU-7           [-1, 64, 85, 48]               0
            Conv2d-8           [-1, 64, 85, 48]          36,864
       BatchNorm2d-9           [-1, 64, 85, 48]             128
             ReLU-10           [-1, 64, 85, 48]               0
       BasicBlock-11           [-1, 64, 85, 48]               0
           Conv2d-12           [-1, 64, 85, 48]          36,864
      BatchNorm2d-13           [-1, 64, 85, 48]             128
             ReLU-14           [-1, 64,

In [32]:
def clip_gradient(optimizer, grad_clip):
    """
    Clips gradients computed during backpropagation to avoid explosion of gradients.

    :param optimizer: optimizer with the gradients to be clipped
    :param grad_clip: clip value
    """
    for group in optimizer.param_groups:
        for param in group["params"]:
            if param.grad is not None:
                param.grad.data.clamp_(-grad_clip, grad_clip)

In [33]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [34]:
import shutil 
def save_checkpoint(state, filename='model.ckpt'):
    torch.save(state, filename)
    shutil.copyfile(filename, 'model_best.ckpt')

In [35]:
optimizer = torch.optim.Adam(model.parameters(), lr=LR)   # optimize all cnn parameters
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=0, verbose=True)
criterion = nn.CrossEntropyLoss().to(device=device)

# early stopping
min_val_loss = np.Inf
patience = 3
global_step = 1
writer = SummaryWriter('runs/experiment_{}'.format(datetime.now().strftime("%f")))

for epoch in range(EPOCHES):
    val_loss = 0
    model_ft.train()
    for i, (img_batch, label_batch) in tqdm(enumerate(train_loader)):    
        optimizer.zero_grad()      
        img_batch = img_batch.to(device=device)
        label_batch = label_batch.to(device=device)  
        output = model_ft(img_batch)
        loss = criterion(output, label_batch.squeeze())
        clip_regurization = nn.utils.clip_grad_norm_(model.parameters(), 0.9) 
#         regularization_loss = 0      
#         for param in model.parameters():
#             regularization_loss += torch.sum(torch.abs(param))
#         loss = loss + 0.1 * regularization_loss
        loss = loss + clip_regurization
        loss.backward()
        optimizer.step()

        # Compute accuracy
        _, predicted = torch.max(output.cpu().data, 1)
        accuracy = torch.sum(predicted == label_batch.cpu().data.view(-1), dtype=torch.float32) / BATCH_SIZE
        
        # Write tensorboard
        writer.add_scalar('Accuracy/train', accuracy.item(), global_step)
        writer.add_scalar('Loss/train', loss.item(), global_step)
        writer.add_scalar('LR/train', get_lr(optimizer), global_step)
                
        global_step += 1
#         val_loss += loss.item()
        
        if i % 20== 0:
            print('epoch {}, step {}, loss={}, accuracy={}'.format(epoch+1, i, loss.item(), accuracy.item()))         
    
#     val_loss = val_loss / len(train_loader)  
#     scheduler.step(val_loss)
    
    model.eval()
    eval_loss = 0
    for i, (img_batch, label_batch) in enumerate(valid_loader):
        output = model(img_batch.to(device))
        _, predicted = torch.max(output.cpu().data, 1)
        loss = criterion(output, label_batch.to(device).squeeze())
        accuracy = torch.sum(predicted == label_batch.data.view(-1), dtype=torch.float32) / BATCH_SIZE
        eval_loss += loss.item()
        writer.add_scalar('Accuracy/valid', accuracy.item(), global_step)
        writer.add_scalar('Loss/valid', loss.item(), global_step)
    
    eval_loss = eval_loss / len(valid_loader)
    scheduler.step(eval_loss)
    
    print('epoch {}, val_loss={}'.format(epoch+1, eval_loss))         

    ## Early Stopping
    if eval_loss < min_val_loss:
        torch.save(model, 'ckpt/resNet_{}_compose.ckpt'.format(epoch+1))
#         save_checkpoint({
#             'epoch': epoch+1,
#             'state_dict': model.state_dict(),
#             'best_loss': val_loss,
#             'optimizer' :optimizer.state_dict(),
#             }, 'conv_{}.ckpt'.format(epoch+1))
        min_val_loss = eval_loss
    else:
        patience-=1
    if patience == 0:
        print('Early stopping')
        break
        
writer.close()
print('Finish all training !')

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 1, step 0, loss=4.671567916870117, accuracy=0.03125
epoch 1, step 20, loss=6.5467095375061035, accuracy=0.03125
epoch 1, step 40, loss=5.381259441375732, accuracy=0.03125
epoch 1, step 60, loss=5.342200756072998, accuracy=0.0625
epoch 1, step 80, loss=4.601925373077393, accuracy=0.03125
epoch 1, step 100, loss=3.9714481830596924, accuracy=0.0625
epoch 1, step 120, loss=4.118907451629639, accuracy=0.0625
epoch 1, step 140, loss=4.026341915130615, accuracy=0.03125
epoch 1, step 160, loss=4.102416515350342, accuracy=0.09375
epoch 1, step 180, loss=3.995544910430908, accuracy=0.125
epoch 1, step 200, loss=3.7514302730560303, accuracy=0.0625
epoch 1, step 220, loss=3.7040164470672607, accuracy=0.09375
epoch 1, step 240, loss=3.688523530960083, accuracy=0.15625
epoch 1, step 260, loss=3.9163589477539062, accuracy=0.09375
epoch 1, step 280, loss=3.6367268562316895, accuracy=0.125
epoch 1, step 300, loss=3.7971279621124268, accuracy=0.15625
epoch 1, step 320, loss=3.50956130027771, accur

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 2, step 0, loss=3.3742737770080566, accuracy=0.125
epoch 2, step 20, loss=3.097829580307007, accuracy=0.0625
epoch 2, step 40, loss=3.063203811645508, accuracy=0.1875
epoch 2, step 60, loss=2.8077919483184814, accuracy=0.21875
epoch 2, step 80, loss=3.7494421005249023, accuracy=0.0625
epoch 2, step 100, loss=3.0976271629333496, accuracy=0.21875
epoch 2, step 120, loss=3.4226791858673096, accuracy=0.09375
epoch 2, step 140, loss=3.2383439540863037, accuracy=0.09375
epoch 2, step 160, loss=3.17246675491333, accuracy=0.125
epoch 2, step 180, loss=3.6158957481384277, accuracy=0.15625
epoch 2, step 200, loss=3.100374698638916, accuracy=0.15625
epoch 2, step 220, loss=3.1065409183502197, accuracy=0.25
epoch 2, step 240, loss=3.112060308456421, accuracy=0.21875
epoch 2, step 260, loss=3.355567455291748, accuracy=0.15625
epoch 2, step 280, loss=2.8525829315185547, accuracy=0.21875
epoch 2, step 300, loss=2.863534688949585, accuracy=0.3125
epoch 2, step 320, loss=2.798459768295288, accura

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 3, step 0, loss=2.579862117767334, accuracy=0.25
epoch 3, step 20, loss=2.6845734119415283, accuracy=0.1875
epoch 3, step 40, loss=2.589587688446045, accuracy=0.28125
epoch 3, step 60, loss=2.313965320587158, accuracy=0.3125
epoch 3, step 80, loss=3.583925247192383, accuracy=0.09375
epoch 3, step 100, loss=2.5829851627349854, accuracy=0.1875
epoch 3, step 120, loss=2.9595861434936523, accuracy=0.28125
epoch 3, step 140, loss=2.672018051147461, accuracy=0.21875
epoch 3, step 160, loss=2.6360838413238525, accuracy=0.40625
epoch 3, step 180, loss=3.342761993408203, accuracy=0.125
epoch 3, step 200, loss=2.5084497928619385, accuracy=0.3125
epoch 3, step 220, loss=2.72106671333313, accuracy=0.375
epoch 3, step 240, loss=2.7086031436920166, accuracy=0.40625
epoch 3, step 260, loss=2.5347771644592285, accuracy=0.34375
epoch 3, step 280, loss=2.2406883239746094, accuracy=0.40625
epoch 3, step 300, loss=2.691155195236206, accuracy=0.375
epoch 3, step 320, loss=2.1123046875, accuracy=0.281

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 4, step 0, loss=2.147911310195923, accuracy=0.3125
epoch 4, step 20, loss=2.1254758834838867, accuracy=0.28125
epoch 4, step 40, loss=1.6722394227981567, accuracy=0.53125
epoch 4, step 60, loss=1.701303482055664, accuracy=0.6875
epoch 4, step 80, loss=2.5288493633270264, accuracy=0.40625
epoch 4, step 100, loss=2.0208992958068848, accuracy=0.375
epoch 4, step 120, loss=2.347501516342163, accuracy=0.25
epoch 4, step 140, loss=1.8022918701171875, accuracy=0.34375
epoch 4, step 160, loss=2.2203285694122314, accuracy=0.375
epoch 4, step 180, loss=2.680549383163452, accuracy=0.3125
epoch 4, step 200, loss=2.2338039875030518, accuracy=0.4375
epoch 4, step 220, loss=2.0390448570251465, accuracy=0.28125
epoch 4, step 240, loss=1.8643133640289307, accuracy=0.53125
epoch 4, step 260, loss=2.1024093627929688, accuracy=0.40625
epoch 4, step 280, loss=1.5703718662261963, accuracy=0.4375
epoch 4, step 300, loss=2.194206476211548, accuracy=0.40625
epoch 4, step 320, loss=1.7361360788345337, acc

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 5, step 0, loss=1.878986120223999, accuracy=0.375
epoch 5, step 20, loss=1.9198325872421265, accuracy=0.53125
epoch 5, step 40, loss=1.5494792461395264, accuracy=0.625
epoch 5, step 60, loss=1.2075430154800415, accuracy=0.65625
epoch 5, step 80, loss=2.4435997009277344, accuracy=0.5
epoch 5, step 100, loss=1.6132625341415405, accuracy=0.5625
epoch 5, step 120, loss=2.4430055618286133, accuracy=0.3125
epoch 5, step 140, loss=1.6947693824768066, accuracy=0.375
epoch 5, step 160, loss=1.7711858749389648, accuracy=0.40625
epoch 5, step 180, loss=2.2767140865325928, accuracy=0.3125
epoch 5, step 200, loss=2.2410836219787598, accuracy=0.4375
epoch 5, step 220, loss=1.834011435508728, accuracy=0.46875
epoch 5, step 240, loss=1.6932880878448486, accuracy=0.5625
epoch 5, step 260, loss=1.8983572721481323, accuracy=0.40625
epoch 5, step 280, loss=1.542979121208191, accuracy=0.53125
epoch 5, step 300, loss=1.6023783683776855, accuracy=0.5625
epoch 5, step 320, loss=1.9003419876098633, accur

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 6, step 0, loss=1.4143675565719604, accuracy=0.53125
epoch 6, step 20, loss=1.452266812324524, accuracy=0.71875
epoch 6, step 40, loss=1.6065306663513184, accuracy=0.6875
epoch 6, step 60, loss=1.2399283647537231, accuracy=0.625
epoch 6, step 80, loss=1.6992249488830566, accuracy=0.5625
epoch 6, step 100, loss=1.5151106119155884, accuracy=0.5625
epoch 6, step 120, loss=2.130459785461426, accuracy=0.4375
epoch 6, step 140, loss=1.4047186374664307, accuracy=0.53125
epoch 6, step 160, loss=1.4300258159637451, accuracy=0.59375
epoch 6, step 180, loss=1.8323861360549927, accuracy=0.4375
epoch 6, step 200, loss=1.942037582397461, accuracy=0.4375
epoch 6, step 220, loss=1.4173675775527954, accuracy=0.625
epoch 6, step 240, loss=1.4227826595306396, accuracy=0.625
epoch 6, step 260, loss=1.4679360389709473, accuracy=0.59375
epoch 6, step 280, loss=1.1961770057678223, accuracy=0.625
epoch 6, step 300, loss=1.4692515134811401, accuracy=0.625
epoch 6, step 320, loss=1.2574423551559448, accur

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 7, step 0, loss=1.3905975818634033, accuracy=0.5625
epoch 7, step 20, loss=0.9200831055641174, accuracy=0.78125
epoch 7, step 40, loss=1.3943487405776978, accuracy=0.59375
epoch 7, step 60, loss=1.0035028457641602, accuracy=0.71875
epoch 7, step 80, loss=1.9390852451324463, accuracy=0.5
epoch 7, step 100, loss=1.0279790163040161, accuracy=0.6875
epoch 7, step 120, loss=1.7618608474731445, accuracy=0.34375
epoch 7, step 140, loss=1.0907901525497437, accuracy=0.75
epoch 7, step 160, loss=1.2691212892532349, accuracy=0.59375
epoch 7, step 180, loss=1.5272905826568604, accuracy=0.53125
epoch 7, step 200, loss=1.7150086164474487, accuracy=0.5
epoch 7, step 220, loss=1.5620183944702148, accuracy=0.625
epoch 7, step 240, loss=1.3125243186950684, accuracy=0.65625
epoch 7, step 260, loss=1.3492952585220337, accuracy=0.65625
epoch 7, step 280, loss=0.869901716709137, accuracy=0.71875
epoch 7, step 300, loss=0.8641990423202515, accuracy=0.75
epoch 7, step 320, loss=1.0220186710357666, accur

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 8, step 0, loss=0.9910120964050293, accuracy=0.65625
epoch 8, step 20, loss=1.4238473176956177, accuracy=0.625
epoch 8, step 40, loss=1.0123062133789062, accuracy=0.625
epoch 8, step 60, loss=0.9814603328704834, accuracy=0.75
epoch 8, step 80, loss=1.2214009761810303, accuracy=0.59375
epoch 8, step 100, loss=1.015940546989441, accuracy=0.625
epoch 8, step 120, loss=1.0569865703582764, accuracy=0.6875
epoch 8, step 140, loss=1.2031550407409668, accuracy=0.71875
epoch 8, step 160, loss=1.2540826797485352, accuracy=0.6875
epoch 8, step 180, loss=1.59602689743042, accuracy=0.375
epoch 8, step 200, loss=1.44540536403656, accuracy=0.59375
epoch 8, step 220, loss=0.9929584264755249, accuracy=0.59375
epoch 8, step 240, loss=1.2782355546951294, accuracy=0.625
epoch 8, step 260, loss=1.065795660018921, accuracy=0.71875
epoch 8, step 280, loss=0.4706387519836426, accuracy=0.84375
epoch 8, step 300, loss=0.6425600051879883, accuracy=0.78125
epoch 8, step 320, loss=0.9390038251876831, accurac

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 9, step 0, loss=0.8609794974327087, accuracy=0.71875
epoch 9, step 20, loss=0.8460136651992798, accuracy=0.75
epoch 9, step 40, loss=1.2633569240570068, accuracy=0.65625
epoch 9, step 60, loss=0.8856645822525024, accuracy=0.84375
epoch 9, step 80, loss=1.5712525844573975, accuracy=0.5625
epoch 9, step 100, loss=1.183443546295166, accuracy=0.625
epoch 9, step 120, loss=0.8830428123474121, accuracy=0.75
epoch 9, step 140, loss=1.2420364618301392, accuracy=0.625
epoch 9, step 160, loss=0.7978774309158325, accuracy=0.8125
epoch 9, step 180, loss=1.0106369256973267, accuracy=0.625
epoch 9, step 200, loss=0.9383394122123718, accuracy=0.65625
epoch 9, step 220, loss=1.3423166275024414, accuracy=0.65625
epoch 9, step 240, loss=1.00179123878479, accuracy=0.71875
epoch 9, step 260, loss=1.3496997356414795, accuracy=0.59375
epoch 9, step 280, loss=0.6681790351867676, accuracy=0.78125
epoch 9, step 300, loss=0.8995953798294067, accuracy=0.75
epoch 9, step 320, loss=0.7750880122184753, accura

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 10, step 0, loss=0.6417156457901001, accuracy=0.78125
epoch 10, step 20, loss=0.9771443009376526, accuracy=0.71875
epoch 10, step 40, loss=0.9411560893058777, accuracy=0.65625
epoch 10, step 60, loss=0.6277148723602295, accuracy=0.71875
epoch 10, step 80, loss=0.9328272938728333, accuracy=0.65625
epoch 10, step 100, loss=0.527763843536377, accuracy=0.8125
epoch 10, step 120, loss=0.862009584903717, accuracy=0.78125
epoch 10, step 140, loss=0.9952446818351746, accuracy=0.65625
epoch 10, step 160, loss=1.2737388610839844, accuracy=0.71875
epoch 10, step 180, loss=0.812660813331604, accuracy=0.65625
epoch 10, step 200, loss=1.1284786462783813, accuracy=0.75
epoch 10, step 220, loss=1.151834487915039, accuracy=0.59375
epoch 10, step 240, loss=0.8821616172790527, accuracy=0.6875
epoch 10, step 260, loss=1.1319342851638794, accuracy=0.71875
epoch 10, step 280, loss=0.42669951915740967, accuracy=0.875
epoch 10, step 300, loss=0.6595611572265625, accuracy=0.8125
epoch 10, step 320, loss=

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 11, step 0, loss=0.8188281059265137, accuracy=0.8125
epoch 11, step 20, loss=1.1006977558135986, accuracy=0.625
epoch 11, step 40, loss=0.7519148588180542, accuracy=0.71875
epoch 11, step 60, loss=0.5427334308624268, accuracy=0.84375
epoch 11, step 80, loss=0.5198497772216797, accuracy=0.84375
epoch 11, step 100, loss=0.46386831998825073, accuracy=0.8125
epoch 11, step 120, loss=1.3164974451065063, accuracy=0.46875
epoch 11, step 140, loss=0.9705235958099365, accuracy=0.78125
epoch 11, step 160, loss=0.5924906730651855, accuracy=0.84375
epoch 11, step 180, loss=0.932207465171814, accuracy=0.71875
epoch 11, step 200, loss=1.1032992601394653, accuracy=0.6875
epoch 11, step 220, loss=0.8325086832046509, accuracy=0.78125
epoch 11, step 240, loss=0.7403455376625061, accuracy=0.78125
epoch 11, step 260, loss=0.8395848274230957, accuracy=0.71875
epoch 11, step 280, loss=0.8282145261764526, accuracy=0.78125
epoch 11, step 300, loss=0.38563960790634155, accuracy=0.875
epoch 11, step 320, 

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 12, step 0, loss=0.726274847984314, accuracy=0.78125
epoch 12, step 20, loss=0.7743198871612549, accuracy=0.75
epoch 12, step 40, loss=0.42054471373558044, accuracy=0.875
epoch 12, step 60, loss=0.5211969614028931, accuracy=0.84375
epoch 12, step 80, loss=0.4188309907913208, accuracy=0.875
epoch 12, step 100, loss=0.503167986869812, accuracy=0.78125
epoch 12, step 120, loss=0.6776552200317383, accuracy=0.75
epoch 12, step 140, loss=0.6904396414756775, accuracy=0.8125
epoch 12, step 160, loss=0.4916115403175354, accuracy=0.875
epoch 12, step 180, loss=0.6215370297431946, accuracy=0.75
epoch 12, step 200, loss=1.2064335346221924, accuracy=0.625
epoch 12, step 220, loss=0.46538108587265015, accuracy=0.875
epoch 12, step 240, loss=0.7371988892555237, accuracy=0.78125
epoch 12, step 260, loss=0.8843119144439697, accuracy=0.8125
epoch 12, step 280, loss=0.5459361672401428, accuracy=0.84375
epoch 12, step 300, loss=0.7316473722457886, accuracy=0.78125
epoch 12, step 320, loss=1.03103017

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 13, step 0, loss=0.29041939973831177, accuracy=0.90625
epoch 13, step 20, loss=0.9072789549827576, accuracy=0.84375
epoch 13, step 40, loss=0.5607641935348511, accuracy=0.84375
epoch 13, step 60, loss=0.3788691759109497, accuracy=0.875
epoch 13, step 80, loss=0.3634062111377716, accuracy=0.9375
epoch 13, step 100, loss=0.3827592134475708, accuracy=0.875
epoch 13, step 120, loss=0.9552939534187317, accuracy=0.71875
epoch 13, step 140, loss=0.5062313079833984, accuracy=0.84375
epoch 13, step 160, loss=0.6559087038040161, accuracy=0.75
epoch 13, step 180, loss=0.5442779064178467, accuracy=0.78125
epoch 13, step 200, loss=0.3603441119194031, accuracy=0.84375
epoch 13, step 220, loss=0.44518256187438965, accuracy=0.875
epoch 13, step 240, loss=0.6436267495155334, accuracy=0.8125
epoch 13, step 260, loss=0.8646172881126404, accuracy=0.71875
epoch 13, step 280, loss=0.28785213828086853, accuracy=0.9375
epoch 13, step 300, loss=0.2279684990644455, accuracy=0.9375
epoch 13, step 320, loss

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 14, step 0, loss=0.5227958559989929, accuracy=0.78125
epoch 14, step 20, loss=0.44250357151031494, accuracy=0.8125
epoch 14, step 40, loss=0.44137880206108093, accuracy=0.8125
epoch 14, step 60, loss=0.3505966067314148, accuracy=0.90625
epoch 14, step 80, loss=0.28216707706451416, accuracy=0.875
epoch 14, step 100, loss=0.3717271089553833, accuracy=0.84375
epoch 14, step 120, loss=0.575495719909668, accuracy=0.78125
epoch 14, step 140, loss=0.5702390670776367, accuracy=0.8125
epoch 14, step 160, loss=0.350790411233902, accuracy=0.90625
epoch 14, step 180, loss=0.4424876868724823, accuracy=0.84375
epoch 14, step 200, loss=0.40065857768058777, accuracy=0.78125
epoch 14, step 220, loss=0.38664937019348145, accuracy=0.8125
epoch 14, step 240, loss=0.5138797163963318, accuracy=0.78125
epoch 14, step 260, loss=0.6058288216590881, accuracy=0.78125
epoch 14, step 280, loss=0.40364667773246765, accuracy=0.90625
epoch 14, step 300, loss=0.5628494620323181, accuracy=0.8125
epoch 14, step 32

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 15, step 0, loss=0.7606164216995239, accuracy=0.75
epoch 15, step 20, loss=0.4262882471084595, accuracy=0.84375
epoch 15, step 40, loss=0.731155276298523, accuracy=0.8125
epoch 15, step 60, loss=0.16557937860488892, accuracy=0.90625
epoch 15, step 80, loss=0.16377417743206024, accuracy=0.9375
epoch 15, step 100, loss=0.3474748134613037, accuracy=0.9375
epoch 15, step 120, loss=0.45050346851348877, accuracy=0.84375
epoch 15, step 140, loss=0.483173131942749, accuracy=0.84375
epoch 15, step 160, loss=0.40129369497299194, accuracy=0.8125
epoch 15, step 180, loss=0.2126186639070511, accuracy=0.90625
epoch 15, step 200, loss=0.3793339133262634, accuracy=0.875
epoch 15, step 220, loss=0.365367591381073, accuracy=0.875
epoch 15, step 240, loss=0.5208755731582642, accuracy=0.84375
epoch 15, step 260, loss=0.5872870683670044, accuracy=0.75
epoch 15, step 280, loss=0.44610559940338135, accuracy=0.875
epoch 15, step 300, loss=0.27168160676956177, accuracy=0.90625
epoch 15, step 320, loss=0.

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 16, step 0, loss=0.3891902565956116, accuracy=0.875
epoch 16, step 20, loss=0.532086968421936, accuracy=0.8125
epoch 16, step 40, loss=0.20959125459194183, accuracy=0.90625
epoch 16, step 60, loss=0.24494829773902893, accuracy=0.90625
epoch 16, step 80, loss=0.17341430485248566, accuracy=0.90625
epoch 16, step 100, loss=0.371090292930603, accuracy=0.875
epoch 16, step 120, loss=0.3871783912181854, accuracy=0.90625
epoch 16, step 140, loss=0.5676125884056091, accuracy=0.84375
epoch 16, step 160, loss=0.384351521730423, accuracy=0.875
epoch 16, step 180, loss=0.4249950647354126, accuracy=0.84375
epoch 16, step 200, loss=0.7180385589599609, accuracy=0.8125
epoch 16, step 220, loss=0.33886852860450745, accuracy=0.875
epoch 16, step 240, loss=0.531898021697998, accuracy=0.84375
epoch 16, step 260, loss=0.4720737636089325, accuracy=0.8125
epoch 16, step 280, loss=0.37970882654190063, accuracy=0.875
epoch 16, step 300, loss=0.12875989079475403, accuracy=0.96875
epoch 16, step 320, loss=

In [37]:
model.eval()
for i, (img_batch, label_batch) in enumerate(valid_loader):
    output = model(img_batch.to(device))
    _, predicted = torch.max(output.cpu().data, 1)
    accuracy = torch.sum(predicted == label_batch.data.view(-1), dtype=torch.float32) / BATCH_SIZE
    print('accuracy={}'.format(accuracy))

accuracy=0.84375
accuracy=0.8125
accuracy=0.84375
accuracy=0.8125
accuracy=0.875
accuracy=0.84375
accuracy=0.875
accuracy=0.78125
accuracy=0.875
accuracy=0.8125
accuracy=0.875
accuracy=0.8125
accuracy=0.8125
accuracy=0.78125
accuracy=0.84375
accuracy=0.84375
accuracy=0.84375
accuracy=0.8125
accuracy=0.84375
accuracy=0.8125
accuracy=0.84375
accuracy=0.875
accuracy=0.90625
accuracy=0.90625
accuracy=0.9375
accuracy=0.90625
accuracy=0.875
accuracy=0.96875
accuracy=0.75
accuracy=0.8125
accuracy=0.75
accuracy=0.8125
accuracy=0.84375
accuracy=0.90625
accuracy=0.96875
accuracy=0.90625
accuracy=0.84375
accuracy=0.9375
accuracy=0.84375
accuracy=0.78125
accuracy=0.84375
accuracy=0.8125
accuracy=0.84375
accuracy=0.84375
accuracy=0.8125
accuracy=0.6875
accuracy=0.84375
accuracy=0.84375
accuracy=0.90625
accuracy=0.90625
accuracy=0.75
accuracy=0.84375
accuracy=0.96875
accuracy=0.96875
accuracy=0.9375
accuracy=0.9375
accuracy=0.84375
accuracy=0.96875
accuracy=0.84375
accuracy=0.9375
accuracy=0.84375
a