In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version 1.7 > /dev/null

In [None]:
# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
# !python pytorch-xla-env-setup.py --version 1.7 --apt-packages libomp5 libopenblas-dev

!pip install pytorch-lightning == 1.8

!pip install omegaconf

# !git clone https://github.com/rwightman/pytorch-image-models

In [None]:
!git clone https://github.com/rwightman/pytorch-image-models

In [None]:
import sys
sys.path.append('pytorch-image-models')
# sys.path.append("../input/mytorchlightning/pytorch-lightning")
# sys.path.append('../input/pytorch-image-models/pytorch-image-models-master')

import numpy as np  # linear algebra
import pandas as pd  # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import time
import random
from contextlib import contextmanager
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold
from sklearn import model_selection
from collections import defaultdict, Counter
import sys
from typing import Tuple
import PIL
from omegaconf import OmegaConf
from torch.utils.data import Dataset
from pathlib import Path
from PIL import Image
from PIL.Image import Image as PILImage
from torch.utils.data.dataloader import DataLoader
import cv2
from sklearn.model_selection import train_test_split, StratifiedKFold
import albumentations as A
from sklearn.metrics import roc_auc_score

from albumentations import (
    Compose, OneOf, Normalize, Resize, RandomResizedCrop, RandomCrop, HorizontalFlip, VerticalFlip,
    RandomBrightness, RandomContrast, RandomBrightnessContrast, Rotate, ShiftScaleRotate, Cutout,
    IAAAdditiveGaussianNoise, Transpose
)
from albumentations.pytorch import ToTensorV2
from albumentations import ImageOnlyTransform
import matplotlib.pyplot as plt
import timm
from torchvision import models
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
from torch import optim

import torch_xla.core.xla_model as xm


import torch_xla

In [None]:
from pytorch_lightning import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

In [None]:
TRAIN_PATH = '../input/plant-pathology-2020-fgvc7/images'
# target_cols = 
train = pd.read_csv("../input/plant-pathology-2020-fgvc7/train.csv")
test = pd.read_csv("../input/plant-pathology-2020-fgvc7/train.csv")
rand = random.randint(0, 100000)

train

In [None]:
conf = """
base:
  train_path: '../input/plant-pathology-2020-fgvc7/images'
  print_freq: 100
  num_workers: 4
  seed: 42
  target_size: 4
  target_cols: [
      "healthy",
      "multiple_diseases",
      "rust",
      "scab"
  ]

  n_fold: 4
  trn_fold: [0]
  train: True
  debug: True
  oof: False
  tpu: True

split:
  name: "KFold"
  param: {
           "n_splits": 4,
           "shuffle": True,
           "random_state": 1212
  }

model:
  model_name: "tf_efficientnet_b0_ns"
  size: 224  # 480
  batch_size: 128
  pretrained: true
  epochs: 10

loss:
  name: "BCEWithLogitsLoss"
  param: {}

optimizer:
  name: "AdamW"
  param: {
           "lr": 5e-3,
           "weight_decay": 1e-6,
           "amsgrad": False
  }

scheduler:
  name: "CosineAnnealingLR"
  param: {
            "T_max": 6,
            "eta_min": 0,
            "last_epoch": -1
  }
wandb:
  use: false
  project: "kaggle-tpu"
  name: "1"
  tags: []
"""

