# 1. Import Statements and Constant Declarations (e.g. config variables, paths)

In [None]:
import copy
import os
import time
import random
import shutil
from typing import Type, Tuple
from abc import ABC

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms

STORAGE_DIR = '/content/drive/My Drive/MyCNN'
STORAGE_MODEL = os.path.join(STORAGE_DIR, 'best-model')
STORAGE_MODEL_TEMP = STORAGE_MODEL + '.tmp'
STORAGE_ONNX = os.path.join(STORAGE_DIR, "kitchen_classifier.onnx")

DATA_DIR = '/content/utensil-images'
BATCH_SIZE = 16
INPUT_SIZE = 224

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# 2. Retrieve data from github repository




In [None]:
# git-lfs is needed for the dataset
!apt-get install git-lfs
!git lfs install

KEEP_FREQ = 2

def move_files(path: str, base_dir: str, keep_frequency: int = 0) -> None:
    base_in = os.path.join(base_dir, path)

    drop_counter = 0
    for file_name in os.listdir(base_in):
        if file_name == 'test':
            target_dir = os.path.join(base_dir, 'validation', path)
            print(target_dir)
            shutil.rmtree(target_dir, ignore_errors=True)
            shutil.move(os.path.join(base_in, file_name),
                            target_dir)
            continue
        
        if keep_frequency:
            drop_counter += 1
            if drop_counter != keep_frequency:
                continue

        drop_counter = 0
        target_dir = os.path.join(base_dir, 'train', path)
        os.makedirs(target_dir, exist_ok=True)
        shutil.move(os.path.join(base_dir, path, file_name),
                        os.path.join(target_dir, file_name))
    
    return

    for file_name in os.listdir(base_in):
        if file_name == 'test':
            target = os.path.join(base_dir, 'test', path)
            print(target)
            shutil.rmtree(target, ignore_errors=True)
            shutil.move(os.path.join(base_in, file_name),
                        target)
            continue
        rand = random.random()
        total_prob = 0
        for prob, target in probabilities.items():
            target_dir = os.path.join(base_dir, target, path)
            total_prob += prob
            if rand < total_prob:
                target_dir = os.path.join(base_dir, target, path)
                shutil.move(os.path.join(base_dir, path, file_name),
                            os.path.join(base_dir, target_dir, file_name))
                break

def partition_image_data(path: str) -> int:
    dirs = [f.name for f in os.scandir(path) if f.is_dir()]
    print(dirs)

    for directory in dirs:
      move_files(directory, path, KEEP_FREQ)
    
    [shutil.rmtree(os.path.join(path, directory)) for directory in dirs]
    
    return len(dirs)

!rm -rf kitchen-utensils/
!git clone https://github.com/BeeblebroxIV/kitchen-utensils.git
shutil.rmtree(DATA_DIR, ignore_errors=True)

os.mkdir(DATA_DIR)
!tar -xf kitchen-utensils/kitchen-utensils.tar.gz --directory "/content/utensil-images"
NUM_CLASSES = partition_image_data(DATA_DIR)
!tar -xf kitchen-utensils/background-test.tar.gz --directory "/content/utensil-images"
!tar -xf kitchen-utensils/distance-test.tar.gz --directory "/content/utensil-images"

# 3. Setup dataloaders as well as functions for training / visualisation

In [None]:
# Precomputed for dataset
mean = [0.4399, 0.4211, 0.3609]
std = [0.1220, 0.1210, 0.1228]

data_transforms = {}

