In [0]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [0]:
# import os
# os.environ['KAGGLE_CONFIG_DIR'] = "/content/gdrive/My Drive/kaggle"

In [0]:
# %cd /content/gdrive/My Drive/kaggle
# %pwd

In [0]:
# ! kaggle datasets download -d paramaggarwal/fashion-product-images-small

In [0]:
# %cd ~

In [0]:
!unzip "/content/gdrive/My Drive/kaggle/fashion-product-images-small.zip"

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: myntradataset/images/58131.jpg  
  inflating: myntradataset/images/58132.jpg  
  inflating: myntradataset/images/58133.jpg  
  inflating: myntradataset/images/58135.jpg  
  inflating: myntradataset/images/58136.jpg  
  inflating: myntradataset/images/58137.jpg  
  inflating: myntradataset/images/58138.jpg  
  inflating: myntradataset/images/58139.jpg  
  inflating: myntradataset/images/5814.jpg  
  inflating: myntradataset/images/58140.jpg  
  inflating: myntradataset/images/58141.jpg  
  inflating: myntradataset/images/58143.jpg  
  inflating: myntradataset/images/58144.jpg  
  inflating: myntradataset/images/58145.jpg  
  inflating: myntradataset/images/58146.jpg  
  inflating: myntradataset/images/58147.jpg  
  inflating: myntradataset/images/58148.jpg  
  inflating: myntradataset/images/58149.jpg  
  inflating: myntradataset/images/5815.jpg  
  inflating: myntradataset/images/58150.jpg  
  inflating: mynt

In [0]:
%ls

