In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

In [None]:
pip install efficientnet_pytorch

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import matplotlib.pyplot as plt 


In [None]:
import warnings
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.data_parallel as dp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm 
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils
import warnings
import gc

warnings.filterwarnings("ignore")

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from skimage import io, transform
from torchvision import models, transforms
import torch.optim as optim
from torch.optim import lr_scheduler
import time
import copy 
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split as ttp
from skimage.filters import threshold_otsu
from skimage.color import rgb2gray
import cv2 as cv 
import pickle
import random 
import albumentations
from transformers import AdamW, get_linear_schedule_with_warmup, get_constant_schedule

In [None]:
class melanoma_dataset(Dataset):
    def __init__(self, root_dir, transform, df = pd.DataFrame() , csv_file = False, train = True):
        
        
        self.df = df
        
        if csv_file:
            self.csv = pd.read_csv(csv_file)
        
        self.directory = root_dir
        
        self.transform = transform
        
        self.train = train
        
        
        
    def __getitem__(self,idx):
        
        if not self.df.empty:
            tab = self.df
        else:
            tab = self.csv
        
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        if tab.iloc[idx,1] == '-1':
            directory = self.directory[1]
        else:
            directory = self.directory[0]
        
        if self.train == False:
            directory = TEST_FOLDER
        img_name = os.path.join(directory, tab.iloc[idx, 0]) + '.jpg'
        img = cv.imread(img_name)

        target= tab.iloc[idx, 3] if self.train else 0
        
        if self.transform:
            #sample= self.transform(image = self['image'], target = self['target'])
            image = self.transform(image = img)
            flipped = image['image']
            image = np.transpose(flipped, (2, 0, 1)).astype(np.float32)
        
        if self.train:
            return image, target
        else:
            return image
        
        
        
        
    
    def __len__(self):
        if not self.df.empty:
            return len(self.df)
        else:
            return len(self.csv)

In [None]:
from efficientnet_pytorch import EfficientNet 

In [None]:
mx = EfficientNet.from_pretrained('efficientnet-b6')
mx._fc = nn.Linear(2304, 1)

In [None]:
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
transform_train = albumentations.Compose([
    albumentations.Normalize(mean, std, always_apply = True),
    albumentations.ShiftScaleRotate(),
    albumentations.Flip(p=0.5)
])

transform_valid = albumentations.Compose([
    albumentations.Normalize(mean, std, always_apply = True),
])

transform_test = albumentations.Compose([
    albumentations.Normalize(mean, std, always_apply = True),
])


In [None]:
DIR0 = '../input/jpeg-melanoma-384x384/train'
DIR1 = '../input/jpeg-isic2019-384x384/train'
FOLD_CSVS ={0:'../input/combined-train/train0.csv',1:'../input/combined-train/train1.csv',2:'../input/combined-train/train2.csv',
            3:'../input/combined-train/train3.csv',4:'../input/combined-train/train4.csv'}
TEST_FOLDER = '../input/jpeg-melanoma-384x384/test'
TEST_CSV = '../input/combined-train/test.csv'
MODELS_PATH = '../input/tpu-models/'

In [None]:
#have to add in datasets/for fold line 

In [None]:
def loss_fn(outputs, targets):
    return torch.nn.BCEWithLogitsLoss()(outputs, targets.view(-1, 1))

In [None]:
def reduce_fn(vals):
    return sum(vals) / len(vals)


lr_start   = 0.000005
lr_max     = 0.000000125 * 8 * 4
lr_min     = 0.00000005
lr_ramp_ep = 5
lr_sus_ep  = 0
lr_decay   = 0.8

def get_lr(epoch):
    if epoch == 0:
        lr = 0.001
        
    elif epoch < lr_ramp_ep:
        lr = (lr_max - lr_start) / lr_ramp_ep * epoch + lr_start

    elif epoch < lr_ramp_ep + lr_sus_ep:
        lr = lr_max

    else:
        lr = (lr_max - lr_min) * lr_decay**(epoch - lr_ramp_ep - lr_sus_ep) + lr_min

    return lr


In [None]:
def train_loop_fn(dataloader, model, optimizer, device, scheduler):
    model.train()
    #torch.set_grad_enabled(True)
    for bi, d in enumerate(dataloader):
    
        inputs = d[0]
        targets = d[1]
        
        inputs = inputs.to(device, dtype = torch.float32)
        targets = targets.to(device, dtype = torch.float32)
        
        optimizer.zero_grad()
        
        outputs = model(inputs)
 
        loss = loss_fn(outputs, targets)
        
        if bi % 50 == 0:
            loss_reduced = xm.mesh_reduce('loss_reduce', loss, reduce_fn)
            xm.master_print(f'bi ={bi}, TRAIN_loss = {loss_reduced}')
            
            loss.backward()
            
            xm.optimizer_step(optimizer)
            if scheduler is not None:
                scheduler.step()
    model.eval()
    
