In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
if 'google.colab' in sys.modules:
    %pip install -q skorch
    !git clone https://github.com/youyinnn/medical_imaging_imbalancing.git
    %cd /content/medical_imaging_imbalancing/src/isic
sys.path.insert(0, os.path.abspath('..'))

In [None]:
from utils.resnet_model import ResNet18, ResNet50, ResNet34
import utils.skorch_trainer as isic_skorch_trainer
from isic_dataset import ISIC_2018
import torchvision.transforms as transforms

from utils.skorch_trainer import ResetedSkorchLRScheduler
from sklearn.metrics import top_k_accuracy_score
from torchvision.transforms import v2


import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# torch.manual_seed(42);

isic_train = ISIC_2018(verbose=1)
isic_val = ISIC_2018(split='val', verbose=1)

isic_train_128 = ISIC_2018(size=(128, 128))
isic_val_128 = ISIC_2018(split='val', size=(128, 128))

isic_train_32 = ISIC_2018(size=(32, 32))
isic_val_32 = ISIC_2018(split='val', size=(32, 32))

isic_train_448 = ISIC_2018(size=(448))
isic_val_448 = ISIC_2018(split='val', size=(448))

isic_train_cutmixed = ISIC_2018(cut_mixed=True)

### 18_256

In [None]:
max_epochs = 150
lr = 0.01

net_18_256_no_lrsc = isic_skorch_trainer.net_fit(
    isic_skorch_trainer.net_def(
        ResNet18, 
        net_name = '18_256_no_lrsc',
        classes=[torch.tensor, torch.tensor],    
        classifier_kwargs = dict(
            lr = lr,
            module__output_features = 7,
            train_split = isic_val,
            callbacks = []
            )
        ), 
isic_train, None, max_epochs)

: 

In [None]:
max_epochs = 150
lr = 0.01

net_18_256 = isic_skorch_trainer.net_fit(
    isic_skorch_trainer.net_def(
        ResNet18, 
        net_name = '18_256',
        classes=[torch.tensor, torch.tensor],    
        classifier_kwargs = dict(
            lr = lr,
            module__output_features = 7,
            train_split = isic_val,
            callbacks = [
                ResetedSkorchLRScheduler(
                    'valid_f1', ['valid_acc', 'valid_f1'], 10,
                    policy='StepLR', step_size=7, last_epoch = -1
                )]
            )
        ), 
isic_train, None, max_epochs)

In [None]:
max_epochs = 100
lr = 0.01

t1 = v2.Compose([
    v2.RandomChoice([v2.RandomPerspective(distortion_scale=0.1, p=0.8), v2.RandomRotation(degrees=(0, 360))])
])

net_18_256_t1 = isic_skorch_trainer.net_fit(
    isic_skorch_trainer.net_def(
        ResNet18, 
        net_name = '18_256_t1',
        classes=[torch.tensor, torch.tensor],    
        classifier_kwargs = dict(
            lr = lr,
            module__output_features = 7,
            train_split = isic_val,
            callbacks = []
            )
        ), 
ISIC_2018(transform=t1), None, max_epochs)

In [None]:
max_epochs = 100
lr = 0.01

t2 = transforms.Compose([
    v2.RandomChoice([
        v2.ColorJitter(),
        v2.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5.))
    ])
])

net_18_256_t2 = isic_skorch_trainer.net_fit(
    isic_skorch_trainer.net_def(
        ResNet18, 
        net_name = '18_256_t2',
        classes=[torch.tensor, torch.tensor],    
        classifier_kwargs = dict(
            lr = lr,
            module__output_features = 7,
            train_split = isic_val,
            callbacks = []
            )
        ), 
ISIC_2018(transform=t2), None, max_epochs)

### 34_256

In [None]:
net_2 = isic_skorch_trainer.net_fit(
    isic_skorch_trainer.net_def(
        ResNet34, 
        net_name = '34_256',
        classes=[torch.tensor, torch.tensor],    
        classifier_kwargs = dict(
            lr = 0.01,
            module__output_features = 7,
            train_split = isic_val,
            callbacks = [
                ResetedSkorchLRScheduler(
                    'valid_f1', ['valid_acc', 'valid_f1'],20,
                    policy='StepLR', step_size=7, last_epoch = -1
                )]
            )
        ), 
isic_train, None, 0)

### CutMix

In [None]:
max_epochs = 100
lr = 0.01