[0m[01;34mgdrive[0m/  [01;34mimages[0m/  [01;34mmyntradataset[0m/  [01;34msample_data[0m/  styles.csv


In [0]:
%rm -r images
%rm styles.csv

In [0]:
%ls

[0m[01;34mgdrive[0m/  [01;34mmyntradataset[0m/  [01;34msample_data[0m/


### Clean and filter metadata

In [0]:
import numpy as np
import pandas as pd
import os

base_path = "myntradataset/"
old_csv_path = os.path.join(base_path, "styles.csv")
csv_path = os.path.join(base_path, "styles_fixed.csv")
img_path = os.path.join(base_path, "images")

In [0]:
# Fixing bad lines in csv file (due to commas in product names)

import csv

with open(old_csv_path) as rf, open(csv_path, 'w') as wf:
    csv_reader = csv.reader(rf, delimiter=',')
    csv_writer = csv.writer(wf, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
    for row in csv_reader:
        if len(row) > 10:
            save_row = row[:8]
            save_row.append(','.join(row[9:]))
        else:
            save_row = row
        csv_writer.writerow(save_row)

In [0]:
# %cp myntradataset/styles_fixed.csv "/content/gdrive/My Drive/kaggle/fashion-product-images-small/styles_fixed.csv"

In [0]:
styles = pd.read_csv(csv_path)

In [0]:
styles['image'] = styles.apply(lambda row: str(row['id']) + ".jpg", axis=1)
print(styles['image'])

0        15970.jpg
1        39386.jpg
2        59263.jpg
3        21379.jpg
4        53759.jpg
           ...    
44441    17036.jpg
44442     6461.jpg
44443    18842.jpg
44444    46694.jpg
44445    51623.jpg
Name: image, Length: 44446, dtype: object


In [0]:
for image in styles['image']:
    if not os.path.exists(os.path.join(img_path, image)):
        print("Image {} doesn't exist!".format(image))

Image 39403.jpg doesn't exist!
Image 39410.jpg doesn't exist!
Image 39401.jpg doesn't exist!
Image 39425.jpg doesn't exist!
Image 12347.jpg doesn't exist!


In [0]:
# Filer out rows for which images don't exist
img_exists = styles.apply(lambda row: os.path.exists(os.path.join(img_path, row['image'])), axis=1)

styles = styles[img_exists]
print(styles)

          id gender  ...                               productDisplayName      image
0      15970    Men  ...                 Turtle Check Men Navy Blue Shirt  15970.jpg
1      39386    Men  ...               Peter England Men Party Blue Jeans  39386.jpg
2      59263  Women  ...                         Titan Women Silver Watch  59263.jpg
3      21379    Men  ...    Manchester United Men Solid Black Track Pants  21379.jpg
4      53759    Men  ...                            Puma Men Grey T-shirt  53759.jpg
...      ...    ...  ...                                              ...        ...
44441  17036    Men  ...                        Gas Men Caddy Casual Shoe  17036.jpg
44442   6461    Men  ...               Lotto Men's Soccer Track Flip Flop   6461.jpg
44443  18842    Men  ...             Puma Men Graphic Stellar Blue Tshirt  18842.jpg
44444  46694  Women  ...                   Rasasi Women Blue Lady Perfume  46694.jpg
44445  51623  Women  ...  Fossil Women Pink Dial Chronograph Watc

### Create train-test splits

In [0]:
# Place images from even-numbered years in train-split and odd-numbered years in test-split

is_train = styles['year']%2==0
is_test = styles['year']%2!=0

full_train = styles[is_train]
print(full_train.shape)

full_test = styles[is_test]
print(full_test.shape)

(23787, 11)
(20654, 11)


### Sub-split training data for pre-training and fine-tuning

In [0]:
# Find the top-20 classes in the training split

top_articleType = styles.groupby('articleType').size().sort_values(ascending=False).head(20).reset_index()

print("Top 20 classes:")
print(top_articleType)

Top 20 classes:
              articleType     0
0                 Tshirts  7069
1                  Shirts  3215
2            Casual Shoes  2846
3                 Watches  2542
4            Sports Shoes  2036
5                  Kurtas  1844
6                    Tops  1762
7                Handbags  1759
8                   Heels  1323
9              Sunglasses  1073
10                Wallets   936
11             Flip Flops   916
12                Sandals   897
13                 Briefs   849
14                  Belts   813
15              Backpacks   724
16                  Socks   686
17           Formal Shoes   637
18  Perfume and Body Mist   614
19                  Jeans   608


In [0]:
total_classes = len(styles['articleType'].unique())
rest_classes = total_classes - 20

rest_articleType = styles.groupby('articleType').size().sort_values(ascending=True).head(rest_classes).reset_index()

print("Remaining {} classes:".format(rest_classes))
print(rest_articleType)

Remaining 122 classes:
             articleType    0
0                   Ipad    1
1         Hair Accessory    1
2         Cushion Covers    1
3      Mens Grooming Kit    1
4    Body Wash and Scrub    1
..                   ...  ...
117              Dresses  464
118                  Bra  477
119                Flats  500
120             Trousers  530
121               Shorts  547

[122 rows x 2 columns]


In [0]:
# Covert the datatype of column containing articleType to categorical

top_articleType['articleType'] = top_articleType['articleType'].astype('category')
rest_articleType['articleType'] = rest_articleType['articleType'].astype('category')

In [0]:
classmap_top20 = dict(zip(top_articleType['articleType'], top_articleType['articleType'].cat.codes))
classmap_ft = dict(zip(rest_articleType['articleType'], rest_articleType['articleType'].cat.codes))

In [0]:
# Check if there is any intersection between top-20 and remaining values

print(len(set(classmap_ft.keys()).intersection(set(classmap_top20.keys()))))

0


In [0]:
# Pick the rows corresponding to the top-20 classes for pretraining and corresponding testing

filter_topArticles = full_train['articleType'].isin(top_articleType['articleType'])
train_top20_data = full_train[filter_topArticles]
print(train_top20_data.shape)

filter_topArticles_test = full_test['articleType'].isin(top_articleType['articleType'])
test_top20_data = full_test[filter_topArticles_test]
print(test_top20_data.shape)

(18000, 11)
(15149, 11)


In [0]:
# Pick the rows corresponding to the rest of the classes for fine-tuning and corresponding testing

train_ft_data = full_train[~filter_topArticles]
print(train_ft_data.shape)

test_ft_data = full_test[~filter_topArticles_test]
print(test_ft_data.shape)

(5787, 11)
(5505, 11)


In [0]:
print(len(train_top20_data['articleType'].unique()))
print(len(train_ft_data['articleType'].unique()))
print(len(full_train['articleType'].unique()))

19
88
107


In [0]:
print(len(test_top20_data['articleType'].unique()))
print(len(test_ft_data['articleType'].unique()))
print(len(full_test['articleType'].unique()))

20
102
122


In [0]:
# Check if there is any class overlap between train-top20 and fine-tune data (should be 0)

print(len(set(train_top20_data['articleType'].unique()).intersection(set(train_ft_data['articleType'].unique()))))
print(len(set(test_top20_data['articleType'].unique()).intersection(set(test_ft_data['articleType'].unique()))))

0
0


In [0]:
from sklearn.model_selection import train_test_split

train_top20_data, val_top20_data = train_test_split(train_top20_data, test_size=0.2)
train_ft_data, val_ft_data = train_test_split(train_ft_data, test_size=0.2)

In [0]:
print(len(train_top20_data), len(val_top20_data))
print(len(train_ft_data), len(val_ft_data))

14400 3600
4629 1158


In [0]:
data_map = {'train_top20': train_top20_data,
            'val_top20': val_top20_data,
            'train_ft': train_ft_data,
            'val_ft': val_ft_data,
            'test_top20': test_top20_data,
            'test_ft': test_ft_data}

data_top20_map = {
    'train': train_top20_data,
    'val': val_top20_data,
    'test': test_top20_data
}

data_ft_map = {
    'train': train_ft_data,
    'val': val_ft_data,
    'test': test_ft_data
}

### Load Dataset

In [0]:
import torch
import torchvision
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from PIL import Image

class FashionDatasetSmall(Dataset):
    """
    Custom dataset class that uses metadata stored in a DataFrame and returns
    the actual data: (x,y) = (img, label) -
        img is a Tensor representing the image
        label is the categorical label corresponding to the articleType
    """
    def __init__(self, df, img_path, class_map, data_transforms=None):
        """
        Args:
            df (DataFrame): pandas DataFrame containing the metadata
            img_path (string): path to the folder where images are
            data_transforms: pytorch transforms transformations
        """
        self.image_arr = np.asarray(df['image'].values)
        self.label_arr = np.asarray(df['articleType'].values)
        self.to_tensor = transforms.ToTensor()
        self.img_path = img_path
        self.class_map = class_map
        self.data_transforms = data_transforms

    def __getitem__(self, index):
        try:
            img_name = self.image_arr[index]
            img_as_img = Image.open(os.path.join(self.img_path, img_name))
            if img_as_img.mode != 'RGB':
                img_as_img = img_as_img.convert('RGB')

            if self.data_transforms is not None:
                img_as_tensor = self.data_transforms(img_as_img)
            else:
                img_as_tensor = self.to_tensor(img_as_img)

            label = self.class_map[self.label_arr[index]]
        except Exception as e:
            print("Exception while trying to fetch item at index",index)
            print("Image =",img_name)
            raise e

        return (img_as_tensor, label)

    def __len__(self):
        return len(self.image_arr)

In [0]:
# Data augmentation and normalization for training and fine-tuning
# Just normalization for test data

train_transforms = [
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]

test_transforms = [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]

val_transforms = test_transforms

data_transforms = {}
for key in data_top20_map.keys():
    if 'train' in key:
        data_transforms[key] = transforms.Compose(train_transforms)
    elif 'val' in key:
        data_transforms[key] = transforms.Compose(val_transforms)
    else:
        data_transforms[key] = transforms.Compose(test_transforms)

datasets_top20 = {x: FashionDatasetSmall(data_top20_map[x], img_path,
                                         classmap_top20, data_transforms[x]) 
            for x in data_top20_map.keys()}

for name, dataset in datasets_top20.items():
    print("Created {} dataset with {} samples".format(name, len(dataset)))

Created train dataset with 14400 samples
Created val dataset with 3600 samples
Created test dataset with 15149 samples


In [0]:
from torch.utils.data import DataLoader

dataloaders_top20 = {x: DataLoader(datasets_top20[x], batch_size=64,
                             shuffle=False, num_workers=1)
              for x in data_top20_map.keys()}

In [0]:
verify_datasets = False

if verify_datasets:
    for name, dataloader in dataloaders_top20.items():
        print("\nVerifying dataloader for {}".format(name))
        for i, item in enumerate(dataloader):
            if i%10 == 0:
                print(item[0].size())

In [0]:
# Adapted from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

import time
import copy
from tqdm.notebook import tqdm

SAVE_PATH = '/content/gdrive/My Drive/kaggle/fashion-product-images-small/ckpts'
CKPT_PATH = os.path.join(SAVE_PATH, 'best_val_top20.ckpt')
CKPT_PATH_FT = os.path.join(SAVE_PATH, 'best_val_ft.ckpt')

def train_model(model, criterion, optimizer, data_loaders, scheduler=None,
                num_epochs=10, lock_weights=False, load_acc=0., CKPT_PATH=CKPT_PATH):
    if lock_weights==True:
        for param in model.parameters():
            param.requires_grad = False
    
    since = time.time()
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            current_loss = 0.0
            current_corrects = 0

            print('Iterating through data for phase: {}...'.format(phase))

            for inputs, labels in tqdm(data_loaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device).long()

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                current_loss += loss.item() * inputs.size(0)
                current_corrects += torch.sum(preds == labels.data)
            
            if phase == 'train' and scheduler is not None:
                scheduler.step()

            epoch_loss = current_loss / len(data_loaders[phase].dataset)
            epoch_acc = current_corrects.double() / len(data_loaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                # Check if the current best model is better than the loaded model
                if best_acc > load_acc:
                    print('--Saving checkpoint--')
                    # Save checkpont file
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': best_model_wts,
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': epoch_loss,
                        'acc': epoch_acc
                    }, CKPT_PATH)

        print()

    time_since = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_since // 60, time_since % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # Now we'll load in the best model weights and return it
    model.load_state_dict(best_model_wts)

    return model

In [0]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

### Training on top-20 classes

In [0]:
from torchvision import models
import torch.nn as nn
import torch
import torch.optim as optim

resnet = models.resnet50(pretrained=True)

In [0]:
# Modify the final fully connected layer for our dataset

num_features = resnet.fc.in_features
resnet.fc = nn.Linear(num_features, len(classmap_top20))

In [0]:
print(resnet.fc)

Linear(in_features=2048, out_features=20, bias=True)


In [0]:
# Estimate class weights to handle imbalance

def get_class_weights(df, classmap, eps=1):
    """
    Returns class weights corresponding to each class using the formula:
        weights = n_samples / (n_classes * (np.bincount(y)+eps))
    'eps' is added to handle classes that don't have any samples
    """
    labels = [classmap[x] for x in df['articleType']]
    labels_count = np.bincount(labels) + eps
    return len(labels) / (len(classmap) * labels_count)

weights_top20 = get_class_weights(data_top20_map['train'], classmap_top20)
print(weights_top20)

[3.61809045e+00 1.79104478e+00 1.36105860e+00 5.76923077e-01
 1.78217822e+00 2.50000000e+00 9.00000000e-01 1.16129032e+00
 2.67657993e+00 7.88608981e-01 7.20000000e+02 1.87500000e+00
 8.03571429e-01 1.81818182e+00 1.01265823e+00 8.68516285e-01
 8.62275449e-01 3.25497288e-01 1.48453608e+00 3.60180090e-01]


In [0]:
from torch.optim import lr_scheduler

criterion = nn.CrossEntropyLoss(weight=torch.Tensor(weights_top20).to(device))
# criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(resnet.parameters(), eps=1e-07)

load_flag = True

# Load previous checkpoint if it exists
if load_flag and os.path.exists(CKPT_PATH):
    checkpoint = torch.load(CKPT_PATH)
    resnet.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    load_acc = checkpoint['acc']
    print("Loaded previous checkpoint trained on {} epoch(s) with final loss={:.4f}, acc={:.4f}".format(epoch+1, loss, load_acc))

Loaded previous checkpoint trained on 10 epoch(s) with final loss=0.1642, acc=0.9369


In [0]:
resnet = resnet.to(device)

In [0]:
resnet = train_model(resnet, criterion, optimizer, dataloaders_top20,
                     num_epochs=10, load_acc=load_acc)

Epoch 0/9
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=225), HTML(value='')))


train Loss: 0.3188 Acc: 0.8748
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=57), HTML(value='')))


val Loss: 0.1869 Acc: 0.9297

Epoch 1/9
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=225), HTML(value='')))


train Loss: 0.3166 Acc: 0.8800
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=57), HTML(value='')))


val Loss: 0.1854 Acc: 0.9286

Epoch 2/9
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=225), HTML(value='')))


train Loss: 0.3124 Acc: 0.8809
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=57), HTML(value='')))