In [None]:
def seed_torch(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
def get_transforms(img_size, data):
    if data == 'train':
        return Compose([
            Resize(img_size, img_size),
            RandomResizedCrop(img_size, img_size, scale=(0.85, 1.0)),
            HorizontalFlip(p=0.5),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])

    elif data == 'valid':
        return Compose([
            Resize(img_size, img_size),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])

In [None]:
# ====================================================
# Dataset
# ====================================================
class TrainDataset(Dataset):
    def __init__(self, cfg, df, transform=None, inference=False):
        self.df = df
        self.cfg = cfg
        self.file_names = df['image_id'].values
        self.labels = df[cfg.base.target_cols].values
        self.transform = transform
        self.inference = inference

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = f'{self.cfg.base.train_path}/{file_name}.jpg'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # image = Image.fromarray(np.uint8(image)).convert("RGB")
        if self.transform:
            # print(image.shape)
            # image = image.transpose(2, 0, 1)
            augmented = self.transform(image=image)
            # print(image)
            # print(augmented)
            image = augmented['image']
        label = torch.tensor(self.labels[idx]).float()

        if self.inference:
            return image
        else:
            return image, label


In [None]:
class CHIZUDataModule(LightningDataModule):
    def __init__(
            self,
            cfg,
            train_df,
            val_df,
            aug_p: float = 0.5,
            val_pct: float = 0.2,
            img_sz: int = 224,
            batch_size: int = 64,
            num_workers: int = 4,
            fold_id: int = 0,
    ):
        super().__init__()
        self.cfg = cfg
        self.aug_p = aug_p
        self.val_pct = val_pct
        self.img_sz = img_sz
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.fold_id = fold_id

        self.train_df = train_df
        self.val_df = val_df

    def train_dataloader(self):
        train_dataset = TrainDataset(self.cfg, self.train_df, transform=get_transforms(self.img_sz, data="train"))
        
        sampler = None
        if self.cfg.base.tpu:
            sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True
            )
            
            return DataLoader(
                train_dataset,
                batch_size=self.batch_size,
                # num_workers=self.num_workers,
                sampler = sampler,
                # shuffle=True,
                # pin_memory=True,
                # drop_last=True
            )
        else:
            return DataLoader(
                train_dataset,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                # sampler = sampler,
                shuffle=True,
                pin_memory=True,
                drop_last=True
            )

    def val_dataloader(self):
        valid_dataset = TrainDataset(self.cfg, self.val_df, transform=get_transforms(self.img_sz, data="valid"))
        
        sampler = None
        if self.cfg.base.tpu:
            sampler = torch.utils.data.distributed.DistributedSampler(
            valid_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=False
        )
                
            return DataLoader(
                valid_dataset,
                batch_size=self.batch_size,
                # num_workers=self.num_workers,
                # shuffle=False,
                # pin_memory=True,
                # drop_last=False,
                sampler=sampler
            )
        else:
            return DataLoader(
                valid_dataset,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                shuffle=False,
                pin_memory=True,
                drop_last=False,
                # sampler=sampler
            )

In [None]:
__CRITERIONS__ = {
    # "BCEFocalLoss": BCEFocalLoss
}
    

__SPLITS__ = {
    # "MultilabelStratifiedKFold": MultilabelStratifiedKFold
}

__OPTIMIZERS__ = {
    # "AdaBelief": AdaBelief,
    # "RAdam": torch_optimizer.RAdam
}

def get_criterion(cfg):
    if hasattr(nn, cfg.loss.name):
        return nn.__getattribute__(cfg.loss.name)(**cfg.loss.param)
    elif __CRITERIONS__.get(cfg.loss.name) is not None:
        return __CRITERIONS__[cfg.loss.name](**cfg.loss.param)
    else:
        raise NotImplementedError

        
def get_optimizer(cfg, model):
    optimizer_name = cfg.optimizer.name

    if __OPTIMIZERS__.get(optimizer_name) is not None:
        return __OPTIMIZERS__[optimizer_name](model.parameters(), **cfg.optimizer.param)
    else:
        return optim.__getattribute__(optimizer_name)(model.parameters(), **cfg.optimizer.param)


def get_scheduler(cfg, optimizer):
    scheduler_name = cfg.scheduler.name

    if scheduler_name is None:
        return
    else:
        return optim.lr_scheduler.__getattribute__(scheduler_name)(optimizer, **cfg.scheduler.param)


def get_split(cfg):
    if hasattr(model_selection, cfg.split.name):
        return model_selection.__getattribute__(cfg.split.name)(**cfg.split.param)
    elif __SPLITS__.get(cfg.split.name) is not None:
        return __SPLITS__[cfg.split.name](**cfg.split.param)
    else:
        raise NotImplementedError

In [None]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))


