In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py > /dev/null 2>&1
!python pytorch-xla-env-setup.py --version 20210331 --apt-packages libomp5 libopenblas-dev > /dev/null 2>&1

In [None]:
import sys
sys.path.append("../input/timm-pytorch-image-models/pytorch-image-models-master/")

import platform
import numpy as np
import pandas as pd
import os
from tqdm.notebook import tqdm
import cv2
import random
import glob
import gc
from math import ceil
import albumentations as A
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, confusion_matrix
from skimage.exposure import exposure, equalize_hist,equalize_adapthist
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import mean_squared_error

import torch
import timm
import time
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.optim.lr_scheduler import ReduceLROnPlateau, LambdaLR, StepLR
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils
import warnings

warnings.simplefilter('ignore')
np.set_printoptions(suppress=True)
os.environ['XLA_USE_BF16']="1"
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '100000000'

In [None]:
from torch.autograd import Variable

class bceFocalLoss(nn.Module):
    def __init__(self, gamma=2):
        super().__init__()
        self.gamma = gamma
        
    def forward(self, input, target, reduction='mean'):
        n = input.shape[-1]
        input = input.view(-1).float()
        target = target.view(-1).float()
        loss = -target*F.logsigmoid(input)*torch.exp(self.gamma*F.logsigmoid(-input)) -\
           (1.0 - target)*F.logsigmoid(-input)*torch.exp(self.gamma*F.logsigmoid(input))
        
        return n*loss.mean() if reduction=='mean' else loss