val Loss: 0.1785 Acc: 0.9294

Epoch 3/9
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=225), HTML(value='')))


train Loss: 0.3116 Acc: 0.8779
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=57), HTML(value='')))


val Loss: 0.1737 Acc: 0.9325

Epoch 4/9
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=225), HTML(value='')))


train Loss: 0.3078 Acc: 0.8788
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=57), HTML(value='')))


val Loss: 0.1844 Acc: 0.9292

Epoch 5/9
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=225), HTML(value='')))


train Loss: 0.3042 Acc: 0.8801
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=57), HTML(value='')))


val Loss: 0.1721 Acc: 0.9356
--Saving checkpoint--

Epoch 6/9
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=225), HTML(value='')))


train Loss: 0.2940 Acc: 0.8864
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=57), HTML(value='')))


val Loss: 0.1802 Acc: 0.9328

Epoch 7/9
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=225), HTML(value='')))


train Loss: 0.2827 Acc: 0.8913
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=57), HTML(value='')))


val Loss: 0.1829 Acc: 0.9336

Epoch 8/9
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=225), HTML(value='')))


train Loss: 0.2719 Acc: 0.8942
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=57), HTML(value='')))


val Loss: 0.1803 Acc: 0.9336

Epoch 9/9
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=225), HTML(value='')))


