<a href="https://colab.research.google.com/github/ricglz/CE888_activities/blob/main/assignment/Project_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision skorch timm

Collecting git+https://github.com/khornlund/pytorch-balanced-sampler
  Cloning https://github.com/khornlund/pytorch-balanced-sampler to /tmp/pip-req-build-vem7zrgb
  Running command git clone -q https://github.com/khornlund/pytorch-balanced-sampler /tmp/pip-req-build-vem7zrgb
Collecting skorch
[?25l  Downloading https://files.pythonhosted.org/packages/18/c7/2f6434f9360c91a4bf14ae85f634758e5dacd3539cca4266a60be9f881ae/skorch-0.9.0-py3-none-any.whl (125kB)
[K     |████████████████████████████████| 133kB 16.8MB/s 
[?25hCollecting timm
[?25l  Downloading https://files.pythonhosted.org/packages/22/c6/ba02d533cec7329323c7d7a317ab49f673846ecef202d4cc40988b6b7786/timm-0.3.4-py3-none-any.whl (244kB)
[K     |████████████████████████████████| 245kB 29.2MB/s 
Building wheels for collected packages: pytorch-balanced-sampler
  Building wheel for pytorch-balanced-sampler (setup.py) ... [?25l[?25hdone
  Created wheel for pytorch-balanced-sampler: filename=pytorch_balanced_sampler-0.0.1-py2.py3-

## Preparations

Before we begin, lets mount the google drive to later on read information from it:

---

In [None]:
from google.colab import drive

drive_path = '/content/gdrive'
drive.mount(drive_path, force_remount=False)
drive_path += '/MyDrive'

Mounted at /content/gdrive


Next we will set the seeds in everything to make this as deterministic as possible

In [None]:
import torch
import random
import numpy as np

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True

## Gather the dataset

For this we will create both our training _(which later on will be splitted into actual training an validation)_ and testing dataset.

Pytorch also allows us to have transformations like the resize and the normalization. The normalization used are [the mean and std of the ImageNet dataset](https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/constants.py)

----

In [None]:
import torchvision.transforms as T
from os import path

data_dir = path.join(drive_path, 'Flame')
resize = T.Resize((254, 254))
normalize = T.Normalize([0.485, 0.456, 0.406], 
                         [0.229, 0.224, 0.225])

In [None]:
train_transforms = T.Compose([
  resize,
  T.RandomHorizontalFlip(),
  T.RandomVerticalFlip(),
  T.ToTensor(),
  normalize
])
transforms = T.Compose([
  resize,
  T.ToTensor(),
  normalize
])

In [None]:
import torchvision.datasets as datasets
train_ds = datasets.ImageFolder(path.join(data_dir, 'Training'),
                                train_transforms)
len(train_ds)

29264

In [None]:
test_ds = datasets.ImageFolder(path.join(data_dir, 'Test'), transforms)
len(test_ds)

8617

## Create modular model 

---

In [None]:
from torch.nn import Linear, Module
import timm

f_params = None

class PretrainedModel(Module):
    def __init__(self, model='rexnet'):
        super().__init__()
        model_name = self.get_model_name(model)
        self.model = timm.create_model(
            model_name, pretrained=True, num_classes=1)
    
    def get_model_name(self, general_model):
        return 'rexnet_200' if general_model == 'rexnet' else \
               'tf_efficientnet_b8' if general_model == 'efficientnet' else ''

    def forward(self, x):
        return self.model(x).squeeze(-1)

## Defining the API

---

### Callbacks

In this case the only Callback that will be used in every model will be an early stopping callback

In [None]:
from skorch.callbacks import EarlyStopping, Freezer, LRScheduler, ProgressBar

is_top_layer = lambda x: not x.startswith('model.fc') and \
                            not x.startswith('model._fc') and \
                            not x.startswith('model.head') and \
                            not x.startswith('model.classifier')
freezer = Freezer(is_top_layer) 
early_stopping = EarlyStopping(patience=3)
scheduler = LRScheduler(policy='StepLR', gamma=9e-1, step_size=1)
progress_bar = ProgressBar()

### Classifier class

In [None]:
from torch import float64
from skorch.classifier import NeuralNetBinaryClassifier
from skorch.utils import to_tensor, to_numpy
import sklearn.metrics as sk_metrics 
import numpy as np

class MyClassifier(NeuralNetBinaryClassifier):
    def infer(self, x, **fit_params):
        x = to_tensor(x, device=self.device)
        if isinstance(x, dict):
            x_dict = self._merge_x_and_fit_params(x, fit_params)
            return self.module_(**x_dict).to(device=self.device, dtype=float64)
        return self.module_(x, **fit_params).to(device=self.device, dtype=float64)

    def train_step_single(self, Xi, yi, **fit_params):
        self.module_.train()
        y_pred = self.infer(Xi, **fit_params)
        yi = yi.to(device=self.device, dtype=float64)
        loss = self.get_loss(y_pred, yi, X=Xi, training=True)
        loss.backward()
        return { 'loss': loss, 'y_pred': y_pred }

    def validation_step(self, Xi, yi, **fit_params):
        self.module_.eval()
        y_pred = self.infer(Xi, **fit_params)
        yi = yi.to(device=self.device, dtype=float64)
        loss = self.get_loss(y_pred, yi, X=Xi, training=False)
        return { 'loss': loss,'y_pred': y_pred }

    def _get_y_values(self, X):
        y_true, y_pred = [], []
        nonlinearity = self._get_predict_nonlinearity()
        for images, labels in self.get_iterator(X):
            images = images.to(self.device)
            outputs = nonlinearity(self.module_(images))
            _, predicted = torch.max(outputs.data, 1)
            y_true.append(to_numpy(labels))
            y_pred.append(to_numpy(predicted))
        y_true = np.concatenate(y_true)
        y_pred = np.concatenate(y_pred)
        return y_true, y_pred

    def score(self, X):
        y_true, y_pred = self._get_y_values(X)
        return sk_metrics.roc_auc_score(y_true, y_pred)
    
    def scores(self, X):
        y_true, y_pred = self._get_y_values(X)
        accuracy = sk_metrics.accuracy_score(y_true, y_pred)
        confusion_matrix = sk_metrics.confusion_matrix(y_true, y_pred)
        f1 = sk_metrics.f1_score(y_true, y_pred)
        auc = sk_metrics.roc_auc_score(y_true, y_pred)
        return accuracy, confusion_matrix, f1, auc

    def print_and_plot_scores(self, X):
        accuracy, confusion_matrix, f1, auc = self.scores(X)
        print(f'Accuracy: {accuracy}')
        print(f'F1 Score: {f1}')
        print(f'AUC: {auc}')
        disp = sk_metrics.ConfusionMatrixDisplay(
          confusion_matrix, display_labels=['Fire', 'No_Fire'])
        disp.plot()

### Classifier helper functions

The next code will be used to create helper functions to easily create, fit and evaluate different type of CNN architectures

In [None]:
from torch.optim import Adam
from skorch.callbacks import Checkpoint
from skorch.dataset import CVSplit

def create_model(module_model):
    global f_params

    f_params = path.join(drive_path, f'Models/best_{module_model}.pt')
    checkpoint = Checkpoint(f_params=f_params, monitor='valid_acc_best')
    callbacks = [checkpoint, freezer, early_stopping, scheduler]
    lr = 2e-3

    return MyClassifier(
        PretrainedModel,
        module__model=module_model,
        optimizer=Adam,
        lr=lr,
        batch_size=32,
        max_epochs=10,
        iterator_train__shuffle=True,
        iterator_train__num_workers=16,
        iterator_valid__shuffle=True,
        iterator_valid__num_workers=16,
        train_split=CVSplit(0.2, random_state=seed),
        callbacks=callbacks,
        device='cuda'
    )

In [None]:
def create_and_fit(model_name):
    net = create_model(model_name)
    net.fit(train_ds)
    return net

## Models results

---

### Rexnet

In [None]:
rexnet = create_and_fit('rexnet')

TypeError: ignored

In [None]:
rexnet.print_and_plot_scores(test_ds)

### EfficientNet

In [None]:
efficientnet = create_and_fit('efficientnet')

In [None]:
efficientnet.print_and_plot_scores(test_ds)