<a href="https://colab.research.google.com/github/utsavnandi/Kaggle-SIIM-ISIC-Melanoma-Classification/blob/master/TPU_SIIM_ISIC_Melanoma_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## One-time


In [1]:
# import os
# assert os.environ['COLAB_TPU_ADDR']
# VERSION = "nightly"  #@param ["1.5" , "20200516", "nightly"]
# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
# !python pytorch-xla-env-setup.py --version $VERSION # --apt-packages libomp5 libopenblas-dev


In [2]:
# %%time
# !pip uninstall kaggle -y
# !pip install kaggle==1.5.6 -q
# !pip install -U catalyst -q
# !pip install -U git+https://github.com/albu/albumentations -q
# !pip install -U git+https://github.com/rwightman/pytorch-image-models -q
# !pip install -U git+https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer -q
# !pip install -U git+https://github.com/PyTorchLightning/pytorch-lightning -q
# !pip install neptune-client -q
# !mkdir ~/.kaggle/
# !cp ./kaggle.json  ~/.kaggle/kaggle.json
# !chmod 600 ~/.kaggle/kaggle.json
# !kaggle datasets download -d shonenkov/melanoma-merged-external-data-512x512-jpeg
# !unzip melanoma-merged-external-data-512x512-jpeg.zip -d ./data/
# !rm melanoma-merged-external-data-512x512-jpeg.zip
# !kaggle competitions download siim-isic-melanoma-classification -f sample_submission.csv
# !kaggle competitions download siim-isic-melanoma-classification -f test.csv
# !kaggle competitions download siim-isic-melanoma-classification -f train.csv
# !unzip train.csv -d ./data/
# !mv ./test.csv ./data/
# !mv ./sample_submission.csv ./data/
# !rm train.csv.zip
# !mkdir ./logs/

## Setup

In [3]:
import os
#os.environ['XLA_USE_BF16'] = "0"
import gc
import time
import datetime
import random
from getpass import getpass
import numpy as np
import cv2
import pandas as pd
import matplotlib.pyplot as plt
#import seaborn as sns
##from tqdm.notebook import tqdm
from google.colab import auth
from google.cloud import storage

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, roc_curve

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms, models

import torch_xla
import torch_xla.core.xla_model as xm
#import torch_xla.utils.serialization as xser
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu
from torch.utils.data.distributed import DistributedSampler

# from pytorch_lightning.core.lightning import LightningModule
# from pytorch_lightning import Trainer
# #from pytorch_lightning.loggers import TensorBoardLogger
# from pytorch_lightning.metrics.functional import auroc

from ranger import Ranger
from catalyst.data.sampler import DistributedSamplerWrapper, BalanceClassSampler
import timm

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

seed_everything(43)



numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject



In [4]:
DATA_DIR = '/content/data/'

In [5]:
df_train = pd.read_csv(DATA_DIR+'folds.csv')
df_test = pd.read_csv(DATA_DIR+'test.csv').rename(columns={'image_name':'image_id'})
sample_submission = pd.read_csv(DATA_DIR+'sample_submission.csv')

In [6]:
df_train['fold'].value_counts()

0    12219
1    12106
2    12072
4    12048
3    12042
Name: fold, dtype: int64

In [7]:
fold_no = 1
X_train = df_train[df_train['fold'] != fold_no][[col for col in df_train.columns if col != 'target']]
y_train = df_train[df_train['fold'] != fold_no][[col for col in df_train.columns if col == 'target']]
X_val = df_train[df_train['fold'] == fold_no][[col for col in df_train.columns if col != 'target']]
y_val = df_train[df_train['fold'] == fold_no][[col for col in df_train.columns if col == 'target']]

In [8]:
print('X_train', X_train.shape)
print('y_train', y_train.shape)
print('X_val', X_val.shape)
print('y_val', y_val.shape)

X_train (48381, 8)
y_train (48381, 1)
X_val (12106, 8)
y_val (12106, 1)


In [9]:
print('Train target distribution: ')
print(y_train['target'].value_counts())
print('Val target distribution: ')
print(y_val['target'].value_counts())

Train target distribution: 
0    43997
1     4384
Name: target, dtype: int64
Val target distribution: 
0    11011
1     1095
Name: target, dtype: int64


##  Dataset