In [None]:
class CHIZUModel(LightningModule):
    def __init__(self, cfg, model_name="resnext50_32x4d"):
        super().__init__()

        self.cfg = cfg
        self.wd = 1e-6
        self.model_name = model_name
        self.model = timm.create_model(model_name, pretrained=cfg.model.pretrained)

        if "efficient" not in self.model_name:
            n_features = self.model.fc.in_features
            self.model.fc = nn.Linear(n_features, cfg.base.target_size)
        else:
            "efficient"
            self.model.classifier = nn.Linear(self.model.num_features, cfg.base.target_size)

        # self.model.avg_pool = GeM()

        self.optimizer = get_optimizer(cfg, self.model)
        self.scheduler = get_scheduler(cfg, self.optimizer)
        self.criterion = get_criterion(cfg)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.model(x)
        # x = self.sigmoid(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        # self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        # self.log("valid_loss", loss, prog_bar=True)
        return loss, y_hat.cpu().numpy(), y.cpu().numpy()

    def validation_epoch_end(self, input_):
        auc_l = 0
        acc_l = 0
        acc_f = 0
        for j in range(4):
            loss_list, y_hat_list, y_list = np.array([]), np.array([]), np.array([])
            for i, (loss, y_hat, y) in enumerate(input_):
                y_hat_list = np.append(y_hat_list, y_hat[:, j])
                y_list = np.append(y_list, y[:, j])

            y_hat_list = sigmoid(y_hat_list)
            try:
                auc = roc_auc_score(y_list, y_hat_list)
            except ValueError:
                auc = 0
            # acc = accuracy_score(y_list, np.round(y_hat_list))
            auc_l += auc / 4
            # acc_l += acc / 29

            num = "{0:01d}".format(j + 1)
            self.log(f"{num}-auc", auc, prog_bar=True)

            if j == 0:
                auc_f = auc

        for i, (loss, y_hat, y) in enumerate(input_):
            loss_list = np.append(loss_list, float(loss.cpu()))
        self.log("valid_loss", loss_list.mean(), prog_bar=True)
        self.log("valid auc", auc_l, prog_bar=True)
        # self.log("valid Acc", acc_l, prog_bar=True)
        # self.log("auc-1st", auc_f, prog_bar=True)

    def configure_optimizers(self):
        optimizer = self.optimizer
        scheduler = self.scheduler

        return [optimizer], [scheduler]

In [None]:
def train_loop(cfg, folds, fold):
    global rand

    trn_idx = folds[folds['fold'] != fold].index
    val_idx = folds[folds['fold'] == fold].index

    train_folds = folds.loc[trn_idx].reset_index(drop=True)
    valid_folds = folds.loc[val_idx].reset_index(drop=True)

    data_module = CHIZUDataModule(
        cfg,
        train_folds,
        valid_folds,
        aug_p=0.5,
        img_sz=cfg.model.size,
        batch_size=cfg.model.batch_size,
        num_workers=cfg.base.num_workers,
        # fold_id=fold,
    )
    model = CHIZUModel(
        cfg,
        model_name=cfg.model.model_name,
    )

    checkpoint_callback = ModelCheckpoint(
        dirpath=f'../exp2/{rand}',
        filename=f"fold-{fold}",
        # save_top_k=3,
        mode='min',
    )
    
    if cfg.wandb.use:
        wandb.init(
            name=cfg.wandb.name + f"-fold-{fold}-{rand}",
            project=cfg.wandb.project,
            tags=cfg.wandb.tags + [str(rand)],
            reinit=True
        )
        wandb_logger = WandbLogger(
            name=cfg.wandb.name + f"-fold-{fold}-{rand}",
            project=cfg.wandb.project,
            tags=cfg.wandb.tags + [str(rand)]
        )
        wandb_logger.log_hyperparams(dict(cfg))
        wandb_logger.log_hyperparams(dict({"rand": rand, "fold": fold, }))

    
    if cfg.base.tpu:
        trainer = pl.Trainer(
            # gpus=-1,
            tpu_cores=8,
            max_epochs=cfg.model.epochs,
            # gradient_clip_val=0.1,
            precision=16,
            # logger=wandb_logger if "wandb_logger" in locals() else False,
            # callbacks=[checkpoint_callback]
        )
    else:

        trainer = pl.Trainer(
            gpus=-1,
            max_epochs=cfg.model.epochs,
            gradient_clip_val=0.1,
            precision=16,
            # logger=wandb_logger if "wandb_logger" in locals() else False,
            callbacks=[checkpoint_callback]
        )

    trainer.fit(model=model, datamodule=data_module)


In [None]:
def main(cfg):
    seed_torch(seed=cfg.base.seed)

    folds = train.copy()

    if cfg.base.debug:
        folds = folds.sample(n=1000, random_state=cfg.base.seed).reset_index(drop=True)
        cfg.model.epochs = 1
    
    
    Fold = get_split(cfg)
    for n, (train_index, val_index) in enumerate(Fold.split(folds, folds[cfg.base.target_cols])):
        folds.loc[val_index, 'fold'] = int(n)
    folds['fold'] = folds['fold'].astype(int)

    oof_df = train.copy()
    test_pred = test.copy()
    test_pred.iloc[:, 1:] = 0

    for fold in range(cfg.base.n_fold):
        if fold in cfg.base.trn_fold:
            train_loop(cfg, folds, fold)


In [None]:
main(OmegaConf.create(conf))