train Loss: 0.2640 Acc: 0.8924
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=57), HTML(value='')))


val Loss: 0.1642 Acc: 0.9369
--Saving checkpoint--

Training complete in 14m 39s
Best val Acc: 0.936944


In [0]:
# torch.cuda.empty_cache()

In [0]:
def get_accuracy(model, dataloader, topk=(1,5)):
    """Computes the accuracy@k for the specified values of k"""
    total = len(dataloader.dataset)
    maxk = max(topk)
    correct_count = {k: 0 for k in topk}

    with torch.no_grad():
        for images, labels in tqdm(dataloader):
            images = images.to(device)
            labels = labels.to(device).long()
            outputs = model(images)
            _, pred = torch.topk(outputs, maxk, 1)
            pred = pred.t()
            correct = pred.eq(labels.view(1, -1).expand_as(pred))
            for k in topk:
                correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
                correct_count[k] += correct_k

    return [v.item()/total for v in correct_count.values()]

acc = get_accuracy(resnet, dataloaders_top20['test'])

print("Test accuracy: Top-1 = {}, Top-5 = {}".format(acc[0], acc[1]))

HBox(children=(IntProgress(value=0, max=237), HTML(value='')))


Test accuracy: Top-1 = 0.8670539309525381, Top-5 = 0.9569608555020134


