In [1]:
from os import listdir
from os.path import join, splitext, basename
import glob

import torch.utils.data as data
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from PIL import Image
import PIL
import numpy as np
import torch
import torchvision
import random
import torchvision.models as models
import torch.optim as optim
import torch.nn as nn

from imgaug import augmenters as iaa
import imgaug as ia

In [2]:
code2names = {
    0:"Bread",
    1:"Dairy_product",
    2:"Dessert",
    3:"Egg",
    4:"Fried_food",
    5:"Meat",
    6:"Noodles",
    7:"Rice",
    8:"Seafood",
    9:"Soup",
    10:"Vegetable_fruit"
}

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])

def load_img(filepath):
    img = Image.open(filepath)
    return img


In [3]:
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib as mpl
mpl.rcParams['axes.grid'] = False
mpl.rcParams['image.interpolation'] = 'nearest'
mpl.rcParams['figure.figsize'] = 15, 25
def show_dataset(dataset, n=6):
    img = np.vstack((np.hstack((np.asarray(dataset[i][0]) for _ in range(n))) for i in range(5)))
    plt.imshow(img)
    plt.axis('off')

In [4]:
class ImgAugTransform:
    def __init__(self):
        self.aug = iaa.Sequential([
            iaa.Scale((224, 224)),
            iaa.Sometimes(0.25, iaa.GaussianBlur(sigma=(0, 3.0))),
            iaa.Fliplr(0.5),
            iaa.Affine(rotate=(-20, 20), mode='symmetric'),
            iaa.Sometimes(0.25,
                      iaa.OneOf([iaa.Dropout(p=(0, 0.1)),
                                 iaa.CoarseDropout(0.1, size_percent=0.5)])),  # 對batch中的一部分圖片應用一部分Augmenters,剩下的圖片應用另外的Augmenters。
            iaa.AddToHueAndSaturation(value=(-10, 10), per_channel=True)  # 即修改色調和飽和度
        ])
      
    def __call__(self, img):
        img = np.array(img)
        return self.aug.augment_image(img)