In [10]:
class MelanomaDataset(Dataset):
    
    def __init__(self, df, labels, istrain=False, transforms=None):
        super().__init__()
        self.image_id = df['image_id'].values
        self.transforms = transforms
        self.labels = labels.values
        self.neg_indices = np.where(self.labels==0)[0]
        self.pos_indices = np.where(self.labels==1)[0]
        self.istrain = istrain

    def __len__(self):
        return len(self.image_id)

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()
        
        image, target = self.load_image(index)

        if self.transforms:
            image = self.transforms(image=image)['image']

        #if self.istrain == False:
        #      return image, target

        #rng = np.random.choice([1, 2])
        #if rng == 1:
        #    image, target = self.cutmix(image, target, alpha=1)
        #elif rng == 2:
        #    image, target = self.mixup(image, target, alpha=1) 
        ##else:
        ##    target = [target]

        return image, target#, rng
    
    def load_image(self, index):
        if torch.is_tensor(index):
            index = index.tolist()
        image_name = DATA_DIR + f'512x512-dataset-melanoma/512x512-dataset-melanoma/{self.image_id[index]}.jpg'
        image = cv2.imread(image_name, cv2.IMREAD_COLOR).astype(np.uint8)
        target = self.labels[index].astype(np.float32)
        return image, target

    def cutmix(self, data, target, alpha=1):
        rand_index = self.get_rand_index()
        random_image, random_target = self.load_image(rand_index)
        if self.transforms:
            random_image = self.transforms(image=random_image)['image']
        lam = np.random.beta(alpha, alpha)
        bbx1, bby1, bbx2, bby2 = self.rand_bbox(data.size(), lam)
        data[:, bbx1:bbx2, bby1:bby2] = random_image[ :, bbx1:bbx2, bby1:bby2]
        # adjust lambda to exactly match pixel ratio
        lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (data.size()[-1] * data.size()[-2]))
        targets = [target, random_target, lam]
        return data, targets

    def mixup(self, data, target, alpha=1):
        rand_index = self.get_rand_index()
        random_image, random_target = self.load_image(rand_index)
        if self.transforms:
            random_image = self.transforms(image=random_image)['image']
        lam = np.random.beta(alpha, alpha)
        data = data * lam + random_image * (1 - lam)
        targets = [target, random_target, lam]
        return data, targets

    def get_rand_index(self):
        if random.random()>0.5:
            rand_index = np.random.choice(self.pos_indices)
        else:
            rand_index = np.random.choice(self.neg_indices)
        return rand_index

    def rand_bbox(self, size, lam):
        W = size[1]
        H = size[2]
        cut_rat = np.sqrt(1. - lam)
        cut_w = np.int(W * cut_rat)
        cut_h = np.int(H * cut_rat)
        cx = np.random.randint(W)
        cy = np.random.randint(H)
        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)
        return bbx1, bby1, bbx2, bby2

def get_datasets():
    datasets = {}
    datasets['train'] = MelanomaDataset(
        X_train, y_train, istrain=True, transforms=get_train_transforms()
    )
    datasets['valid'] = MelanomaDataset(
        X_val, y_val, istrain=False, transforms=get_valid_transforms()
    )
    return datasets


## Augmentations

In [11]:
#%%writefile augmentations.txt
# Reference IMG_SIZE
# B0 - 224
# B1 - 240
# B2 - 260
# B3 - 300
# B4 - 380
# B5 - 456
# B6 - 520
# B7 - 600
# B8 - 672
# L2 NS - 475
# L2 - 800
# Transforms
IMG_SIZE = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