### Training on fine-tune set

In [0]:
datasets_ft = {x: FashionDatasetSmall(data_ft_map[x], img_path,
                                         classmap_ft, data_transforms[x]) 
            for x in data_ft_map.keys()}

for name, dataset in datasets_ft.items():
    print("Created {} dataset with {} samples".format(name, len(dataset)))

Created train dataset with 4629 samples
Created val dataset with 1158 samples
Created test dataset with 5505 samples


In [0]:
dataloaders_ft = {x: DataLoader(datasets_ft[x], batch_size=64,
                             shuffle=False, num_workers=1)
              for x in data_ft_map.keys()}

In [0]:
verify_datasets = True

if verify_datasets:
    for name, dataloader in dataloaders_ft.items():
        print("\nVerifying dataloader for {}".format(name))
        for item in tqdm(dataloader):
            pass


Verifying dataloader for train


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))



Verifying dataloader for val


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))



Verifying dataloader for test


HBox(children=(IntProgress(value=0, max=87), HTML(value='')))




In [0]:
# Modify the final fully connected layer for our dataset

print(resnet.fc)

resnet_ft = resnet

num_features = resnet.fc.in_features
resnet_ft.fc = nn.Linear(num_features, len(classmap_ft))

print(resnet_ft.fc)

Linear(in_features=2048, out_features=20, bias=True)
Linear(in_features=2048, out_features=122, bias=True)


In [0]:
# Estimate class weights to handle imbalance

weights_ft = get_class_weights(data_ft_map['train'], classmap_ft)
print(weights_ft.shape)

(122,)


In [0]:
criterion_ft = nn.CrossEntropyLoss(weight=torch.Tensor(weights_ft).to(device))

optimizer_ft = optim.Adam(resnet_ft.parameters(), eps=1e-07)

