In [None]:
!pip install pretrainedmodels
# !pip install torchtoolbox
# !pip install torchviz
# !pip install efficientnet_pytorch
!git clone https://github.com/4uiiurz1/pytorch-auto-augment > /dev/null

VERSION = "20200516"  #@param ["1.5" , "20200516", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION --apt-packages libomp5 libopenblas-dev

### Importing Dependencies

In [None]:
%autosave 30
import os
os.environ['XLA_USE_BF16'] = "1"
import sys
sys.path.insert(0, './pytorch-auto-augment')
import gc
gc.enable()
import time
import glob
import random
from datetime import datetime

import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from skimage import io
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import tqdm as tqdm
from PIL import Image

import torch
import torchvision
from torchvision import transforms, models
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SequentialSampler, BatchSampler, RandomSampler

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.data_parallel as dp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu

# from torchviz import make_dot
import albumentations as A
from auto_augment import AutoAugment, Cutout
from albumentations.pytorch.transforms import ToTensorV2
from catalyst.data.sampler import DistributedSamplerWrapper, BalanceClassSampler
import pretrainedmodels
# from efficientnet_pytorch import EfficientNet

import sklearn
from sklearn import metrics
from sklearn.model_selection import GroupKFold

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore")

### Configuration

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

SEED = 2020
seed_everything(SEED)

### Data Preparation

In [None]:
dataset = []

for label, kind in enumerate(['Cover', 'JMiPOD', 'JUNIWARD', 'UERD']):
    for path in glob.glob('../input/alaska2-image-steganalysis/Cover/*.jpg'):
        dataset.append({
            'kind': kind,
            'image_name': path.split('/')[-1],
            'label': label
        })
        
random.shuffle(dataset)
dataset = pd.DataFrame(dataset)
gkf = GroupKFold(n_splits=5)
dataset.loc[:, 'fold'] = 0
for fold_number, (train_index, val_index) in enumerate(gkf.split(X=dataset.index, y=dataset['label'], groups=dataset['image_name'])):
    dataset.loc[dataset.iloc[val_index].index, 'fold'] = fold_number

### Augmentations

In [None]:
def get_train_transforms():
    return transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        # AutoAugment(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    ])