def get_train_transforms(p=1.0):
    return A.Compose([
        A.Resize(IMG_SIZE, IMG_SIZE, interpolation=2, always_apply=True, p=1),
        A.RandomResizedCrop(
            IMG_SIZE, IMG_SIZE, scale=(0.8, 1.2), interpolation=2, p=0.33
        ),
        A.Flip(p=0.33),
        A.Transpose(p=0.33),
        #A.OneOf([
        #    A.MedianBlur(blur_limit=3, p=0.5),
        #    A.Blur(blur_limit=3, p=0.5),
        #], p=0.5),
        A.ShiftScaleRotate(
            interpolation=2,
            shift_limit=0.0625, scale_limit=0.15, 
            rotate_limit=15, p=0.3
        ),
        #A.OneOf([
        #    A.OpticalDistortion(p=0.3),
        #    A.GridDistortion(p=.1),
        #    A.IAAPiecewiseAffine(p=0.3),
        #], p=0.5),
        #A.OneOf([
        #    A.CLAHE(clip_limit=2),
        #    A.IAASharpen(),
        #    A.IAAEmboss(),
        #    A.RandomBrightnessContrast(),            
        #], p=0.5),
        A.HueSaturationValue(
            hue_shift_limit=20, sat_shift_limit=30, 
            val_shift_limit=20, p=0.33
        ),
        A.MultiplicativeNoise(
            multiplier=[0.75, 1.25], 
            elementwise=True, p=0.33
        ),
        A.Normalize(mean, std, max_pixel_value=255.0, always_apply=True),
        ToTensorV2(p=1.0),
    ], p=p)

def get_valid_transforms():
    return A.Compose([
        A.Resize(IMG_SIZE, IMG_SIZE, interpolation=2, always_apply=True, p=1),
        A.Normalize(mean, std, max_pixel_value=255.0, always_apply=True),
        ToTensorV2(p=1.0),
    ])


## Model

In [12]:
class ResNet18(nn.Module): 

    def __init__(self):
        super().__init__()
        self.model = models.resnet18(pretrained=True)
        in_features = self.model.fc.in_features
        self.model.fc = nn.Linear(in_features, 1)
    
    def forward(self, x):
        x = self.model(x)
        return x

class Tf_efficientnet_b0_ns(nn.Module):

    def __init__(self):
        super().__init__()
        self.model = timm.create_model('tf_efficientnet_b0_ns', pretrained=True)
        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(in_features, 1)

    def forward(self, x):
        return self.model(x)

class Tf_efficientnet_b3_ns(nn.Module):

    def __init__(self):
        super().__init__()
        self.model = timm.create_model('tf_efficientnet_b3_ns', pretrained=True)
        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(in_features, 1)

    def forward(self, x):
        return self.model(x)

class Tf_efficientnet_b3_ns_Mod(nn.Module):

    def __init__(self):
        super().__init__()
        self.model = timm.create_model('tf_efficientnet_b3_ns', pretrained=True)
        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(in_features, int(in_features/2))
        self.bn_1 = nn.BatchNorm1d(int(in_features/2))
        self.relu_1 = nn.ReLU()
        self.drop_1 = nn.Dropout(0.2)
        self.fc_2 = nn.Linear(int(in_features/2), 1)

    def forward(self, x):
        x = self.model(x)
        x = self.bn_1(x)
        x = self.relu_1(x)
        x = self.drop_1(x)
        x = self.fc_2(x)
        return x

class Tf_efficientnet_b3_ns_Mod_v2(nn.Module):

    def __init__(self):
        super().__init__()
        self.model = timm.create_model('tf_efficientnet_b3_ns', pretrained=True)
        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(in_features, in_features)
        self.relu_1 = nn.ReLU()
        self.bn_1 = nn.BatchNorm1d(in_features)
        self.fc_1 = nn.Linear(in_features, int(in_features/2))
        self.bn_2 = nn.BatchNorm1d(int(in_features/2))
        self.fc_2 = nn.Linear(int(in_features/2), 1)

    def forward(self, x):
        x = self.model(x)
        x = self.relu_1(x)
        x = self.bn_1(x)
        x = self.fc_1(x)
        x = self.bn_2(x)
        x = self.fc_2(x)
        return x

class Gluon_seresnext101_32x4d(nn.Module):

    def __init__(self):
        super().__init__()
        self.model = timm.create_model('gluon_seresnext101_32x4d', pretrained=True)
        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(in_features, 1)

    def forward(self, x):
        return self.model(x)

class Gluon_seresnext50_32x4d(nn.Module):

    def __init__(self):
        super().__init__()
        self.model = timm.create_model('gluon_seresnext50_32x4d', pretrained=True)
        in_features = self.model.fc.in_features
        self.model.fc = nn.Linear(in_features, 1)

    def forward(self, x):
        return self.model(x)

## Custom Losses

In [13]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, logits=True, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduce = reduce

    def forward(self, inputs, targets):
        if self.logits:
            BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        else:
            BCE_loss = F.binary_cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss

