In [None]:
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import dataset, dataloader

import matplotlib.pyplot as plt
import time
import os, json
import copy

from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split

import os
from collections import Counter, OrderedDict
import re
import requests
import tarfile
from PIL import Image

In [None]:
data_dir = 'data2'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'
test_dir = data_dir + '/test'

Data prep for pets data following this [notebook](https://colab.research.google.com/github/akashmehra/blog/blob/fastbook/lessons/_notebooks/2021-07-20-pets_classifier.ipynb#scrollTo=ekNHMAUtklXS)

In [None]:
# normalize images with the following transform mean and std
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
train_transform = transforms.Compose([transforms.RandomResizedCrop((224,224)),     # resize image
                                      transforms.RandomHorizontalFlip(), # augment
                                      transforms.ToTensor(),    # change to Tensor
                                      transforms.Normalize(mean = mean, 
                                                            std = std)  ])# normalize with mean & std from docs


val_transform = transforms.Compose([transforms.Resize((224,224)),               # resize image
                                    transforms.CenterCrop(224),
                                    transforms.ToTensor(),    # change to Tensor
                                    transforms.Normalize(mean = mean, 
                                                         std = std)  ])# normalize with mean & std from docs


In [None]:
def fetch_data(url, data_dir, download=False):
    if download:
        response = requests.get(url, stream=True)
        file = tarfile.open(fileobj=response.raw, mode="r|gz")
        file.extractall(path=data_dir)

In [None]:
#collapse-hide
pets_url = 'https://s3.amazonaws.com/fast-ai-imageclas/oxford-iiit-pet.tgz'
#data_dir = os.path.join('drive', 'MyDrive', 'pets_data')
base_img_dir = os.path.join(data_dir, 'oxford-iiit-pet', 'images')
fetch_data(pets_url, data_dir, True)

In [None]:
class RegexLabelExtractor():
    def __init__(self, pattern):
        self.pattern = pattern
        self._names = []
    
    def __call__(self, iterable):
        return [re.findall(self.pattern, value)[0] for value in iterable]

In [None]:
class LabelManager():
    def __init__(self, labels):
        self._label_to_idx = OrderedDict()    
        for label in labels:
            if label not in self._label_to_idx:
                self._label_to_idx[label] = len(self._label_to_idx)
        self._idx_to_label = {v:k for k,v in self._label_to_idx.items()}
    
    @property
    def keys(self):
        return list(self._label_to_idx.keys())
    
    def id_for_label(self, label):
        return self._label_to_idx[label]
    
    def label_for_id(self, idx):
        return self._idx_to_label[idx]
    
    def __len__(self):
        return len(self._label_to_idx)

In [None]:
class Splitter():
    def __init__(self, valid_pct=0.2, seed = None):
        self.seed = seed
        self.valid_pct = valid_pct
    
    def __call__(self, dataset):
        return train_test_split(dataset, test_size=self.valid_pct, random_state=np.random.RandomState(self.seed))


In [None]:
class PetsDataset(dataset.Dataset):
    def __init__(self, data, tfms=None):
        super(PetsDataset, self).__init__()
        self.data = data
        self.transforms = tfms
    
    def __getitem__(self, idx):
        X = Image.open(self.data[idx][0])
        if X.mode != 'RGB':
            X = X.convert('RGB')
        y = self.data[idx][1]
        if self.transforms:
            X = self.transforms(X)
        return (X, y)
    
    def __len__(self):
        return len(self.data)
    

In [None]:
class DatasetManager():
    
    def __init__(self, base_dir, paths, label_extractor, tfms=None, valid_pct=0.2, seed=None):
        self._labels = label_extractor(paths)
        self.tfms = tfms
        self._label_manager = LabelManager(self._labels)
        self._label_ids = [self.label_manager.id_for_label(label) for label in self._labels]

        self.abs_paths = [os.path.join(base_dir, path) for path in paths]
        self.train_data, self.valid_data = Splitter(valid_pct=valid_pct, seed=seed)(list(zip(self.abs_paths, self._label_ids)))
        
        
    @property
    def label_manager(self):
        return self._label_manager
    
    @property
    def train_dataset(self):
        return PetsDataset(self.train_data, tfms=self.tfms)

    @property
    def valid_dataset(self):    
        return PetsDataset(self.valid_data, tfms=self.tfms)
    

In [None]:
paths = [path for path in sorted(os.listdir(base_img_dir)) if path.endswith('.jpg')]
pattern = '(.+)_\d+.jpg$'
regex_label_extractor = RegexLabelExtractor(pattern)
dataset_manager = DatasetManager(base_img_dir, paths, regex_label_extractor, 
                                 tfms=val_transform, 
                                 seed=42)
train_dataset = dataset_manager.train_dataset
valid_dataset = dataset_manager.valid_dataset

In [None]:
#collapse-output
import pandas as pd
df = pd.DataFrame(dataset_manager.label_manager.keys, columns=['label_name'])
df.head(len(df))

Unnamed: 0,label_name
0,Abyssinian
1,Bengal
2,Birman
3,Bombay
4,British_Shorthair
5,Egyptian_Mau
6,Maine_Coon
7,Persian
8,Ragdoll
9,Russian_Blue


In [None]:
import math
def plot_one_batch(batch, max_images=9):
    nrows = int(math.sqrt(max_images))
    ncols = int(math.sqrt(max_images))
    if nrows * ncols != max_images:
        nrows = (max_images + ncols - 1) // ncols 
    fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20, 10))
    X,Y = next(batch)
    for idx, x in enumerate(X[:max_images]):
        y = Y[idx]
        ax.ravel()[idx].imshow(transforms.ToPILImage()(x))
        ax.ravel()[idx].set_title(f'{y}/{dataset_manager.label_manager.label_for_id(y.item())}')
        ax.ravel()[idx].set_axis_off()
    plt.tight_layout()
    plt.show()