load_flag = True

# Load previous checkpoint if it exists
if load_flag and os.path.exists(CKPT_PATH_FT):
    checkpoint = torch.load(CKPT_PATH_FT)
    resnet_ft.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    load_acc = checkpoint['acc']
    print("Loaded previous checkpoint trained on {} epoch(s) with final loss={:.4f}, acc={:.4f}".format(epoch+1, loss, load_acc))
else:
    load_acc = 0.

In [0]:
resnet_ft = resnet_ft.to(device)

In [0]:
resnet_ft = train_model(resnet_ft, criterion_ft, optimizer_ft, dataloaders_ft,
                     num_epochs=10, load_acc=load_acc, CKPT_PATH=CKPT_PATH_FT)

Epoch 0/9
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 5.5425 Acc: 0.0261
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 4.5207 Acc: 0.0285
--Saving checkpoint--

Epoch 1/9
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 4.4414 Acc: 0.0497
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 4.4502 Acc: 0.0553
--Saving checkpoint--

Epoch 2/9
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 4.0786 Acc: 0.1145
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.9911 Acc: 0.1425
--Saving checkpoint--

Epoch 3/9
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 3.7616 Acc: 0.1426
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.7518 Acc: 0.1658
--Saving checkpoint--

Epoch 4/9
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 3.3971 Acc: 0.2022
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.3897 Acc: 0.3247
--Saving checkpoint--

Epoch 5/9
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 3.1150 Acc: 0.2361
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.2527 Acc: 0.3135

Epoch 6/9
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 2.8661 Acc: 0.2620
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.8413 Acc: 0.1459

Epoch 7/9
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 2.7655 Acc: 0.2817
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.2304 Acc: 0.3169

Epoch 8/9
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 2.5148 Acc: 0.3167
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.0834 Acc: 0.3282
--Saving checkpoint--

Epoch 9/9
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 2.3684 Acc: 0.3346
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.3293 Acc: 0.3221

Training complete in 4m 51s
Best val Acc: 0.328152


In [0]:
resnet_ft = train_model(resnet_ft, criterion_ft, optimizer_ft, dataloaders_ft,
                     num_epochs=20, load_acc=load_acc, CKPT_PATH=CKPT_PATH_FT)

Epoch 0/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 2.4031 Acc: 0.3372
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.0859 Acc: 0.4413
--Saving checkpoint--

Epoch 1/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 2.2067 Acc: 0.3793
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.0317 Acc: 0.5225
--Saving checkpoint--

Epoch 2/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 2.0791 Acc: 0.3945
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.2450 Acc: 0.4136

Epoch 3/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 2.0164 Acc: 0.3975
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.0958 Acc: 0.4870

Epoch 4/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 1.8642 Acc: 0.4277
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.0512 Acc: 0.4758

Epoch 5/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f91b00d42b0>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process



train Loss: 1.8255 Acc: 0.4372
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f91b00d42b0>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process



val Loss: 4.1717 Acc: 0.2720

Epoch 6/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f91b00d42b0>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process



train Loss: 1.7568 Acc: 0.4526
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f91b00d42b0>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process



val Loss: 3.0542 Acc: 0.4801

Epoch 7/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f91b00d42b0>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process



train Loss: 1.6680 Acc: 0.4774
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f91b00d42b0>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process



val Loss: 2.8936 Acc: 0.5432
--Saving checkpoint--

Epoch 8/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 1.5806 Acc: 0.4925
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.1014 Acc: 0.5570
--Saving checkpoint--

Epoch 9/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 1.4623 Acc: 0.5131
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.1022 Acc: 0.5328

Epoch 10/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 1.5980 Acc: 0.4893
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.2835 Acc: 0.5147

Epoch 11/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 1.4560 Acc: 0.5016
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 2.9642 Acc: 0.6339
--Saving checkpoint--

Epoch 12/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 1.3761 Acc: 0.5416
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 2.9514 Acc: 0.6114

Epoch 13/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 1.2159 Acc: 0.5649
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.1481 Acc: 0.5959

Epoch 14/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 1.2334 Acc: 0.5587
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.1124 Acc: 0.6434
--Saving checkpoint--

Epoch 15/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 1.2402 Acc: 0.5671
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.0608 Acc: 0.6554
--Saving checkpoint--