def ohem_loss(cls_pred, cls_target, rate):
    ohem_cls_loss = F.binary_cross_entropy_with_logits(cls_pred, cls_target, reduction='none')
    batch_size = cls_pred.size(0)
    sorted_ohem_loss, idx = torch.sort(ohem_cls_loss, descending=True)
    keep_num = min(sorted_ohem_loss.size()[0], int(batch_size*rate))
    if keep_num < sorted_ohem_loss.size()[0]:
        keep_idx_cuda = idx[:keep_num]
        ohem_cls_loss = ohem_cls_loss[keep_idx_cuda]
    cls_loss = ohem_cls_loss.sum() / keep_num
    return cls_loss

def criterion(y_pred, y_true):
    return nn.BCEWithLogitsLoss()(y_pred, y_true)

def focal_criterion(y_pred, y_true):
    return FocalLoss(alpha=(43997/4384))(y_pred, y_true)


## Train script

In [14]:

class RocAucMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.y_true = np.array([0,1])
        self.y_pred = np.array([0.5,0.5])
        self.score = 0

    def update(self, y_true, y_pred):
        y_true = y_true.cpu().numpy()
        y_pred = torch.flatten(torch.sigmoid(y_pred)).data.cpu().numpy()
        self.y_true = np.append(self.y_true, y_true)
        self.y_pred = np.append(self.y_pred, y_pred)
        self.score = roc_auc_score(self.y_true, self.y_pred)

    @property
    def avg(self):
        return self.score
        

In [15]:

exp_description = '''
Tf_efficientnet_b0_ns with base head,
Extra Data
Focal loss
mixup only
RandomSampler,
changed aug,
imsize 300
'''

#SERIAL_EXEC = xmp.MpSerialExecutor()
#WRAPPED_MODEL = xmp.MpModelWrapper(Tf_efficientnet_b0_ns())
WRAPPED_MODEL = Tf_efficientnet_b0_ns()