for f in os.listdir(DATA_DIR):
    data_transforms[f] = transforms.Compose([
        transforms.Resize(INPUT_SIZE),
        transforms.CenterCrop(INPUT_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

data_transforms['train'] = data_transforms['crop540']

'''transforms.Compose([
        transforms.RandomResizedCrop(INPUT_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
])'''


image_datasets = {x: datasets.ImageFolder(os.path.join(DATA_DIR, x),
                                          transform)
                  for x, transform in data_transforms.items()}
dataloaders = {x: torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE,
                                             shuffle=True, num_workers=8)
              for x, dataset in image_datasets.items()}
dataset_sizes = {x: len(dataset) for x, dataset in image_datasets.items()}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)

# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))

# make sure that we can store the model
os.makedirs(STORAGE_DIR, exist_ok=True)
print(STORAGE_DIR)

def train_model(model: torch.nn.modules.Module, criterion, optimizer, 
                scheduler, num_epochs: int) -> Tuple[torch.nn.modules.Module, float]:
    """
    Trains a pytorch Model for a certain number of epoch, stopping early if
    no improvement was made in the last 10 epochs.

    :return: The best model trained during the process (measured by the
        validation accuracy) as well as it's accuracy on the validation set
    """
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    best_epoch = 0
    impr_counter = 0
    stop = False

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'validation']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                impr_counter += 1
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    if INPUT_SIZE == 299 and phase == 'train':  # inception
                        outputs, aux_outputs = model(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4*loss2
                    else:
                        outputs = model(inputs)

                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy and store the model
            if phase == 'validation' and epoch_acc > best_acc:
                impr_counter = 0
                best_acc = epoch_acc
                best_epoch = epoch
                print('Storing')
                torch.save(model, STORAGE_MODEL_TEMP)
                best_model_wts = copy.deepcopy(model.state_dict())
            
            if impr_counter > 9:
                print('No improvement in last 10 iterations; aborting')
                stop = True

        if stop:
            break
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f} at epoch {}'.format(best_acc, best_epoch))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, best_acc


def visualize_model(model: torch.nn.modules.Module, num_images=6):
    """
    Visualises a pytorch model by showing it's results for validation
    images.
    """
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['validation']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title('predicted: {}'.format(class_names[preds[j]]))
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)


def get_confusion_matrix(model: torch.nn.modules.Module,
                         data: str = 'validation') -> torch.tensor:
    """
    Retrieves the confusion matrix of a trained pytorch model instance for
    the given data set.
    """
    confusion_matrix = torch.zeros(NUM_CLASSES, NUM_CLASSES)
    for i, (inputs, classes) in enumerate(dataloaders[data]):
        inputs = inputs.to(device)
        classes = classes.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        for t, p in zip(classes.view(-1), preds.view(-1)):
            confusion_matrix[t.long(), p.long()] += 1
    return confusion_matrix

4. # Declare unified classifier classes for the ResNet50, DenseNet161 and VGG16 models

In [None]:
print("Using GPU: ", torch.cuda.is_available())


class Model(ABC):
    """
    Abstract class offering a unified interface for easily training various
    different models requiring different setup steps.
    """
    def __init__(self, for_transfer_learning=True):
        self._model = self._get_model()

        if for_transfer_learning:
            self.__setup_for_transfer()

        self._setup_classifier()
        self._model = self._model.to(device)

        self.__optimizer = optim.SGD(self._model.parameters(), lr=0.0001,
                               momentum=0.8)
        self.__scheduler = lr_scheduler.StepLR(self.__optimizer,
                                               step_size=5,
                                               gamma=0.1)

    def _get_model(self):
        raise NotImplementedError

    def __setup_for_transfer(self):
        for param in list(self._model.parameters())[:-1]:
            param.requires_grad = False

    def _setup_classifier(self):
        raise NotImplementedError()

    def optimize(self) -> float:
        self._model = self._model.to(device)
        criterion = nn.CrossEntropyLoss()
        self._model, best_acc = train_model(self._model, criterion,
                                            self.__optimizer, self.__scheduler,
                                            num_epochs=1)
        return best_acc

    @property
    def model(self):
        return self._model

class VGG16(Model):

    def _get_model(self):
        return torchvision.models.vgg16(pretrained=True)

    def _setup_classifier(self):
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, NUM_CLASSES),
        )


class ResNet50(Model):
    def _get_model(self):
        return torchvision.models.resnet50(pretrained=True)

    def _setup_classifier(self):
        nr_features = self._model.fc.in_features
        self._model.fc = nn.Linear(nr_features, NUM_CLASSES)


class DenseNet161(Model):
    def _get_model(self):
        return torchvision.models.densenet161(pretrained=True)

    def _setup_classifier(self):
        nr_features = self._model.classifier.in_features
        self._model.classifier = nn.Linear(nr_features, NUM_CLASSES)


# 5 Declare utility functions