Epoch 16/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 1.1166 Acc: 0.5878
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.4238 Acc: 0.5907

Epoch 17/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 1.1470 Acc: 0.5813
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.0325 Acc: 0.6200

Epoch 18/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 1.0866 Acc: 0.5908
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.0519 Acc: 0.6753
--Saving checkpoint--

Epoch 19/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 1.0660 Acc: 0.6049
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.9075 Acc: 0.4568

Training complete in 9m 37s
Best val Acc: 0.675302


In [0]:
resnet_ft = train_model(resnet_ft, criterion_ft, optimizer_ft, dataloaders_ft,
                     num_epochs=20, load_acc=load_acc, CKPT_PATH=CKPT_PATH_FT)

Epoch 0/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 1.0597 Acc: 0.6237
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.1083 Acc: 0.6718
--Saving checkpoint--

Epoch 1/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 1.0253 Acc: 0.6209
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.1725 Acc: 0.6667

Epoch 2/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.9206 Acc: 0.6291
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.3345 Acc: 0.6831
--Saving checkpoint--

Epoch 3/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 1.0697 Acc: 0.6140
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 2.9881 Acc: 0.6718

Epoch 4/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.8295 Acc: 0.6589
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.2085 Acc: 0.6926
--Saving checkpoint--

Epoch 5/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.9314 Acc: 0.6394
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.3448 Acc: 0.6848

Epoch 6/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.8840 Acc: 0.6587
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.6225 Acc: 0.6727

Epoch 7/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.9600 Acc: 0.6317
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.0687 Acc: 0.6554

Epoch 8/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.8563 Acc: 0.6563
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.1868 Acc: 0.7366
--Saving checkpoint--

Epoch 9/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.7908 Acc: 0.6848
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.0409 Acc: 0.7176

Epoch 10/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.8860 Acc: 0.6546
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.3945 Acc: 0.6770

Epoch 11/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.8149 Acc: 0.6682
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.5779 Acc: 0.7211

Epoch 12/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.7950 Acc: 0.6872
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.4071 Acc: 0.6762

Epoch 13/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.8115 Acc: 0.6811
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.2461 Acc: 0.7107

Epoch 14/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.7256 Acc: 0.7017
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.1790 Acc: 0.6710

Epoch 15/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.7618 Acc: 0.6928
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.4716 Acc: 0.7409
--Saving checkpoint--

Epoch 16/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.6482 Acc: 0.7256
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.5389 Acc: 0.7349

Epoch 17/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.6565 Acc: 0.7233
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.5951 Acc: 0.7599
--Saving checkpoint--

Epoch 18/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.6370 Acc: 0.7228
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.4751 Acc: 0.7599

Epoch 19/19
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.6031 Acc: 0.7354
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.7121 Acc: 0.7029

Training complete in 9m 35s
Best val Acc: 0.759931


In [0]:
resnet_ft = train_model(resnet_ft, criterion_ft, optimizer_ft, dataloaders_ft,
                     num_epochs=50, load_acc=load_acc, CKPT_PATH=CKPT_PATH_FT)

Epoch 0/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.6454 Acc: 0.7209
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.5974 Acc: 0.7539
--Saving checkpoint--

Epoch 1/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.6453 Acc: 0.7205
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.4213 Acc: 0.7409

Epoch 2/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.6712 Acc: 0.7274
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 4.0679 Acc: 0.6762

Epoch 3/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.7904 Acc: 0.6859
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.5920 Acc: 0.6943

Epoch 4/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.6501 Acc: 0.7174
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.3601 Acc: 0.7297

Epoch 5/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.6474 Acc: 0.7367
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.5221 Acc: 0.7496

Epoch 6/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.6074 Acc: 0.7310
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.9447 Acc: 0.7219

Epoch 7/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.6592 Acc: 0.7222
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 5.0590 Acc: 0.5406

Epoch 8/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.7071 Acc: 0.7192
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.9988 Acc: 0.5967

Epoch 9/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.7070 Acc: 0.7146
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 4.0352 Acc: 0.6934

Epoch 10/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.7448 Acc: 0.7010
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.5333 Acc: 0.7211

Epoch 11/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.5575 Acc: 0.7485
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.7101 Acc: 0.7608
--Saving checkpoint--

Epoch 12/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.5316 Acc: 0.7630
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.5952 Acc: 0.7694
--Saving checkpoint--