def train_model():

    best_score = 0.0
    datasets = SERIAL_EXEC.run(get_datasets)
    #datasets = get_datasets()

    #sampler
    #labels_vcount = y_train['target'].value_counts()
    #class_counts = [labels_vcount[0].astype(np.float32), labels_vcount[1].astype(np.float32)]
    #num_samples = sum(class_counts)
    #class_weights = [num_samples/class_counts[i] for i in range(len(class_counts))]
    #weights = [class_weights[y_train['target'].values[i]] for i in range(int(num_samples))]
    #wrsampler = WeightedRandomSampler(torch.DoubleTensor(weights), int(num_samples))
    #BalanceClassSampler(labels=y_train['target'].values, mode="downsampling"),
    
    #train_sampler = DistributedSamplerWrapper(
    #    sampler=wrsampler,
    #    num_replicas=xm.xrt_world_size(),
    #    rank=xm.get_ordinal(),
    #    shuffle=True
    #)
    
    train_sampler =  DistributedSampler(
        datasets['train'],
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True
    )
    validation_sampler = DistributedSampler(
        datasets['valid'],
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False
    )
    train_loader = DataLoader(
        datasets['train'],
        batch_size=FLAGS['batch_size'], 
        num_workers=FLAGS['num_workers'],
        sampler=train_sampler,
        drop_last=True,
    )
    val_loader = DataLoader(
        datasets['valid'],
        batch_size=FLAGS['batch_size'],
        num_workers=FLAGS['num_workers'],
        sampler=validation_sampler,
        drop_last=True
    )
    
    device = xm.xla_device()
    model = WRAPPED_MODEL.to(device)
    #optimizer = Ranger(
    #    model.parameters(),
    #    lr=FLAGS['learning_rate'] * xm.xrt_world_size(), 
    #    alpha=0.5, k=6, N_sma_threshhold=5,
    #    weight_decay=FLAGS['weight_decay']
    #)
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=FLAGS['learning_rate'] * xm.xrt_world_size(), 
        weight_decay=FLAGS['weight_decay']
    )

    criterion = focal_criterion
    #criterion = nn.BCEWithLogitsLoss()

    def train_one_epoch(loader):
        model.train()
        running_loss = 0.0
        max_idx = 0.0
        for idx, (images, targets) in enumerate(loader):
            images = images.to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            y_pred = model(images.float())
            loss = criterion(y_pred, targets)
            running_loss += float(loss)
            loss.backward()
            xm.optimizer_step(optimizer)
            max_idx = float(idx)
            if idx % FLAGS['log_steps'] == 0 and idx !=0:
                #print('[xla:{}]({}) Loss={:.5f} Time={}'.format(
                #    xm.get_ordinal(), idx, loss.item(), time.asctime(time.localtime())), flush=True)
                xm.master_print('({}) Loss={:.5f} Time={}'.format(
                    idx, loss.item(), time.asctime(time.localtime())))
        #return running_loss/(max_idx+1)

    def val_one_epoch(loader):
        model.eval()
        running_loss = 0.0
        total_samples = 0
        max_idx = 0.0
        #roc_auc_scores = RocAucMeter()
        with torch.no_grad():
            for idx, (images, targets) in enumerate(loader):
                images = images.to(device)
                targets = targets.to(device)
                y_pred = model(images.float())
                loss = criterion(y_pred, targets)
                running_loss += float(loss)
                max_idx = float(idx)
                total_samples += images.size()[0]
                #roc_auc_scores.update(targets, y_pred)

        #return running_loss/(max_idx+1), roc_auc_scores

    for epoch in range(0, FLAGS['num_epochs']):

        xm.master_print('-'*27 + f'Epoch #{epoch+1} started' + '-'*27)
        xm.master_print(f'epoch start time: {time.asctime(time.localtime())}')

        para_loader = pl.ParallelLoader(train_loader, [device])
        train_one_epoch(para_loader.per_device_loader(device)) # train_loss = 
        xm.master_print("finished training epoch {}".format(epoch+1))
        #xm.master_print(f'average loss for epoch #{epoch+1} : {train_loss}')
        #gc.collect()
        
        if (epoch+1) % 1 == 0:
            para_loader = pl.ParallelLoader(val_loader, [device])
            val_one_epoch(para_loader.per_device_loader(device)) # val_loss, auc_score = 
            xm.master_print("finished validating epoch {}".format(epoch+1))
            #xm.master_print(f'roc_auc_score: {auc_score.avg:.5f}')
            #xm.master_print(f'average loss for val epoch: {val_loss}')
            #gc.collect()
            #if auc_score.avg > best_score:
            #    best_score = auc_score.avg
                #xm.master_print(f'new best score: {best_score}')
                #if epoch > 3:
                
            xm.save(model.state_dict(), f'./model_{epoch+1}.pth')
                # xser.save(
                #     model.state_dict(), 
                #     f'./model_epoch_{epoch+1}.pth', 
                #     master_only=True
                # )
                #xser.save(model.state_dict(), f"model.bin", master_only=True)
                #gc.collect()
                #model.load_state_dict(xser.load(f"./model_epoch_{epoch+1}.pth"))
            xm.master_print(f'saved model...')
        
        xm.master_print(f'epoch end time: {time.asctime(time.localtime())}')
        xm.master_print('-'*28 + f'Epoch #{epoch+1} ended' + '-'*28)
    
    #return best_score


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth" to /root/.cache/torch/hub/checkpoints/tf_efficientnet_b0_ns-c0e6a31c.pth


## Train

In [None]:
FLAGS = {}
FLAGS['batch_size'] = 32
FLAGS['num_workers'] = 0
FLAGS['learning_rate'] = 2e-4
FLAGS['num_epochs'] = 10
FLAGS['weight_decay'] = 1e-4
FLAGS['log_steps'] = 40
FLAGS['img_size'] = IMG_SIZE
FLAGS['loss'] = 'BCE'
FLAGS['optimizer'] = 'AdamW'
FLAGS['exp_name'] = 'Tf_efficientnet_b0_ns'
FLAGS['fold'] = 1
FLAGS['num_cores'] = 8

def _mp_fn(rank, flags):
  global FLAGS
  #global X_train, y_train, X_val, y_val
  FLAGS = flags
  torch.set_default_tensor_type('torch.FloatTensor')
  #best_score = 
  train_model()

xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8,
          start_method='fork')