In [None]:
def get_best(model_cls=ResNet50, nr_repeats=10) -> torch.nn.modules.Module:
    """
    Trains multiple instances of a certain model, storing the best instance
    under STORAGE_MODEL

    :param model_cls: The class used to instantiate model instances
    :param nr_repeats: How often the the model is meant to be instantiated
        and trained before picking the model with the highest accuracy
    """
    best_acc = -1
    best_model = None
    for _ in range(nr_repeats):
        model = model_cls(True)
        acc = model.optimize()
        if acc > best_acc:
            best_acc = acc
            best_model = model.model
            torch.save(best_model, STORAGE_MODEL)

    return best_model


def get_average_performance(model_cls: Type[Model], nr_repeats=10) -> float:
    """
    Gets the average accuracy for a specific model type for a given number
    of repeats.

    :param model_cls: The class used to instantiate model instances
    :param nr_repeats: How often the the model is meant to be instantiated,
        trained and evaluated
    """
    avg_acc = 0
    for _ in range(nr_repeats):
        model = model_cls(False)
        avg_acc += model.optimize()
    
    return avg_acc / nr_repeats

def get_average_confusion(model_cls: Type[Model], nr_repeats=10,
                          dataset='validation') -> torch.tensor:
    """
    Instantiates and trains a specific model class before 

    :param model_cls: The class used to instantiate model instances
    :param nr_repeats: How often the the model is meant to be instantiated,
        trained and evaluated
    :param dataset: The dataset to be used; should be a folder name found in
        the DATA_DIR directory
    :return: The average confusion matrix of dimension
        NUM_CLASSES x NUM_CLASSES
    """
    avg_acc = 0
    avg_conf = torch.zeros(NUM_CLASSES, NUM_CLASSES)
    for _ in range(nr_repeats):
        model = model_cls(False)
        acc = model.optimize()
        conf = get_confusion_matrix(model.model, data=dataset)
        print(acc, conf)
        avg_conf += conf

    return avg_conf / nr_repeats


def get_test_confusion(model: torch.nn.modules.Module,
                       exclude_prefix: str = 'crop') -> torch.tensor:
    """
    Evaluates an already trained pytorch model on the available training
    data sets by acquiring the summated confusion matrix for all the test
    data in the DATA_DIR directory.

    :param model: The model meant to be evaluated
    :param exclude_prefix: Prefix of datasets that are meant to be excluded
        from the evaluation
    """
    confs = []
    for dataset in dataloaders:
        if dataset.startswith(exclude_prefix) or dataset in ['train', 'validation']:
            continue

        confs.append(get_confusion_matrix(model, dataset))

    return sum(confs)


def get_test_f1(model: torch.nn.modules.Module,
                exclude_prefix: str = 'crop') -> torch.tensor:
    """
    Evaluates an already trained pytorch model on the available training
    data sets by acquiring the F1 scores for the individual classes for each for all individual test
    data sets in the DATA_DIR directory.

    :param model: The model meant to be evaluated
    :param exclude_prefix: Prefix of datasets that are meant to be excluded
        from the evaluation
    :return: A matrix of shape NUM_CLASSES x <number of test sets>: each row features the F1-scores for a particular
        test scenario
    """
    f1_scores = []
    confs = []
    for dataset in dataloaders:
        if dataset.startswith(exclude_prefix) or dataset in ['train', 'validation']:
            continue

        conf = get_confusion_matrix(model, dataset)
        confs.append(copy.deepcopy(conf))
        trans_conf = copy.deepcopy(conf.transpose(0, 1))

        for i, row in enumerate(conf):
            if any(list(row)):
                row /= sum(row)
        for i, row in enumerate(trans_conf):
            if any(list(row)):
                row /= sum(row)

        recall = np.array(list(row[i] for i, row in enumerate(conf)))
        precision = np.array(list(row[i] for i, row in enumerate(trans_conf)))
        f1_scores.append(2 * recall * precision / precision + recall)

    return np.array(f1_scores).transpose()


# 6. Get best network from multiple tries

In [None]:
model = get_best(ResNet50, 10)
print(get_test_f1(model))
print(get_test_confusion(model))

# 7. Export Model to onnx format for integration in OpenCV

In [None]:
model = torch.load(STORAGE_MODEL)
dummy_input = torch.randn(1, 3, INPUT_SIZE, INPUT_SIZE, device='cuda')
torch.onnx.export(model, dummy_input,
                  STORAGE_ONNX,
                  opset_version=11)
print('Finished')