def eval_loop_fn(data_loader, model, device,scheduler = None):
    fin_targets = []
    fin_outputs = []
    val_loss_avg= []
    #torch.set_grad_enabled(False)
    for bi, d in enumerate(data_loader):

        inputs = d[0] 
        targets = d[1] 

        
        inputs = inputs.to(device, dtype=torch.float32)
        targets = targets.to(device, dtype=torch.float32)
        
        outputs = model(inputs)
        ''''if bi % 10 == 0:
            val_loss = loss_fn(outputs, targets)
            val_loss_reduced = xm.mesh_reduce('loss_reduce', val_loss, reduce_fn)
            xm.master_print(f'bi ={bi}, VAL_loss = {val_loss_reduced}')
            val_loss_avg.append(val_loss_reduced.numpy())'''
            
        
        
        targets_np = targets.cpu().detach().numpy().tolist()
        outputs_np = outputs.sigmoid().cpu().detach().numpy().tolist()
        fin_targets.extend(targets_np)
        fin_outputs.extend(outputs_np)    
        del targets_np, outputs_np
        gc.collect() 
        #scheduler.step(np.mean(val_loss_avg))
    return fin_outputs, fin_targets    

mx = EfficientNet.from_pretrained('efficientnet-b6')
mx._fc = nn.Linear(2304, 1) 

In [None]:
model = EfficientNet.from_pretrained('efficientnet-b6')
model._fc = nn.Linear(2304, 1)
model.load_state_dict(torch.load('../input/tpu-models-pretrained/model_0.pth'))
mx = model

BATCH_SIZE = 4
def get_lr(epoch):
    dic = {0:0.00001,1:0.00001,2:0.00001,3:0.00001,4:0.00001,5:0.00005, 6:0.000001,7:0.000005,8:0.000005,9:0.000001, 10:0.000001,11:0.000001,12:0.000001}
    lr = dic[epoch]
    return lr*xm.xrt_world_size()

def get_lr(epoch):
    dic = {0:0.000005, 1:0.00001, 2:0.0001,3:0.0001,4: 0.001, 5: 0.0001, 6:0.00005, 7: 0.00001}
    lr = dic[epoch]
    return lr

In [None]:
def _run(fold):

    
    EPOCHS = 50
    model_path = f'model_{fold}.pth'
    CSV = pd.read_csv(FOLD_CSVS[fold])
    train_df, valid_df, _,_ =  ttp(CSV, np.zeros(len(CSV)), random_state = 0)
    train_dataset = melanoma_dataset([DIR0,DIR1], transform_train, train_df)
    valid_dataset = melanoma_dataset([DIR0,DIR1], transform_valid, valid_df)
    
    # defining data samplers and loaders 
    train_sampler = torch.utils.data.distributed.DistributedSampler(
          train_dataset,
          num_replicas=xm.xrt_world_size(), 
          rank=xm.get_ordinal(), 
          shuffle=True)


        
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
          valid_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=False)

    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=BATCH_SIZE,
        sampler=valid_sampler,
        drop_last=False,
        num_workers=2
    )
    

    device = xm.xla_device()
    model = mx.to(device) 
    xm.master_print('done loading model')


    xm.master_print('training on train dataset')
    

    lr =1e-7 *xm.xrt_world_size()
    num_train_steps = int(len(train_dataset) / BATCH_SIZE / xm.xrt_world_size() * EPOCHS) 
    #optimizer = optim.SGD(model.parameters(), lr = .0001, momentum = 0.9)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=.000005*num_train_steps,
        num_training_steps=num_train_steps
    )
    #scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,get_lr, verbose=False)
    #scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
    
    xm.master_print(f'num_training_steps = {num_train_steps}, world_size={xm.xrt_world_size()}')


    best_auc = 0
    aucs = []
    count = 0
    for epoch in range(EPOCHS):
        #lr = get_lr(epoch)
        xm.master_print(f'learning rate for epoch : {epoch} is {lr}')
        #optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        ts = time.time()
        gc.collect() 
        
        train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        sampler=train_sampler,
        drop_last=True,
        num_workers=2,
        )
        para_loader = pl.ParallelLoader(train_data_loader, [device]) 
        #xm.master_print('parallel loader created... training now')
        gc.collect()
        # call training loop:
        train_loop_fn(para_loader.per_device_loader(device), model, optimizer, device, scheduler=scheduler)
        del para_loader
        para_loader = pl.ParallelLoader(valid_data_loader, [device])
        gc.collect()
        # call evaluation loop
        o, t = eval_loop_fn(para_loader.per_device_loader(device), model, device)
        del para_loader
        gc.collect()
        # report AUC at the end
        auc = roc_auc_score(np.array(t) >= 0.5, o)
        auc_reduced = xm.mesh_reduce('auc_reduce',auc,reduce_fn)
        auc_reduced
        aucs.append(auc_reduced)
        if auc_reduced > best_auc:
            best_auc = auc_reduced
            best_model_wts = copy.deepcopy(model.state_dict())
        xm.master_print(f'{time.time()-ts}----AUC = {auc_reduced}')
        #if epoch > 9:
        #    break
        gc.collect()
    xm.save(best_model_wts, model_path)

In [None]:
gc.collect()

In [None]:
import time

# Start training processes
def _mp_fn(rank, flags):
        a = _run(fold)

FLAGS={}
start_time = time.time()
for fold in range(5):
    model = EfficientNet.from_pretrained('efficientnet-b6')
    model._fc = nn.Linear(2304, 1)
    model.load_state_dict(torch.load('../input/tpu-models-pretrained/model_0.pth'))
    mx = model
    xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')