---------------------------Epoch #1 started---------------------------
epoch start time: Wed Jun 24 05:32:05 2020
(40) Loss=0.50114 Time=Wed Jun 24 05:33:51 2020
(80) Loss=1.07857 Time=Wed Jun 24 05:34:43 2020
(120) Loss=0.54557 Time=Wed Jun 24 05:35:38 2020
(160) Loss=1.00449 Time=Wed Jun 24 05:36:28 2020
finished training epoch 1
finished validating epoch 1
saved model...
epoch end time: Wed Jun 24 05:38:04 2020
----------------------------Epoch #1 ended----------------------------
---------------------------Epoch #2 started---------------------------
epoch start time: Wed Jun 24 05:38:04 2020
(40) Loss=0.42083 Time=Wed Jun 24 05:39:07 2020
(80) Loss=0.65260 Time=Wed Jun 24 05:39:56 2020
(120) Loss=0.41624 Time=Wed Jun 24 05:40:48 2020
(160) Loss=0.74269 Time=Wed Jun 24 05:41:41 2020
finished training epoch 2
finished validating epoch 2
saved model...
epoch end time: Wed Jun 24 05:43:16 2020
----------------------------Epoch #2 ended----------------------------
----------------------

## Lightning

In [None]:
# FLAGS = {}
# FLAGS['batch_size'] = 32
# FLAGS['num_workers'] = 8
# FLAGS['learning_rate'] = 3e-4
# FLAGS['num_epochs'] = 2
# FLAGS['weight_decay'] = 1e-4
# FLAGS['img_size'] = IMG_SIZE
# FLAGS['loss'] = 'Focal'
# FLAGS['optimizer'] = 'AdamW'
# FLAGS['exp_name'] = 'Tf_efficientnet_b0_ns'
# FLAGS['fold'] = 1
# FLAGS['num_cores'] = 8

# class LitModel(LightningModule):

#     def __init__(self):
#         super().__init__()
#         self.model = timm.create_model('tf_efficientnet_b0_ns', pretrained=True)
#         in_features = self.model.classifier.in_features
#         self.model.classifier = nn.Linear(in_features, 1)

#     def forward(self, x):
#         return self.model(x)

#     def training_step(self, batch, batch_idx):
#         image, target = batch
#         y_pred = self(image.float())
#         loss = focal_criterion(y_pred, target)
#         tensorboard_logs = {'train_loss': loss}
#         return {'loss': loss, 'log': tensorboard_logs}

#     def validation_step(self, batch, batch_idx):
#         image, target = batch
#         y_pred = self(image.float())
#         loss = focal_criterion(y_pred, target)
#         score = auroc(y_pred.item().long(), target.item().long())
#         return {'val_loss': loss, 'score': score}
    
#     def validation_epoch_end(self, outputs):
#         avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
#         avg_score = torch.stack([x['score'] for x in outputs]).mean()
#         tensorboard_logs = {'val_loss': avg_loss, 'avg_score': avg_score}
#         return {'val_loss': avg_loss, 'avg_score': avg_score, 'log': tensorboard_logs}

#     def configure_optimizers(self):
#         return torch.optim.AdamW(
#             self.parameters(), lr=FLAGS['learning_rate'],
#             weight_decay=FLAGS['weight_decay']
#         )

#     def train_dataloader(self):
#         train_ds = MelanomaDataset(
#             X_train, y_train, istrain=True, transforms=get_train_transforms()
#         )
#         train_loader = DataLoader(
#             train_ds,
#             batch_size=FLAGS['batch_size'], 
#             num_workers=FLAGS['num_workers'],
#             shuffle=True
#         )
#         return train_loader
    
#     def val_dataloader(self):
#         val_ds = MelanomaDataset(
#             X_val, y_val, istrain=False, transforms=get_valid_transforms()
#         )
#         val_loader = DataLoader(
#             val_ds,
#             batch_size=FLAGS['batch_size'], 
#             num_workers=FLAGS['num_workers'],
#             shuffle=False,
#         )
#         return val_loader


In [None]:
# model = LitModel()

# trainer = Trainer(
#     tpu_cores=8, max_epochs=FLAGS['num_epochs'],
#     checkpoint_callback=False,
#     row_log_interval=20
# )

# trainer.fit(model)

In [None]:
# %load_ext tensorboard
# %tensorboard --logdir ./lightning_logs

In [None]:
#!rm -rf /content/lightning_logs