class FocalLoss(nn.Module):
    def __init__(self, gamma=0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha])
        if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim()>2:
            input = input.view(input.size(0),input.size(1),-1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1,2)    # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1,input.size(2))   # N,H*W,C => N*H*W,C
        target = target.view(-1,1)

        logpt = F.log_softmax(input)
        logpt = logpt.gather(1,target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type()!=input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0,target.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1-pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()

In [None]:
class Config:
    image_size = 384
    batch_size = 8*8
    epochs = 15
    seed = 2021
    lr = 5e-5 / 8  
    workers = 8
    drop_last = True
    
    def get_loss_fn():
        #return nn.CrossEntropyLoss()
        return nn.MSELoss()
        #return FocalLoss(gamma=1.5)

    def get_optimizer(model, learning_rate):
        return torch.optim.Adam(model.parameters(), lr=learning_rate)

    def get_scheduler(optimizer):
        return ReduceLROnPlateau(optimizer, 
                                 mode='min', 
                                 factor=0.5, 
                                 patience=3, 
                                 threshold=0.0001,
                                 verbose=False, 
                                 min_lr=1e-6,
                                 eps=1e-08)
    

# Make results reproducible
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    
seed_everything(Config.seed)

In [None]:
train = pd.read_csv("../input/simple-eda-using-pandas-profiling/train.csv",index_col=0)
train

# check shape

512,768,1024

In [None]:
classes = train.columns[1:13]
category_name_to_id = {index:class_name for index, class_name in enumerate(classes)}
len(category_name_to_id)

In [None]:
def get_train_transforms():
    return A.Compose(
            [
             A.RandomSizedCrop(
                         min_max_height=[int(0.95*Config.image_size), int(1.0*Config.image_size)],
                         height=Config.image_size,
                         width=Config.image_size, 
                         p=1.0),
                      
              #A.HorizontalFlip(p=0.5),
              #A.VerticalFlip(p=0.5),
              A.Rotate(
                    limit=5,
                    p=0.6,
                ),              
            ])

In [None]:
class inputds(Dataset):
    def __init__(self, df, augments=True):
        super().__init__()
        self.df = df.sample(frac=1).reset_index(drop=True)
        if augments:
          self.augments = get_train_transforms()
        else:
          self.augments = None

        
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        index = row.Id
        imagepad = cv2.imread(f"../input/petfinder-imgds/image_size{Config.image_size}/image_size{Config.image_size}/{index}.jpg", cv2.IMREAD_COLOR)
        h,w,c= imagepad.shape
        image = np.zeros((Config.image_size,Config.image_size,3))
        if h > w:
          x = h
          y = random.randint(0,Config.image_size-w)
          image[0:x,y:y+w] = imagepad
        else:
          x = random.randint(0,Config.image_size-h)
          y = w
          image[x:x+h,0:y] = imagepad


        #image = cv2.resize(image,(384,224))
        #if self.augments:
        #image = self.augments(image=image)["image"]
        image = torch.tensor(image/255.0,dtype=torch.float)
        image = image.permute(2, 0, 1)
        labels = np.array([row[category_name_to_id[i]] for i in range(12)])
        score = row.Pawpularity/100

        return image, labels, score
    
    def __len__(self):
        return len(self.df)

In [None]:
"""
ds =  inputds(train,augments=True)
for i,(image,labels, score) in enumerate(ds):
    print(image.shape,labels.shape,score)
    plt.imshow(image.permute(1,2,0))
    plt.show()
    if i==20:
        break
"""

In [None]:
from sklearn.model_selection import KFold, StratifiedKFold
RANDOM_STATE = 35

kfold = KFold(n_splits=5, random_state=RANDOM_STATE, shuffle=True)
#skfold = StratifiedKFold(n_splits=5, random_state=RANDOM_STATE, shuffle=True)
splits= kfold.split(train)
train_indexs = []
test_indexs = []
for i,(train_index, test_index) in enumerate(splits):
    print(train_index.shape,test_index.shape)
    train_indexs.append(train_index)
    test_indexs.append(test_index)

In [None]:
from timm.models.efficientnet import *
class Net0(nn.Module):
    def __init__(self):
        super(Net0, self).__init__()

        self.eff = tf_efficientnet_b3(pretrained=True, drop_rate=0.3, drop_path_rate=0.2,in_chans=3)

        self.rlogit = nn.Linear(1000,128)
        self.dropout = nn.Dropout(0.1)
        self.fc1 = nn.Linear(140,64)
        self.fc2 = nn.Linear(64,1)

    # @torch.cuda.amp.autocast()
    def forward(self, image, dense):
        #batch_size = len(image)
        x = image

        x = self.eff(x)  
        x = self.dropout(x)
        x = self.rlogit(x)
        x = torch.cat([x, dense], dim=1)
        x = self.fc1(x)
        score = self.fc2(x)
        
        return score

In [None]:
#!pip install torchsummary
#from torchsummary import summary
#summary(Net0(),(3,384,384))
model_=Net0()

In [None]:
class Record:
    '''
    Records labels and predictions within one epoch
    '''
    def __init__(self):
        self.labels = []
        self.preds = []
        
    def update(self, cur_labels, cur_logits):
        cur_labels = cur_labels.detach().cpu().numpy()
        cur_logits = cur_logits.sigmoid().detach().cpu().numpy()
        #cur_logits = np.exp(cur_logits.detach().cpu().numpy())
        #cur_preds = cur_logits / np.sum(cur_logits, axis=1, keepdims=True)
        self.labels.append(cur_labels)
        self.preds.append(cur_logits)

    def get_labels(self):
        return np.concatenate(self.labels) # (n, )

    def get_preds(self):
        return np.concatenate(self.preds, axis=0) # (n, 4)
    
    @staticmethod
    def get_acc(confusion_mat):
        return round(np.sum(np.eye(4) * confusion_mat) / np.sum(confusion_mat) * 100, 2)

    @staticmethod
    def get_rmse(preds,labels):
        return np.sqrt(mean_squared_error(preds*100, labels*100))

In [None]:
class Trainer:
    def __init__(self, model, optimizer, loss_fn, maskloss_fn, device):
        """
        Constructor for Trainer class
        """
        self.model = model
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.maskloss_fn = maskloss_fn
        self.device = device
    
    def train_one_cycle(self, train_loader):
        """
        Runs one epoch of training, backpropagation and optimization
        """
        self.model.train()
        total_loss = 0
        total_nums = 0
        #record = Record()

        for idx, (xtrain, xlabel, ys) in enumerate(train_loader):
            xtrain = xtrain.to(self.device, dtype=torch.float)
            xlabel = xlabel.to(self.device, dtype=torch.float)
            ys = ys.to(self.device, dtype=torch.float)
            

            self.optimizer.zero_grad()
            preds = self.model(xtrain,xlabel)
            loss1 = self.loss_fn(preds[:,0], ys)
            loss = torch.sqrt(loss1)
            
            total_loss += (loss.detach().item() * ys.size(0))
            total_nums += ys.size(0)
            
            loss.backward()
            del loss
            xm.optimizer_step(self.optimizer)
            
        self.model.eval()
        return total_loss / total_nums

    def valid_one_cycle(self, valid_loader):
        """
        Runs one epoch of prediction
        """
        self.model.eval()
        total_loss = 0
        total_nums = 0
        record = Record()
        
        for idx, (xval, xlabel, ys) in enumerate(valid_loader):
            with torch.no_grad():
                xval = xval.to(self.device, dtype=torch.float)
                xlabel = xlabel.to(self.device, dtype=torch.float)
                ys = ys.to(self.device, dtype=torch.float)

                pred = self.model(xval,xlabel)
                record.update(ys, pred)

        return record.get_labels(), record.get_preds()

In [None]:
def _mp_fn(rank, flags):
    '''
    Train and valid
    '''
    torch.set_default_tensor_type('torch.FloatTensor')

    # Sets a common random seed both for initialization and ensuring graph is the same
    torch.manual_seed(Config.seed)

    # Acquires the (unique) Cloud TPU core corresponding to this process's index
    device = xm.xla_device()
    
    # load the model into each tpu core
    model = model_.to(device)
    
    # Creates the (distributed) train sampler
    # which let this process only access its portion of the training dataset.  
    train_sampler = DistributedSampler(
        train_set,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True,
    )
    train_loader = DataLoader(
        train_set,
        batch_size=int(Config.batch_size/xm.xrt_world_size()),
        sampler=train_sampler,
        drop_last=Config.drop_last,
        num_workers=Config.workers,
    )
    valid_sampler = DistributedSampler(
        valid_set,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False,
    )
    valid_loader = DataLoader(
        valid_set,
        batch_size=int(Config.batch_size/xm.xrt_world_size()),
        sampler=valid_sampler,
        drop_last=Config.drop_last,
        num_workers=Config.workers,
    )

    optimizer = Config.get_optimizer(model, Config.lr * xm.xrt_world_size())
    #loss_fn = Config.get_loss_fn()
    loss_fn = nn.BCEWithLogitsLoss()
    maskloss_fn = bceFocalLoss(gamma=1.5)
    #maskloss_fn = nn.BCEWithLogitsLoss()
                               
    scheduler = Config.get_scheduler(optimizer)
    
    trainer = Trainer(
        model=model,
        optimizer=optimizer,
        loss_fn=loss_fn,
        maskloss_fn = maskloss_fn,
        device=device,
    )

    
    for epoch in range(Config.epochs):
        xm.master_print(f"{'-'*30} EPOCH: {epoch+1}/{Config.epochs} {'-'*30}")
        
        # Run one training epoch
        para_loader = pl.ParallelLoader(train_loader, [device])
        train_loss = trainer.train_one_cycle(para_loader.per_device_loader(device))

        # Compute training metrics
        train_loss_avg = xm.mesh_reduce('train_loss_reduce', train_loss, lambda alist: sum(alist) / len(alist))
        xm.master_print(f"Train Loss: {train_loss_avg:.4f}")
        

        # Run one validation epoch
        para_loader = pl.ParallelLoader(valid_loader, [device])
        valid_labels, valid_preds = trainer.valid_one_cycle(para_loader.per_device_loader(device))
        
        valid_labels_concat = xm.mesh_reduce('valid_labels_concat', valid_labels, lambda alist: np.concatenate(alist))
        valid_preds_concat = xm.mesh_reduce('valid_preds_concat', valid_preds, lambda alist: np.concatenate(alist, axis=0))
        valid_loss_avg = Record.get_rmse(valid_preds_concat,valid_labels_concat)
        xm.master_print(f"Valid Loss: {valid_loss_avg:.4f}")
        savename = f"pretrained_model_{flags['fold']}_{valid_loss_avg}.bin"
        xm.master_print(f"saveweight:{savename}")
        xm.save(model.state_dict(),savename)
        if rank == 0:
          for path in sorted(glob.glob(f"pretrained_model_{flags['fold']}_*.bin"))[2:]:
            os.remove(path)
        

        scheduler.step(valid_loss_avg/100)

In [None]:
#%%time
for i in range(5):
    FLAGS = {}
    FLAGS['fold'] = i
    fold = i
    train_data, valid_data = train.iloc[train_indexs[fold],:],train.iloc[test_indexs[fold],:]
    print(f"fold:{fold},Training on {train_data.shape[0]} samples and Validation on {valid_data.shape[0]} samples")

    train_set = inputds(df=train_data, augments=True)
    valid_set = inputds(df=valid_data, augments=False)
    xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')
    #break