In [None]:
# these will terrible becuase I didn't unnormalize photos before plotting
def generate_one_batch(dl):
    for batch in dl:
        yield batch

train_dl = torch.utils.data.DataLoader(train_dataset, batch_size = 32, shuffle = True)
plot_one_batch(generate_one_batch(train_dl), max_images=20)

In [None]:
len(train_dataset)

5912

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
NUM_EPOCHS = 5
feature_extract = False
print(device)

cuda:0


In [None]:
# Create training and validation datasets
image_datasets = {"train": train_dataset,
                  "val": valid_dataset}
# Create training and validation dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=True) for x in ['train', 'val']}

In [None]:
image_datasets

{'train': <__main__.PetsDataset at 0x7fe335cbc410>,
 'val': <__main__.PetsDataset at 0x7fe3353d6910>}

In [None]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [None]:
model_ft = models.resnet34(pretrained=True)
set_parameter_requires_grad(model_ft, False)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 37)   # we have 37 classes
# # re-initialize random weights & biases for fc layer
# nn.init.kaiming_normal_(model.fc.weight) 
# nn.init.zeros_(model.fc.bias)
model_ft.to(device)


params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
else:
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            print("\t",name)


In [None]:
# https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html#run-training-and-validation-step
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False):
    since = time.time()

    val_acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device) 
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    # Special case for inception because in training it has an auxiliary output. In train
                    #   mode we calculate the loss by summing the final output and the auxiliary output
                    #   but in testing we only consider the final output.
                    if is_inception and phase == 'train':
                        # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
                        outputs, aux_outputs = model(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4*loss2
                    else:
                        outputs = model(inputs)
                        loss = criterion(outputs , labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
                val_acc_history.append(epoch_acc)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(params_to_update, lr = 0.0001, momentum = 0.9)

In [None]:
model_ft2, hist2 = train_model(model_ft, dataloaders_dict, criterion, optimizer, num_epochs = 18)

Epoch 0/17
----------
train Loss: 3.3576 Acc: 0.1563
val Loss: 2.8563 Acc: 0.3870

Epoch 1/17
----------
train Loss: 2.4686 Acc: 0.5438
val Loss: 2.0308 Acc: 0.6996

Epoch 2/17
----------
train Loss: 1.7650 Acc: 0.7556
val Loss: 1.4433 Acc: 0.8153

Epoch 3/17
----------
train Loss: 1.3039 Acc: 0.8353
val Loss: 1.0785 Acc: 0.8593

Epoch 4/17
----------
train Loss: 1.0217 Acc: 0.8596
val Loss: 0.8539 Acc: 0.8796

Epoch 5/17
----------
train Loss: 0.8321 Acc: 0.8813
val Loss: 0.6988 Acc: 0.8951

Epoch 6/17
----------
train Loss: 0.7146 Acc: 0.8928
val Loss: 0.6116 Acc: 0.9039

Epoch 7/17
----------
train Loss: 0.6185 Acc: 0.8970
val Loss: 0.5370 Acc: 0.9161

Epoch 8/17
----------
train Loss: 0.5469 Acc: 0.9073
val Loss: 0.4816 Acc: 0.9147

Epoch 9/17
----------
train Loss: 0.4898 Acc: 0.9198
val Loss: 0.4413 Acc: 0.9168

Epoch 10/17
----------
train Loss: 0.4498 Acc: 0.9222
val Loss: 0.4054 Acc: 0.9242

Epoch 11/17
----------
train Loss: 0.4022 Acc: 0.9285
val Loss: 0.3794 Acc: 0.9235

Ep

In [None]:
torch.save(model_ft2.state_dict(), '/content/drive/MyDrive/model_pets.pt')
# torch.save(model_ft2.state_dict(), '/content/drive/MyDrive/model_2SGD_pets.pt')

***