def get_valid_transforms():
    return transforms.Compose([
        transforms.ToPILImage(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    ])

### Dataset

In [None]:
def one_hot(size, target):
    vec = torch.zeros(size, dtype=torch.float32)
    vec[target] = 1.
    return vec

DATA_ROOT_PATH = '/kaggle/input/alaska2-image-steganalysis/'

class AlaskaDataset(Dataset):
    def __init__(self, kinds, image_names, labels, transforms=None):
        super().__init__()
        self.kinds = kinds
        self.image_names = image_names
        self.labels = labels
        self.transforms = transforms

    def __getitem__(self, index: int):
        kind, image_name, label = self.kinds[index], self.image_names[index], self.labels[index]
        image = cv2.imread(f'{DATA_ROOT_PATH}/{kind}/{image_name}', cv2.IMREAD_COLOR)
        image = cv2.resize(image, (331, 331), cv2.INTER_AREA)
        if self.transforms:
            image = self.transforms(image)
        target = one_hot(4, label)
        return image, target

    def __len__(self) -> int:
        return self.image_names.shape[0]
    
    def get_labels(self):
        return list(self.labels)

### Loader

In [None]:
fold_number = 0
SERIAL_EXEC = xmp.MpSerialExecutor()

train_dataset = AlaskaDataset(
    kinds=dataset[dataset['fold'] != fold_number].kind.values,
    image_names=dataset[dataset['fold'] != fold_number].image_name.values,
    labels=dataset[dataset['fold'] != fold_number].label.values,
    transforms=get_train_transforms(),
)

validation_dataset = AlaskaDataset(
    kinds=dataset[dataset['fold'] == fold_number].kind.values,
    image_names=dataset[dataset['fold'] == fold_number].image_name.values,
    labels=dataset[dataset['fold'] == fold_number].label.values,
    transforms=get_valid_transforms(),
)

### Model

In [None]:
class SE_ResNext50_32x4d(nn.Module):
    def __init__(self, pretrained=None):
        super(SE_ResNext50_32x4d, self).__init__()
        self.model = pretrainedmodels.__dict__['se_resnext50_32x4d'](pretrained=None)
        if pretrained is not None:
            # https://www.kaggle.com/abhishek/pretrained-model-weights-pytorch - Download
            self.model.load_state_dict(
                torch.load('../input/pretrained-model-weights-pytorch/se_resnext50_32x4d-a260b3a4.pth')
            )
        self.dropout = nn.Dropout(p=0.1)
        self.high_dropout = nn.Dropout(p=0.5)
        self.classifier = nn.Linear(in_features=2048, out_features=4)
    
    def forward(self, images):
        batch_size, _, _, _ = images.shape
        features = self.model.features(images)
        avg_pool = F.adaptive_avg_pool2d(features, 1).reshape(batch_size, -1)
        logits = torch.mean(
            torch.stack(
                [self.classifier(self.dropout(avg_pool)) for _ in range(5)],
                dim=0,
            ),
            dim=0,
        )
        return logits

### Metrics

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def alaska_weighted_auc(y_true, y_valid):
    tpr_thresholds = [0.0, 0.4, 1.0]
    weights = [2, 1]
    fpr, tpr, thresholds = metrics.roc_curve(y_true, y_valid, pos_label=1)
    areas = np.array(tpr_thresholds[1:]) - np.array(tpr_thresholds[:-1])
    normalization = np.dot(areas, weights)
    competition_metric = 0
    for idx, weight in enumerate(weights):
        y_min = tpr_thresholds[idx]
        y_max = tpr_thresholds[idx + 1]
        mask = (y_min < tpr) & (tpr < y_max)
        x_padding = np.linspace(fpr[mask][-1], 1, 100)
        x = np.concatenate([fpr[mask], x_padding])
        y = np.concatenate([tpr[mask], [y_max] * len(x_padding)])
        y = y - y_min 
        score = metrics.auc(x, y)
        submetric = score * weight
        best_subscore = (y_max - y_min) * weight
        competition_metric += submetric
    return competition_metric / normalization
        
class RocAucMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.y_true = np.array([0,1])
        self.y_pred = np.array([0.5,0.5])
        self.score = 0

    def update(self, y_true, y_pred):
        y_true = y_true.cpu().numpy().argmax(axis=1).clip(min=0, max=1).astype(int)
        y_pred = 1 - nn.functional.softmax(y_pred, dim=1).data.cpu().numpy()[:,0]
        self.y_true = np.hstack((self.y_true, y_true))
        self.y_pred = np.hstack((self.y_pred, y_pred))
        self.score = alaska_weighted_auc(self.y_true, self.y_pred)
    
    @property
    def avg(self):
        return self.score

### Loss

In [None]:
class LabelSmoothing(nn.Module):
    def __init__(self, smoothing = 0.05):
        super(LabelSmoothing, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing

    def forward(self, x, target):
        if self.training:
            x = x.float()
            target = target.float()
            logprobs = torch.nn.functional.log_softmax(x, dim = -1)
            nll_loss = -logprobs * target
            nll_loss = nll_loss.sum(-1)
            smooth_loss = -logprobs.mean(dim=-1)
            loss = self.confidence * nll_loss + self.smoothing * smooth_loss
            return loss.mean()
        else:
            return torch.nn.functional.cross_entropy(x, target)

### Engine

In [None]:
class Engine:
    
    def __init__(self, model, device, config):
        self.config = config
        self.model = model
        self.device = device
        self.model.to(self.device)
        model_params = list(self.model.named_parameters())
        no_decay = ['LayerNorm.weight', 'LayerNorm.bias', 'Bias']
        optimizer_grouped_parameters = [
            {
                'params': [param for name, param in model_params 
                           if any(nd in name for nd in no_decay) and 'classfier' not in name],
                'lr': 1e-4,
                'weight_decay':0.00
            },
            {
                'params': [param for name, param in model_params 
                           if any(nd not in name for nd in no_decay) and 'classifier' not in name],
                'lr': 1e-4,
                'weight_decay': 0.001
            },
            {
                'params': [param for name, param in model_params if 'classifier' in name],
                'lr': 1e-3
            }
        ]
        self.optimizer = optim.AdamW(optimizer_grouped_parameters, 
                                     lr=self.config.lr*xm.xrt_world_size())
        scheduler_params = dict(mode='max',
                                factor=0.8,
                                patience=2,
                                verbose=False, 
                                threshold=0.0001,
                                threshold_mode='abs',
                                cooldown=0, 
                                min_lr=1e-8,
                                eps=1e-08
                            )
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 
                                                              **scheduler_params)
        self.criterion = LabelSmoothing().to(self.device)
        
        self.epoch = 0
        self.best_score = 0
        self.best_loss = 10**5
        
        self.folder = 'TPU_SE_ResNext50_32x4d'
        self.base_dir = f'./{self.folder}'
        self.log_path = f'{self.base_dir}/log.txt'
        time.sleep(1)
        if not os.path.exists(self.base_dir):
            os.makedirs(self.base_dir)
        xm.master_print(f'Engine Prepared. Device is {self.device}')
    
    def train(self, train_loader):
        tracker = xm.RateTracker()
        self.model.train()
        total_loss = AverageMeter()
        total_score = RocAucMeter()
        start_time = time.time()
        for step, (images, labels) in enumerate(train_loader):
            if self.config.verbose and step!=0:
                if step%self.config.log_step==0:
                    print(f'[xla:{xm.get_ordinal()}]({step}) \
                          Train Step={step}/{len(train_loader)} \
                          Rate={tracker.rate():.2f} \
                          GlobalRate={tracker.global_rate():.2f} \
                          Total Loss={total_loss.avg:.3f} \
                          RoC Auc Score={total_score.avg:.3f} \
                          Total Time={time.time()-start_time:.2f}secs', 
                          end='\r', 
                          flush=True
                         )
            batch_size, _, _, _ = images.shape
            images = torch.tensor(images, device=self.device, dtype=torch.float32)
            targets = torch.tensor(labels, device=self.device, dtype=torch.float32)
            self.optimizer.zero_grad()
            logits = self.model(images)
            loss = self.criterion(logits, targets)
            loss.backward()
            xm.optimizer_step(self.optimizer)
            total_score.update(targets, logits)
            total_loss.update(loss.detach().item(), batch_size)
            if self.config.step_scheduler:
                self.scheduler.step()
        return total_loss, total_score
    
    def validation(self, val_loader):
        tracker = xm.RateTracker()
        self.model.eval()
        total_loss = AverageMeter()
        total_score = RocAucMeter()
        start_time = time.time()
        for step, (images, labels) in enumerate(val_loader):
            if self.config.verbose and step!=0:
                if step%self.config.log_step==0:
                    print(f'[xla:{xm.get_ordinal()}]({step}) \
                          Validation Step={step}/{len(train_loader)} \
                          Rate={tracker.rate():.2f} \
                          GlobalRate={tracker.global_rate():.2f} \
                          Total Loss={total_loss.avg:.3f} \
                          RoC Auc Score={total_score.avg:.3f} \
                          Total Time={time.time()-start_time:.2f}secs', 
                          end='\r', 
                          flush=True
                         )
            with torch.no_grad():
                batch_size, _, _, _ = images.shape
                images = torch.tensor(images, device=self.device, dtype=torch.float32)
                targets = torch.tensor(labels, device=self.device, dtype=torch.float32)
                logits = self.model(images)
                loss = self.criterion(logits, targets)
                total_loss.update(loss.detach().item(), batch_size)
                total_score.update(targets, logits)
        return total_loss, total_score
    
    def fit(self, train_loader, val_loader):
        for n_epoch in range(self.config.n_epochs):
            if self.config.verbose:
                lr1, lr2 = self.optimizer.param_groups[0]['lr'], self.optimizer.param_groups[-1]['lr']
                timestamp = datetime.utcnow().isoformat()
                self.log(f'\n{timestamp}\nLR Backbone:{lr1} LR Head: {lr2}')
            
            tracker = xm.RateTracker()
            start_time = time.time()
            para_loader = pl.ParallelLoader(train_loader, [self.device])
            total_loss, total_score = self.train(para_loader.per_device_loader(self.device))
            self.log(
                f'[TRAIN RESULT]: Epoch={self.epoch+1} \
                Rate={tracker.rate():.2f} \
                GlobalRate={tracker.global_rate():.2f} \
                Total Loss={total_loss.avg:.3f} \
                RoC Auc Score={total_score.avg:.3f} \
                Total Time={time.time()-start_time:.2f}secs')
            
            tracker = xm.RateTracker()
            start_time = time.time()
            para_loader = pl.ParallelLoader(val_loader, [self.device])
            total_loss, total_score = self.validation(para_loader.per_device_loader(self.device))
            self.log(
                f'[VALIDATION RESULT]: Epoch={self.epoch+1} \
                Rate={tracker.rate():.2f} \
                GlobalRate={tracker.global_rate():.2f} \
                Total Loss={total_loss.avg:.3f} \
                RoC Auc Score={total_score.avg:.3f} \
                Total Time={time.time()-start_time:.2f}secs')
            
            if self.config.epoch_scheduler:
                self.scheduler.step(metrics=total_score.avg)
            
            if n_epoch%20==0:
                self.save(f'{self.base_dir}/checkpoint-{str(self.epoch).zfill(3)}epoch.bin')
                
            if self.config.metrics_debug:
                xm.master_print(met.metrics_report(), flush=True)

            self.epoch+=1
            
    def save(self, path):
        self.model.eval()        
        xm.save(self.model.state_dict(), path)

    def log(self, message):
        if self.config.verbose:
            xm.master_print(message)
        with open(self.log_path, 'a+') as logger:
            xm.master_print(f'{message}\n', logger)

In [None]:
class Config:
    lr = 1e-4
    n_epochs = 5
    batch_size = 32
    num_workers = 4
    step_scheduler = False
    epoch_scheduler = True
    verbose = True
    log_step = 1
    metrics_debug = True

In [None]:
def _mp_fn(rank, flags):
    #xm.set_rng_state(SEED)
    device = xm.xla_device()
    
    train_sampler = DistributedSamplerWrapper(
        sampler=BalanceClassSampler(labels=train_dataset.get_labels(), mode="downsampling"),
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True
    )
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=Config.batch_size,
        sampler=train_sampler,
        pin_memory=False,
        drop_last=True,
        num_workers=Config.num_workers,
    )
    
    validation_sampler = torch.utils.data.distributed.DistributedSampler(
        validation_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False
    )
    
    validation_loader = torch.utils.data.DataLoader(
        validation_dataset,
        batch_size=Config.batch_size,
        sampler=validation_sampler,
        pin_memory=False,
        drop_last=False,
        num_workers=Config.num_workers
    )
    
    engine = Engine(model=SE_ResNext50_32x4d(pretrained='imagenet'), device=device, config=Config)
    if rank == 0:
        time.sleep(1)
    engine.fit(train_loader, validation_loader)

In [None]:
FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')