Epoch 13/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.5354 Acc: 0.7712
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.8419 Acc: 0.7668

Epoch 14/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.6667 Acc: 0.7358
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.6499 Acc: 0.6779

Epoch 15/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.7623 Acc: 0.6926
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.4980 Acc: 0.7038

Epoch 16/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.5905 Acc: 0.7537
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.6587 Acc: 0.7314

Epoch 17/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.6791 Acc: 0.7228
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.2257 Acc: 0.7193

Epoch 18/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.7093 Acc: 0.7233
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.7956 Acc: 0.7383

Epoch 19/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.5370 Acc: 0.7645
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.5534 Acc: 0.7263

Epoch 20/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.5034 Acc: 0.7803
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.6262 Acc: 0.7945
--Saving checkpoint--

Epoch 21/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.5082 Acc: 0.7691
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.7695 Acc: 0.7254

Epoch 22/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.5383 Acc: 0.7758
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 4.0563 Acc: 0.7418

Epoch 23/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.5467 Acc: 0.7717
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.8199 Acc: 0.7366

Epoch 24/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.5028 Acc: 0.7827
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 4.0133 Acc: 0.7522

Epoch 25/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.4702 Acc: 0.7928
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.5422 Acc: 0.7660

Epoch 26/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.5030 Acc: 0.7755
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.9386 Acc: 0.7565

Epoch 27/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.4066 Acc: 0.7963
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.8451 Acc: 0.7522

Epoch 28/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.4319 Acc: 0.8142
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 4.1403 Acc: 0.7712

Epoch 29/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.4428 Acc: 0.8034
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.7030 Acc: 0.7858

Epoch 30/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.4042 Acc: 0.8043
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.8665 Acc: 0.7979
--Saving checkpoint--

Epoch 31/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.4100 Acc: 0.8159
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.8686 Acc: 0.7651

Epoch 32/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.3393 Acc: 0.8213
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.9714 Acc: 0.7841

Epoch 33/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.5386 Acc: 0.7758
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.6928 Acc: 0.7539

Epoch 34/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.3841 Acc: 0.8146
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.7903 Acc: 0.7608

Epoch 35/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.5341 Acc: 0.7885
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 4.7178 Acc: 0.5734

Epoch 36/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.5930 Acc: 0.7570
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.8653 Acc: 0.7651

Epoch 37/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.5823 Acc: 0.7809
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 4.1058 Acc: 0.7358

Epoch 38/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.5226 Acc: 0.7835
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 4.0070 Acc: 0.7781

Epoch 39/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.4232 Acc: 0.8267
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 4.1053 Acc: 0.7358

Epoch 40/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.4616 Acc: 0.7971
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.7150 Acc: 0.7522

Epoch 41/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.4010 Acc: 0.8142
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.9001 Acc: 0.7850

Epoch 42/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.4167 Acc: 0.8270
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 4.2765 Acc: 0.7686

Epoch 43/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.4204 Acc: 0.8190
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 4.2220 Acc: 0.7677

Epoch 44/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.4312 Acc: 0.8101
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 4.2373 Acc: 0.7815

Epoch 45/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.4198 Acc: 0.8134
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 4.2744 Acc: 0.7832

Epoch 46/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.4009 Acc: 0.8263
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 4.3482 Acc: 0.7444

Epoch 47/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.3993 Acc: 0.8213
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 4.6417 Acc: 0.7841

Epoch 48/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.3489 Acc: 0.8419
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 4.1638 Acc: 0.7927

Epoch 49/49
----------
Iterating through data for phase: train...


HBox(children=(IntProgress(value=0, max=73), HTML(value='')))


train Loss: 0.3676 Acc: 0.8358
Iterating through data for phase: val...


HBox(children=(IntProgress(value=0, max=19), HTML(value='')))


val Loss: 3.9802 Acc: 0.7746

Training complete in 23m 48s
Best val Acc: 0.797927


In [151]:
acc = get_accuracy(resnet_ft, dataloaders_ft['test'])

print("Test accuracy: Top-1 = {}, Top-5 = {}".format(acc[0], acc[1]))

HBox(children=(IntProgress(value=0, max=87), HTML(value='')))


Test accuracy: Top-1 = 0.4225249772933697, Top-5 = 0.5831062670299727