In [5]:
def input_transform():
    return transforms.Compose([
        ImgAugTransform(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

def not_train_transform():
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

In [6]:
class Food11Dataset(data.Dataset):
    def __init__(self, image_dir, input_transform=input_transform, is_train=False):
        super(Food11Dataset, self).__init__()
        self.path_pattern = image_dir + '/**/*.*'
        self.files_list = glob.glob(self.path_pattern, recursive=True)
        self.datapath = image_dir
        self.image_filenames = []
        self.classes_file_name = {}
        self.num_per_classes = {}
        self.class_name_list = []
        for file in self.files_list:
            if is_image_file(file):
                self.image_filenames.append(file)
                class_name = int(basename(file).split("_")[0])
                if class_name not in self.class_name_list:
                    self.class_name_list.append(class_name)
                if class_name in self.num_per_classes:
                    self.num_per_classes[class_name] += 1
                    self.classes_file_name[class_name].append(file)
                else:
                    self.num_per_classes[class_name] = 1
                    self.classes_file_name[class_name] = []
                    self.classes_file_name[class_name].append(file)
                    

        self.input_transform = input_transform
        self.is_train = is_train

    def __getitem__(self, index):
        # TODO [Lab 2-1] Try to embed third-party augmentation functions into pytorch flow
        input_file = self.image_filenames[index]
        input = load_img(input_file)
        if self.input_transform:
            input = self.input_transform()(input)
        label = basename(self.image_filenames[index])
        label = int(label.split("_")[0])
        return input, label

    def __len__(self):
        return len(self.image_filenames)

    def show_details(self):
        for key in sorted(self.num_per_classes.keys()):
            print("{:<8}|{:<20}|{:<12}".format(
                key,
                code2names[key],
                self.num_per_classes[key]
            ))
    
    ''' TODO [Lab 2-1]
    #please add a new function "augmentation(self, wts)"
    #it can change the number of data according to weight of each category
    #"weight" represents the ratio comparing with the amount of original set
    #if the weight > 100, we create new data by copying
    #if the weight < 100, we will delete the original data
    #[hint]you only need to edit the "self.image_filenames" 
ODOT [Lab 2-1]'''
    
    def augmentation(self):
        if self.is_train:
            pass

    def augmentation(self, wts): 
        outsample_num = {}
        if self.is_train:
            for w in range(len(wts)):
                outsample_num[w] = (int(self.num_per_classes[w] * wts[w]))
        
            for class_name in self.class_name_list:
                while self.num_per_classes[class_name] < outsample_num[class_name]:
                    random_file = random.sample(self.classes_file_name[class_name], 1)[0]
                    self.image_filenames.append(random_file)
                    self.num_per_classes[class_name] += 1
                    self.classes_file_name[class_name].append(random_file)
                while self.num_per_classes[class_name] > outsample_num[class_name]:
                    random_file = random.sample(self.classes_file_name[class_name], 1)[0]
                    self.image_filenames.remove(random_file)
                    self.num_per_classes[class_name] -= 1
                    self.classes_file_name[class_name].remove(random_file)

                    
                       

In [7]:
def WRSampler(dataset, wts):
    class_name_list = dataset.class_name_list
    num_per_classes = dataset.num_per_classes
    
    each_data_wts = []
    for class_name in class_name_list:
        class_item_num = num_per_classes[class_name]
        for i in range(class_item_num):
            each_data_wts.append(wts[class_name])
    
    sampler = torch.utils.data.sampler.WeightedRandomSampler(each_data_wts, len(each_data_wts), replacement=True)
    
    return sampler

In [8]:
def data_loading(loader, dataset):

    num_per_classes = {}
    for batch_idx, (data, label) in enumerate(loader):
        for l in label:
            if l.item() in num_per_classes:
                num_per_classes[l.item()] += 1
            else:
                num_per_classes[l.item()] = 1

    print("----------------------------------------------------------------------------------")
    print("Dataset - ", dataset.datapath)
    print("{:<20}|{:<15}|{:<15}".format("class_name", "bf. loading", "af. loading"))
    for key in sorted(num_per_classes.keys()):
        print("{:<20}|{:<15}|{:<15}".format(
            code2names[key],
            dataset.num_per_classes[key],
            num_per_classes[key]
        ))


In [9]:
def main():
    train_datapath = "./training"
    valid_datapath = "./validation"
    test_datapath = "./evaluation"

    train_dataset = Food11Dataset(train_datapath, input_transform=input_transform, is_train=True)
    valid_dataset = Food11Dataset(valid_datapath, input_transform=not_train_transform, is_train=False)
    test_dataset = Food11Dataset(test_datapath, input_transform=not_train_transform, is_train=False)

    ''' For [Lab 2-1] debugging
    train_dataset.augmentation()
    '''
    wts = [ 1.5, 3.5, 1, 1.52, 1.76, 1.13, 3.4, 5.35, 1.75, 1, 2.1 ]
    
    train_dataset.augmentation(wts)
    #show_dataset(train_dataset)
    
    print("----------------------------------------------------------------------------------")
    print("Dataset bf. loading - ", train_datapath)
    print(train_dataset.show_details())

    print("----------------------------------------------------------------------------------")
    print("Dataset bf. loading - ", valid_datapath)
    print(valid_dataset.show_details())

    print("----------------------------------------------------------------------------------")
    print("Dataset bf. loading - ", test_datapath)
    print(test_dataset.show_details())
   
    #sampler = WRSampler(train_dataset, wts)
    #sampler = torch.utils.data.sampler.RandomSampler(data_source=train_dataset, replacement=True, num_samples=len(train_dataset))

    train_loader = DataLoader(dataset=train_dataset, num_workers=4, batch_size=8)#, sampler = sampler)
    valid_loader = DataLoader(dataset=valid_dataset, num_workers=4, batch_size=8, shuffle=False)
    test_loader = DataLoader(dataset=test_dataset, num_workers=4, batch_size=8, shuffle=False)

    data_loading(train_loader, train_dataset)
    data_loading(valid_loader, valid_dataset)
    data_loading(test_loader, test_dataset)



In [10]:
if __name__ == '__main__':
    #main()
    pass

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
train_datapath = "./training"
valid_datapath = "./validation"
test_datapath = "./evaluation"

wts = [ 1.5, 3.5, 1, 1.52, 1.76, 1.13, 3.4, 5.35, 1.75, 1, 2.1 ]

train_dataset = Food11Dataset(train_datapath, input_transform=input_transform, is_train=True)
valid_dataset = Food11Dataset(valid_datapath, input_transform=not_train_transform, is_train=False)
test_dataset = Food11Dataset(test_datapath, input_transform=not_train_transform, is_train=False)
#train_dataset.augmentation(wts)

sampler = WRSampler(train_dataset, wts)

train_loader = DataLoader(dataset=train_dataset, num_workers=4, batch_size=32, sampler = sampler)
valid_loader = DataLoader(dataset=valid_dataset, num_workers=4, batch_size=64, shuffle=False)
test_loader = DataLoader(dataset=test_dataset, num_workers=4, batch_size=8, shuffle=False)

net = models.resnet18(pretrained=False)
net.fc = nn.Sequential(nn.Linear(512,256),nn.LeakyReLU(),nn.Linear(256,128),nn.LeakyReLU(),nn.Linear(128,11))
net = net.to(device) 

learning_rate = 0.0001
criterion = nn.CrossEntropyLoss() #定義損失函數
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, betas=[0.9, 0.999]) #定優化函數

net.train()
num_epoch = 200
run_epoch = 0
t_loss = []
v_loss = []
training_accuracy = []
validation_accuracy = []
total = 0

for epoch in range(200):  # loop over the dataset multiple times
    train_loss = 0.0
    validation_loss = 0.0
    correct_train = 0
    correct_validation = 0
    train_num = 0
    val_num = 0
    cls = []

    for i in range(11):
        cls.append(0)
    
    ########################
    # train the model      #
    ########################
    
    net.train()
    for i, (inputs, labels) in enumerate(train_loader, 0):
        
        #change the type into cuda tensor 
        inputs = inputs.to(device) 
        labels = labels.to(device) 

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        # select the class with highest probability
        _, pred = outputs.max(1)
        # if the model predicts the same results as the true
        # label, then the correct counter will plus 1
        correct_train += pred.eq(labels).sum().item()
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        # print statistics
        train_loss += loss.item()
        train_num += 1
   

        
    ########################
    # validate the model   #
    ########################
    
    net.eval()
    for i, (inputs, labels) in enumerate(valid_loader, 0):
        # move tensors to GPU if CUDA is available
        inputs = inputs.to(device) 
        labels = labels.to(device)
        # forward pass: compute predicted outputs by passing inputs to the model
        outputs = net(inputs)
        _, pred = outputs.max(1)
        correct_validation += pred.eq(labels).sum().item()
        # calculate the batch loss
        loss = criterion(outputs, labels)
        # update average validation loss 
        validation_loss += loss.item()
        val_num += 1
        
        
    if epoch % 1 == 0:    # print every 200 mini-batches
        print('[%d, %5d] train_loss: %.3f' % (epoch + 1, num_epoch, train_loss / train_num))
        print('[%d, %5d] validation_loss: %.3f' % (epoch + 1, num_epoch, validation_loss / val_num))
        print('%d epoch, training accuracy: %.4f' % (epoch + 1, correct_train / len(train_dataset)))
        print('%d epoch, validation accuracy: %.4f' % (epoch + 1, correct_validation / len(valid_dataset)))
        print('-----------------------------------------')
        
        t_loss.append(train_loss / train_num)
        training_accuracy.append(correct_train / len(train_dataset))
        validation_accuracy.append(correct_validation / len(valid_dataset))
        running_loss = 0.0
        validation_loss = 0.0
        train_num = 0
        val_num = 0
        correct_train = 0
        correct_validation = 0
        total = 0
        run_epoch += 1
        
        torch.save(net.state_dict(), './resnet18_no_augmentation/' + str(epoch) + '.pth')
    
print('Finished Training')

  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[1,   200] train_loss: 2.103
[1,   200] validation_loss: 2.188
1 epoch, training accuracy: 0.2605
1 epoch, validation accuracy: 0.2312
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[2,   200] train_loss: 1.819
[2,   200] validation_loss: 1.766
2 epoch, training accuracy: 0.3689
2 epoch, validation accuracy: 0.3840
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[3,   200] train_loss: 1.641
[3,   200] validation_loss: 2.075
3 epoch, training accuracy: 0.4332
3 epoch, validation accuracy: 0.3303
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[4,   200] train_loss: 1.519
[4,   200] validation_loss: 1.696
4 epoch, training accuracy: 0.4744
4 epoch, validation accuracy: 0.4143
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[5,   200] train_loss: 1.417
[5,   200] validation_loss: 1.541
5 epoch, training accuracy: 0.5182
5 epoch, validation accuracy: 0.4889
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[6,   200] train_loss: 1.291
[6,   200] validation_loss: 1.377
6 epoch, training accuracy: 0.5615
6 epoch, validation accuracy: 0.5327
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[7,   200] train_loss: 1.244
[7,   200] validation_loss: 1.293
7 epoch, training accuracy: 0.5797
7 epoch, validation accuracy: 0.5633
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[8,   200] train_loss: 1.188
[8,   200] validation_loss: 1.239
8 epoch, training accuracy: 0.5957
8 epoch, validation accuracy: 0.5819
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[9,   200] train_loss: 1.074
[9,   200] validation_loss: 1.374
9 epoch, training accuracy: 0.6373
9 epoch, validation accuracy: 0.5580
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[10,   200] train_loss: 1.013
[10,   200] validation_loss: 1.187
10 epoch, training accuracy: 0.6623
10 epoch, validation accuracy: 0.6140
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[11,   200] train_loss: 0.975
[11,   200] validation_loss: 1.315
11 epoch, training accuracy: 0.6677
11 epoch, validation accuracy: 0.5761
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[12,   200] train_loss: 0.908
[12,   200] validation_loss: 1.289
12 epoch, training accuracy: 0.6926
12 epoch, validation accuracy: 0.5816
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[13,   200] train_loss: 0.871
[13,   200] validation_loss: 1.152
13 epoch, training accuracy: 0.7068
13 epoch, validation accuracy: 0.6318
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[14,   200] train_loss: 0.853
[14,   200] validation_loss: 1.251
14 epoch, training accuracy: 0.7143
14 epoch, validation accuracy: 0.5907
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[15,   200] train_loss: 0.794
[15,   200] validation_loss: 1.425
15 epoch, training accuracy: 0.7301
15 epoch, validation accuracy: 0.5630
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[16,   200] train_loss: 0.774
[16,   200] validation_loss: 1.214
16 epoch, training accuracy: 0.7397
16 epoch, validation accuracy: 0.6265
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[17,   200] train_loss: 0.727
[17,   200] validation_loss: 1.072
17 epoch, training accuracy: 0.7558
17 epoch, validation accuracy: 0.6493
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[18,   200] train_loss: 0.711
[18,   200] validation_loss: 1.045
18 epoch, training accuracy: 0.7601
18 epoch, validation accuracy: 0.6714
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[19,   200] train_loss: 0.680
[19,   200] validation_loss: 1.125
19 epoch, training accuracy: 0.7676
19 epoch, validation accuracy: 0.6466
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[20,   200] train_loss: 0.660
[20,   200] validation_loss: 1.094
20 epoch, training accuracy: 0.7774
20 epoch, validation accuracy: 0.6647
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[21,   200] train_loss: 0.608
[21,   200] validation_loss: 1.226
21 epoch, training accuracy: 0.7981
21 epoch, validation accuracy: 0.6198
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[22,   200] train_loss: 0.582
[22,   200] validation_loss: 1.084
22 epoch, training accuracy: 0.8037
22 epoch, validation accuracy: 0.6659
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[23,   200] train_loss: 0.580
[23,   200] validation_loss: 1.129
23 epoch, training accuracy: 0.8060
23 epoch, validation accuracy: 0.6569
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[24,   200] train_loss: 0.565
[24,   200] validation_loss: 1.066
24 epoch, training accuracy: 0.8116
24 epoch, validation accuracy: 0.6851
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[25,   200] train_loss: 0.526
[25,   200] validation_loss: 1.081
25 epoch, training accuracy: 0.8239
25 epoch, validation accuracy: 0.6776
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[26,   200] train_loss: 0.512
[26,   200] validation_loss: 1.161
26 epoch, training accuracy: 0.8257
26 epoch, validation accuracy: 0.6618
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[27,   200] train_loss: 0.497
[27,   200] validation_loss: 1.036
27 epoch, training accuracy: 0.8317
27 epoch, validation accuracy: 0.6898
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[28,   200] train_loss: 0.460
[28,   200] validation_loss: 1.148
28 epoch, training accuracy: 0.8417
28 epoch, validation accuracy: 0.6714
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[29,   200] train_loss: 0.452
[29,   200] validation_loss: 1.108
29 epoch, training accuracy: 0.8455
29 epoch, validation accuracy: 0.6758
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[30,   200] train_loss: 0.460
[30,   200] validation_loss: 1.187
30 epoch, training accuracy: 0.8439
30 epoch, validation accuracy: 0.6554
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[31,   200] train_loss: 0.413
[31,   200] validation_loss: 1.115
31 epoch, training accuracy: 0.8581
31 epoch, validation accuracy: 0.6810
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[32,   200] train_loss: 0.398
[32,   200] validation_loss: 1.160
32 epoch, training accuracy: 0.8631
32 epoch, validation accuracy: 0.6726
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[33,   200] train_loss: 0.422
[33,   200] validation_loss: 1.143
33 epoch, training accuracy: 0.8587
33 epoch, validation accuracy: 0.6650
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[34,   200] train_loss: 0.375
[34,   200] validation_loss: 1.061
34 epoch, training accuracy: 0.8726
34 epoch, validation accuracy: 0.7000
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[35,   200] train_loss: 0.353
[35,   200] validation_loss: 1.183
35 epoch, training accuracy: 0.8803
35 epoch, validation accuracy: 0.6700
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[36,   200] train_loss: 0.349
[36,   200] validation_loss: 1.138
36 epoch, training accuracy: 0.8819
36 epoch, validation accuracy: 0.6889
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[37,   200] train_loss: 0.355
[37,   200] validation_loss: 1.125
37 epoch, training accuracy: 0.8771
37 epoch, validation accuracy: 0.6895
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[38,   200] train_loss: 0.339
[38,   200] validation_loss: 1.278
38 epoch, training accuracy: 0.8826
38 epoch, validation accuracy: 0.6665
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[39,   200] train_loss: 0.311
[39,   200] validation_loss: 1.366
39 epoch, training accuracy: 0.8991
39 epoch, validation accuracy: 0.6481
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[40,   200] train_loss: 0.331
[40,   200] validation_loss: 1.251
40 epoch, training accuracy: 0.8893
40 epoch, validation accuracy: 0.6743
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[41,   200] train_loss: 0.296
[41,   200] validation_loss: 1.264
41 epoch, training accuracy: 0.8969
41 epoch, validation accuracy: 0.6741
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[42,   200] train_loss: 0.303
[42,   200] validation_loss: 1.098
42 epoch, training accuracy: 0.8977
42 epoch, validation accuracy: 0.7093
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[43,   200] train_loss: 0.293
[43,   200] validation_loss: 1.264
43 epoch, training accuracy: 0.9033
43 epoch, validation accuracy: 0.6810
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[44,   200] train_loss: 0.269
[44,   200] validation_loss: 1.226
44 epoch, training accuracy: 0.9098
44 epoch, validation accuracy: 0.6895
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[45,   200] train_loss: 0.267
[45,   200] validation_loss: 1.364
45 epoch, training accuracy: 0.9081
45 epoch, validation accuracy: 0.6729
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[46,   200] train_loss: 0.256
[46,   200] validation_loss: 1.195
46 epoch, training accuracy: 0.9130
46 epoch, validation accuracy: 0.6828
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[47,   200] train_loss: 0.246
[47,   200] validation_loss: 1.124
47 epoch, training accuracy: 0.9160
47 epoch, validation accuracy: 0.7152
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[48,   200] train_loss: 0.234
[48,   200] validation_loss: 1.232
48 epoch, training accuracy: 0.9206
48 epoch, validation accuracy: 0.7038
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[49,   200] train_loss: 0.242
[49,   200] validation_loss: 1.198
49 epoch, training accuracy: 0.9191
49 epoch, validation accuracy: 0.7131
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[50,   200] train_loss: 0.231
[50,   200] validation_loss: 1.290
50 epoch, training accuracy: 0.9208
50 epoch, validation accuracy: 0.6878
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[51,   200] train_loss: 0.234
[51,   200] validation_loss: 1.372
51 epoch, training accuracy: 0.9187
51 epoch, validation accuracy: 0.6764
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[52,   200] train_loss: 0.215
[52,   200] validation_loss: 1.427
52 epoch, training accuracy: 0.9285
52 epoch, validation accuracy: 0.6694
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[53,   200] train_loss: 0.233
[53,   200] validation_loss: 1.226
53 epoch, training accuracy: 0.9226
53 epoch, validation accuracy: 0.7015
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[54,   200] train_loss: 0.200
[54,   200] validation_loss: 1.267
54 epoch, training accuracy: 0.9319
54 epoch, validation accuracy: 0.7023
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[55,   200] train_loss: 0.208
[55,   200] validation_loss: 1.297
55 epoch, training accuracy: 0.9303
55 epoch, validation accuracy: 0.7000
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[56,   200] train_loss: 0.215
[56,   200] validation_loss: 1.349
56 epoch, training accuracy: 0.9261
56 epoch, validation accuracy: 0.6950
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[57,   200] train_loss: 0.174
[57,   200] validation_loss: 1.333
57 epoch, training accuracy: 0.9428
57 epoch, validation accuracy: 0.6837
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[58,   200] train_loss: 0.194
[58,   200] validation_loss: 1.227
58 epoch, training accuracy: 0.9345
58 epoch, validation accuracy: 0.7210
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[59,   200] train_loss: 0.179
[59,   200] validation_loss: 1.397
59 epoch, training accuracy: 0.9362
59 epoch, validation accuracy: 0.6913
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[60,   200] train_loss: 0.181
[60,   200] validation_loss: 1.365
60 epoch, training accuracy: 0.9385
60 epoch, validation accuracy: 0.6913
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[61,   200] train_loss: 0.178
[61,   200] validation_loss: 1.288
61 epoch, training accuracy: 0.9401
61 epoch, validation accuracy: 0.7111
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[62,   200] train_loss: 0.164
[62,   200] validation_loss: 1.353
62 epoch, training accuracy: 0.9444
62 epoch, validation accuracy: 0.6936
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[63,   200] train_loss: 0.167
[63,   200] validation_loss: 1.312
63 epoch, training accuracy: 0.9416
63 epoch, validation accuracy: 0.6971
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[64,   200] train_loss: 0.165
[64,   200] validation_loss: 1.317
64 epoch, training accuracy: 0.9469
64 epoch, validation accuracy: 0.7009
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[65,   200] train_loss: 0.162
[65,   200] validation_loss: 1.320
65 epoch, training accuracy: 0.9468
65 epoch, validation accuracy: 0.7000
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[66,   200] train_loss: 0.169
[66,   200] validation_loss: 1.392
66 epoch, training accuracy: 0.9423
66 epoch, validation accuracy: 0.6883
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[67,   200] train_loss: 0.153
[67,   200] validation_loss: 1.383
67 epoch, training accuracy: 0.9506
67 epoch, validation accuracy: 0.6878
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[68,   200] train_loss: 0.154
[68,   200] validation_loss: 1.296
68 epoch, training accuracy: 0.9524
68 epoch, validation accuracy: 0.7120
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[69,   200] train_loss: 0.149
[69,   200] validation_loss: 1.491
69 epoch, training accuracy: 0.9493
69 epoch, validation accuracy: 0.6892
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[70,   200] train_loss: 0.157
[70,   200] validation_loss: 1.349
70 epoch, training accuracy: 0.9463
70 epoch, validation accuracy: 0.6918
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[71,   200] train_loss: 0.133
[71,   200] validation_loss: 1.441
71 epoch, training accuracy: 0.9555
71 epoch, validation accuracy: 0.6866
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[72,   200] train_loss: 0.145
[72,   200] validation_loss: 1.900
72 epoch, training accuracy: 0.9499
72 epoch, validation accuracy: 0.6385
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[73,   200] train_loss: 0.146
[73,   200] validation_loss: 1.345
73 epoch, training accuracy: 0.9493
73 epoch, validation accuracy: 0.7082
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[74,   200] train_loss: 0.140
[74,   200] validation_loss: 1.376
74 epoch, training accuracy: 0.9528
74 epoch, validation accuracy: 0.7017
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[75,   200] train_loss: 0.126
[75,   200] validation_loss: 1.456
75 epoch, training accuracy: 0.9558
75 epoch, validation accuracy: 0.6988
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[76,   200] train_loss: 0.129
[76,   200] validation_loss: 1.446
76 epoch, training accuracy: 0.9563
76 epoch, validation accuracy: 0.6939
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[77,   200] train_loss: 0.131
[77,   200] validation_loss: 1.388
77 epoch, training accuracy: 0.9560
77 epoch, validation accuracy: 0.7038
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[78,   200] train_loss: 0.127
[78,   200] validation_loss: 1.365
78 epoch, training accuracy: 0.9572
78 epoch, validation accuracy: 0.7140
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[79,   200] train_loss: 0.122
[79,   200] validation_loss: 1.337
79 epoch, training accuracy: 0.9568
79 epoch, validation accuracy: 0.7120
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[80,   200] train_loss: 0.115
[80,   200] validation_loss: 1.331
80 epoch, training accuracy: 0.9623
80 epoch, validation accuracy: 0.7155
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[81,   200] train_loss: 0.111
[81,   200] validation_loss: 1.370
81 epoch, training accuracy: 0.9612
81 epoch, validation accuracy: 0.7169
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[82,   200] train_loss: 0.117
[82,   200] validation_loss: 1.409
82 epoch, training accuracy: 0.9599
82 epoch, validation accuracy: 0.6939
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[83,   200] train_loss: 0.125
[83,   200] validation_loss: 1.404
83 epoch, training accuracy: 0.9591
83 epoch, validation accuracy: 0.7058
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[84,   200] train_loss: 0.119
[84,   200] validation_loss: 1.622
84 epoch, training accuracy: 0.9605
84 epoch, validation accuracy: 0.6810
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[85,   200] train_loss: 0.113
[85,   200] validation_loss: 1.367
85 epoch, training accuracy: 0.9631
85 epoch, validation accuracy: 0.7181
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[86,   200] train_loss: 0.116
[86,   200] validation_loss: 1.452
86 epoch, training accuracy: 0.9622
86 epoch, validation accuracy: 0.7026
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[87,   200] train_loss: 0.116
[87,   200] validation_loss: 1.366
87 epoch, training accuracy: 0.9596
87 epoch, validation accuracy: 0.7079
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[88,   200] train_loss: 0.112
[88,   200] validation_loss: 1.395
88 epoch, training accuracy: 0.9620
88 epoch, validation accuracy: 0.7096
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[89,   200] train_loss: 0.098
[89,   200] validation_loss: 1.419
89 epoch, training accuracy: 0.9656
89 epoch, validation accuracy: 0.7082
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[90,   200] train_loss: 0.100
[90,   200] validation_loss: 1.537
90 epoch, training accuracy: 0.9651
90 epoch, validation accuracy: 0.6875
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[91,   200] train_loss: 0.109
[91,   200] validation_loss: 1.435
91 epoch, training accuracy: 0.9640
91 epoch, validation accuracy: 0.7003
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[92,   200] train_loss: 0.096
[92,   200] validation_loss: 1.379
92 epoch, training accuracy: 0.9681
92 epoch, validation accuracy: 0.7187
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[93,   200] train_loss: 0.100
[93,   200] validation_loss: 1.486
93 epoch, training accuracy: 0.9660
93 epoch, validation accuracy: 0.7076
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[94,   200] train_loss: 0.114
[94,   200] validation_loss: 1.405
94 epoch, training accuracy: 0.9634
94 epoch, validation accuracy: 0.7157
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[95,   200] train_loss: 0.097
[95,   200] validation_loss: 1.345
95 epoch, training accuracy: 0.9662
95 epoch, validation accuracy: 0.7187
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[96,   200] train_loss: 0.095
[96,   200] validation_loss: 1.592
96 epoch, training accuracy: 0.9686
96 epoch, validation accuracy: 0.6907
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[97,   200] train_loss: 0.098
[97,   200] validation_loss: 1.413
97 epoch, training accuracy: 0.9654
97 epoch, validation accuracy: 0.7157
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[98,   200] train_loss: 0.081
[98,   200] validation_loss: 1.514
98 epoch, training accuracy: 0.9726
98 epoch, validation accuracy: 0.7061
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[99,   200] train_loss: 0.094
[99,   200] validation_loss: 1.510
99 epoch, training accuracy: 0.9690
99 epoch, validation accuracy: 0.7102
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[100,   200] train_loss: 0.100
[100,   200] validation_loss: 1.594
100 epoch, training accuracy: 0.9653
100 epoch, validation accuracy: 0.6848
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[101,   200] train_loss: 0.094
[101,   200] validation_loss: 1.457
101 epoch, training accuracy: 0.9686
101 epoch, validation accuracy: 0.7082
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[102,   200] train_loss: 0.093
[102,   200] validation_loss: 1.417
102 epoch, training accuracy: 0.9684
102 epoch, validation accuracy: 0.7070
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[103,   200] train_loss: 0.097
[103,   200] validation_loss: 1.475
103 epoch, training accuracy: 0.9690
103 epoch, validation accuracy: 0.6948
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[104,   200] train_loss: 0.075
[104,   200] validation_loss: 1.480
104 epoch, training accuracy: 0.9755
104 epoch, validation accuracy: 0.7096
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[105,   200] train_loss: 0.088
[105,   200] validation_loss: 1.540
105 epoch, training accuracy: 0.9695
105 epoch, validation accuracy: 0.6988
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[106,   200] train_loss: 0.088
[106,   200] validation_loss: 1.463
106 epoch, training accuracy: 0.9723
106 epoch, validation accuracy: 0.7204
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[107,   200] train_loss: 0.066
[107,   200] validation_loss: 1.613
107 epoch, training accuracy: 0.9770
107 epoch, validation accuracy: 0.7070
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[108,   200] train_loss: 0.096
[108,   200] validation_loss: 1.611
108 epoch, training accuracy: 0.9674
108 epoch, validation accuracy: 0.6831
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


[109,   200] train_loss: 0.087
[109,   200] validation_loss: 1.468
109 epoch, training accuracy: 0.9699
109 epoch, validation accuracy: 0.7108
-----------------------------------------


  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)
  warn_deprecated(msg, stacklevel=3)


In [None]:
def reload_net():
    trainednet = models.resnet18()  
    #trainednet.fc = torch.nn.Linear(1000, 11)  
    trainednet.fc = nn.Sequential(nn.Linear(512,256),nn.LeakyReLU(),nn.Linear(256,128),nn.LeakyReLU(),nn.Linear(128,11))
    trainednet.load_state_dict(torch.load('./resnet18_no_augmentation/38.pth'))
    return trainednet

In [None]:
pretrain_net = reload_net().to(device)
test_loss = 0.0
correct_test = 0
test_num = 0
cls = np.zeros(11)
correct_top3 = 0

In [None]:
pretrain_net.eval()
for i, (inputs, labels) in enumerate(test_loader, 0):
    # move tensors to GPU if CUDA is available
    inputs = inputs.to(device) 
    labels = labels.to(device)
    # forward pass: compute predicted outputs by passing inputs to the model
    outputs = pretrain_net(inputs)
    _, pred = outputs.max(1)
    correct_test += pred.eq(labels).sum().item()
    _, top3 = outputs.topk(3)
    correct_top3 += top3.eq(labels.view(-1,1).expand_as(top3)).sum().item()
    
    cls[0] += (pred.eq(0) * pred.eq(labels)).sum().item()
    cls[1] += (pred.eq(1) * pred.eq(labels)).sum().item()
    cls[2] += (pred.eq(2) * pred.eq(labels)).sum().item()
    cls[3] += (pred.eq(3) * pred.eq(labels)).sum().item()
    cls[4] += (pred.eq(4) * pred.eq(labels)).sum().item()
    cls[5] += (pred.eq(5) * pred.eq(labels)).sum().item()
    cls[6] += (pred.eq(6) * pred.eq(labels)).sum().item()
    cls[7] += (pred.eq(7) * pred.eq(labels)).sum().item()
    cls[8] += (pred.eq(8) * pred.eq(labels)).sum().item()
    cls[9] += (pred.eq(9) * pred.eq(labels)).sum().item()
    cls[10] += (pred.eq(10) * pred.eq(labels)).sum().item()

In [None]:
print('Test set: Top 1 Accuracy: %d/3347 (%.2f%%), Top 3 Accuracy: %d/3347 (%.2f%%)' 
      % (correct_test, correct_test / len(test_dataset)*100, correct_top3, correct_top3/ len(test_dataset)*100))
print('%-20s : %d/%d    %10f%%' % (code2names[0], cls[0], 368, cls[0]/368*100))
print('%-20s : %d/%d    %10f%%' % (code2names[1], cls[1], 148, cls[1]/148*100))
print('%-20s : %d/%d    %10f%%' % (code2names[2], cls[2], 500, cls[2]/500*100))
print('%-20s : %d/%d    %10f%%' % (code2names[3], cls[3], 335, cls[3]/335*100))
print('%-20s : %d/%d    %10f%%' % (code2names[4], cls[4], 287, cls[4]/287*100))
print('%-20s : %d/%d    %10f%%' % (code2names[5], cls[5], 432, cls[5]/432*100))
print('%-20s : %d/%d    %10f%%' % (code2names[6], cls[6], 147, cls[6]/147*100))
print('%-20s : %d/%d    %10f%%' % (code2names[7], cls[7], 96, cls[7]/96*100))
print('%-20s : %d/%d    %10f%%' % (code2names[8], cls[8], 303, cls[8]/303*100))
print('%-20s : %d/%d    %10f%%' % (code2names[9], cls[9], 500, cls[9]/500*100))
print('%-20s : %d/%d    %10f%%' % (code2names[10], cls[10], 231, cls[10]/231*100))

avg = []
avg.append(cls[0]/368*100)
avg.append(cls[1]/148*100)
avg.append(cls[2]/500*100)
avg.append(cls[3]/335*100)
avg.append(cls[4]/287*100)
avg.append(cls[5]/432*100)
avg.append(cls[6]/147*100)
avg.append(cls[7]/96*100)
avg.append(cls[8]/303*100)
avg.append(cls[9]/500*100)
avg.append(cls[10]/231*100)
print('Average per case accuracy: %10f%%' % (sum(avg)/len(avg)))
print('-----------------------------------------')