net_18_256_cutmix_no_lrsc = isic_skorch_trainer.net_def(
    ResNet18, 
    net_name = '18_256_cutmix_no_lrsc',
    classes=[torch.tensor, torch.tensor],    
    classifier_kwargs = dict(
        lr = lr,
        module__output_features = 7,
        train_split = isic_val,
        callbacks = []
    ),
    cut_mixed=True
)

net_1_for_cutmixed = isic_skorch_trainer.net_fit(
    net_18_256_cutmix_no_lrsc, isic_train_cutmixed, None, max_epochs)

### CutOut

In [None]:
max_epochs = 100
lr = 0.01
from isic_dataset import CutOut

cutout = v2.Compose([
    CutOut(p=0.7)
])

net_18_256_cutout = isic_skorch_trainer.net_fit(
    isic_skorch_trainer.net_def(
        ResNet18, 
        net_name = '18_256_cutout',
        classes=[torch.tensor, torch.tensor],    
        classifier_kwargs = dict(
            lr = lr,
            module__output_features = 7,
            train_split = isic_val,
            callbacks = []
            )
        ), 
ISIC_2018(transform=cutout), None, max_epochs)

### No Pre-train

In [None]:
max_epochs = 150
lr = 0.01

net_18_256_no_pret = isic_skorch_trainer.net_fit(
    isic_skorch_trainer.net_def(
        ResNet18, 
        net_name = '18_256_no_pret',
        classes=[torch.tensor, torch.tensor],    
        classifier_kwargs = dict(
                lr = lr,
                module__output_features = 7,
                module__weights = None,
                train_split = isic_val,
                callbacks = []
            )
        ), 
isic_train, None, max_epochs)

### Visual

In [None]:
import matplotlib.pyplot as plt

# plt.imshow(isic_train_cutmixed[0][0][0].numpy().transpose(1,2,0) / 255)
# plt.show()

plt.imshow(isic_train[87][0].numpy().transpose(1,2,0) / 255)
plt.show()

In [None]:
torch.randint(low=0, high=10, size=(1,)).item()

### Demo

In [None]:

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from sklearn.datasets import make_classification

X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X, y = X.astype(np.float32), y.astype(np.int64)
X.shape, y.shape, y.mean()

class ClassifierModule(nn.Module):
    def __init__(
            self,
            num_units=10,
            nonlin=F.relu,
            dropout=0.5,
    ):
        super(ClassifierModule, self).__init__()
        self.num_units = num_units
        self.nonlin = nonlin
        self.dropout = dropout

        self.dense0 = nn.Linear(20, num_units)
        self.nonlin = nonlin
        self.dropout = nn.Dropout(dropout)
        self.dense1 = nn.Linear(num_units, 10)
        self.output = nn.Linear(10, 2)

    def forward(self, X, **kwargs):
        X = self.nonlin(self.dense0(X))
        X = self.dropout(X)
        X = F.relu(self.dense1(X))
        X = F.softmax(self.output(X), dim=-1)
        return X

In [None]:
# net_2 = isic_skorch_trainer.net_fit(
#     isic_skorch_trainer.net_def(
#         ClassifierModule, 
#         net_name = 'eee',
#         classes=[torch.tensor, torch.tensor],    
#         classifier_kwargs = dict(
#             lr = 0.001,
#             train_split = isic_val_32,
#             callbacks = [
#                 ResetedSkorchLRScheduler('valid_f1', ['valid_acc', 'valid_f1', 'valid_loss'], 3,
#                     policy='StepLR', step_size=3, last_epoch = -1)]
#             )
#         ), 
# isic_train_32, None, 30)
from skorch import NeuralNetClassifier
from skorch.callbacks import LRScheduler, EpochScoring, PrintLog
from sklearn.metrics import top_k_accuracy_score

net = NeuralNetClassifier(
    ClassifierModule,
    max_epochs=50,
    lr=0.01,
#     device='cuda',  # uncomment this to train with CUDA
    device = 'mps', 
    callbacks = [
        EpochScoring(scoring='f1_macro', name='valid_f1',
                     lower_is_better=False),
        # EpochScoring(scoring=lambda net, X, y:top_k_accuracy_score(y, net.predict(X), k=1), name='valid_top_k_acc',
        #              lower_is_better=False),
        ResetedSkorchLRScheduler('valid_f1', ['valid_acc', 'valid_f1'], 10,
                    policy='StepLR', step_size=5, last_epoch = -1)
    ]
)

# Training the network
net.fit(X, y)