In [None]:
VERSION = "nightly"  #@param ["1.5" , "20200325", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION --apt-packages libomp5 libopenblas-dev

In [None]:
!pip install pretrainedmodels
!pip install efficientnet_pytorch

In [None]:

import json
import sys
import gc
import os
import random
import time
import cv2

from contextlib import contextmanager
from pathlib import Path
from collections import defaultdict, Counter
from PIL import Image
import numpy as np
import pandas as pd
import scipy as sp

import sklearn.metrics
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import StratifiedKFold

from functools import partial
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils import model_zoo
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
import torchvision.models as models

from albumentations import (
    Compose, OneOf, Normalize, Resize, RandomResizedCrop, RandomCrop, HorizontalFlip, VerticalFlip, 
    RandomBrightness, RandomContrast, RandomBrightnessContrast, Rotate, ShiftScaleRotate, Cutout, IAAAdditiveGaussianNoise
)
from albumentations.pytorch import ToTensorV2
from albumentations import ImageOnlyTransform

import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.data_parallel as dp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils

from collections import OrderedDict
import math

import pretrainedmodels as pmodels
from efficientnet_pytorch import EfficientNet


def seed_torch(seed=777):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

class FlowerDataset(Dataset):
    def __init__(self, df, labels, transform=None):
        self.df = df
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.df['id'].values[idx]
        file_path = f'../input/104-flowers-garden-of-eden/jpeg-{IMG_FOLDER}/{file_name}'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
            
        label = self.labels.values[idx]
        
        return image, label

    
def get_transforms(*, data):
    
    if data == 'train':
        return Compose([
            #Resize(HEIGHT, WIDTH),
            RandomResizedCrop(HEIGHT, WIDTH),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            ShiftScaleRotate(rotate_limit=30, p=0.5),
            Cutout(p=0.5, max_h_size=12, max_w_size=12, num_holes=6),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])
    
    elif data == 'valid':
        return Compose([
            Resize(HEIGHT, WIDTH),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])
    

os.environ['XRT_TPU_CONFIG'] = "tpu_worker;0;10.0.0.2:8470"

os.listdir('../input/104-flowers-garden-of-eden/')

SEED = 777
seed_torch(SEED)

N_CLASSES = 104

HEIGHT = 256
WIDTH = 256

train = pd.DataFrame()
valid = pd.DataFrame()
IMG_FOLDER = '512x512'

# train/
image_dict = {}
for _class in os.listdir(f'../input/104-flowers-garden-of-eden/jpeg-{IMG_FOLDER}/train/'):
    image_dict[_class] = os.listdir(f'../input/104-flowers-garden-of-eden/jpeg-{IMG_FOLDER}/train/{_class}/')   
for k, values in image_dict.items():
    train = pd.concat([train, pd.DataFrame({'id': [f'train/{k}/{v}' for v in values], 'class': [k]*len(values)})])

# val/
image_dict = {}
for _class in os.listdir(f'../input/104-flowers-garden-of-eden/jpeg-{IMG_FOLDER}/val/'):
    image_dict[_class] = os.listdir(f'../input/104-flowers-garden-of-eden/jpeg-{IMG_FOLDER}/val/{_class}/')   
for k, values in image_dict.items():
    valid = pd.concat([valid, pd.DataFrame({'id': [f'val/{k}/{v}' for v in values], 'class': [k]*len(values)})])

train = train.reset_index(drop=True)
valid = valid.reset_index(drop=True)

CLASSES = ['pink primrose',    'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     'wild geranium',     'tiger lily',           'moon orchid',              'bird of paradise', 'monkshood',        'globe thistle',         # 00 - 09
           'snapdragon',       "colt's foot",               'king protea',      'spear thistle', 'yellow iris',       'globe-flower',         'purple coneflower',        'peruvian lily',    'balloon flower',   'giant white arum lily', # 10 - 19
           'fire lily',        'pincushion flower',         'fritillary',       'red ginger',    'grape hyacinth',    'corn poppy',           'prince of wales feathers', 'stemless gentian', 'artichoke',        'sweet william',         # 20 - 29
           'carnation',        'garden phlox',              'love in the mist', 'cosmos',        'alpine sea holly',  'ruby-lipped cattleya', 'cape flower',              'great masterwort', 'siam tulip',       'lenten rose',           # 30 - 39
           'barberton daisy',  'daffodil',                  'sword lily',       'poinsettia',    'bolero deep blue',  'wallflower',           'marigold',                 'buttercup',        'daisy',            'common dandelion',      # 40 - 49
           'petunia',          'wild pansy',                'primula',          'sunflower',     'lilac hibiscus',    'bishop of llandaff',   'gaura',                    'geranium',         'orange dahlia',    'pink-yellow dahlia',    # 50 - 59
           'cautleya spicata', 'japanese anemone',          'black-eyed susan', 'silverbush',    'californian poppy', 'osteospermum',         'spring crocus',            'iris',             'windflower',       'tree poppy',            # 60 - 69
           'gazania',          'azalea',                    'water lily',       'rose',          'thorn apple',       'morning glory',        'passion flower',           'lotus',            'toad lily',        'anthurium',             # 70 - 79
           'frangipani',       'clematis',                  'hibiscus',         'columbine',     'desert-rose',       'tree mallow',          'magnolia',                 'cyclamen ',        'watercress',       'canna lily',            # 80 - 89
           'hippeastrum ',     'bee balm',                  'pink quill',       'foxglove',      'bougainvillea',     'camellia',             'mallow',                   'mexican petunia',  'bromelia',         'blanket flower',        # 90 - 99
           'trumpet creeper',  'blackberry lily',           'common tulip',     'wild rose']                                                                                                                                               # 100 - 102

class_map = {}

for i, c in enumerate(CLASSES):
    class_map[c] = i
    
train['class'] = train['class'].map(class_map)
valid['class'] = valid['class'].map(class_map)
train.head()
valid.head()

DEBUG = False

if DEBUG:
    trn_folds = train.sample(n=1000, random_state=42).reset_index(drop=True).copy()
    val_folds = valid.sample(n=1000, random_state=42).reset_index(drop=True).copy()
else:
    trn_folds = train.copy()
    val_folds = valid.copy()

trn_folds.head()
val_folds.head()

def _run():
    
    def train_loop_fn(para_train_loader, model, optimizer, criterion, device):
            
        model.train()
        avg_loss = 0.

        optimizer.zero_grad()
            
        for i, (images, labels) in enumerate(para_train_loader.per_device_loader(device)):

            images = images.to(device)
            labels = labels.to(device)

            y_preds = model(images)
            loss = criterion(y_preds, labels)
            
            if i % 40 == 0:
                xm.master_print(f'[train] i={i}, loss={loss}')
                    
            loss.backward()
            xm.optimizer_step(optimizer)
            optimizer.zero_grad()

            avg_loss += loss.item() / len(train_loader)
                
        return avg_loss
          
        
    def eval_loop_fn(para_valid_loader, model, criterion, device, scheduler):

        model.eval()
        avg_val_loss = 0.
        preds = []
        valid_labels = []
            
        for i, (images, labels) in enumerate(para_valid_loader.per_device_loader(device)):

            images = images.to(device)
            labels = labels.to(device)

            with torch.no_grad():
                y_preds = model(images)

            preds.append(y_preds.argmax(1).to('cpu').numpy())
            valid_labels.append(labels.to('cpu').numpy())

            loss = criterion(y_preds, labels)
            avg_val_loss += loss.item() / len(valid_loader)

        scheduler.step(avg_val_loss)

        preds = np.concatenate(preds)
        valid_labels = np.concatenate(valid_labels)
        
        score = f1_score(valid_labels, preds, average='macro')
            
        return avg_val_loss, score
        
        
    device = xm.xla_device()
    world_size = xm.xrt_world_size()
        
    batch_size = int( 128 / world_size )
    n_epochs = 20
    lr = 1e-4 * world_size
        
    xm.master_print(f'device: {device}')
    xm.master_print(f'world_size: {world_size}')
    xm.master_print(f'batch_size: {batch_size}')
    xm.master_print(f'n_epochs: {n_epochs}')
    xm.master_print(f'lr: {lr}')
    
    
    trn_idx = trn_folds.index
    val_idx = val_folds.index
        
    train_dataset = FlowerDataset(trn_folds.loc[trn_idx].reset_index(drop=True), 
                             trn_folds.loc[trn_idx]['class'], 
                             transform=get_transforms(data='train'))
        
    valid_dataset = FlowerDataset(val_folds.loc[val_idx].reset_index(drop=True), 
                                     val_folds.loc[val_idx]['class'],
                                     transform=get_transforms(data='valid'))

    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,
                                                                        num_replicas=xm.xrt_world_size(),
                                                                        rank=xm.get_ordinal(),
                                                                        shuffle=True)
    valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset,
                                                                        num_replicas=xm.xrt_world_size(),
                                                                        rank=xm.get_ordinal(),
                                                                        shuffle=False)
    if world_size==1:
        N_JOBS = 4
    else:
        N_JOBS = 0
        
    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, drop_last=True, num_workers=N_JOBS)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, sampler=valid_sampler, drop_last=False, num_workers=N_JOBS)
        
    xm.master_print(f"Train for {len(train_loader)} steps per epoch")
        
    #my_model = models.densenet201(pretrained = True)
    #my_model = pmodels.densenet201(num_classes=1000, pretrained = 'imagenet')
    my_model = EfficientNet.from_pretrained('efficientnet-b0')
        
    model = my_model.to(device)

    optimizer = Adam(model.parameters(), lr=lr, amsgrad=False)
    scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.75, patience=1, verbose=True, eps=1e-6)

    criterion = nn.CrossEntropyLoss()
    best_loss = np.inf
    best_score = 0.
    best_thresh = 0.

    for epoch in range(n_epochs):

        start_time = time.time()
            
            # train
        para_train_loader = pl.ParallelLoader(train_loader, [device])
        avg_loss = train_loop_fn(para_train_loader, model, optimizer, criterion, device)
            
            # eval
        para_valid_loader = pl.ParallelLoader(valid_loader, [device])
        avg_val_loss, score = eval_loop_fn(para_valid_loader, model, criterion, device, scheduler)

        elapsed = time.time() - start_time

        xm.master_print(f'  Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
        xm.master_print(f'  Epoch {epoch+1} - f1_score: {score}')

        if score>best_score:
            best_score = score
            if world_size==1:
                torch.save({'model': model.state_dict(), 
                                'score': best_score}, 
                               f'resnet_fold.pth')
            else:
                xm.master_print(f'  Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
                xm.save({'model': model.state_dict(), 
                             'score': best_score}, 
                            f'resnet_fold.pth')

def _mp_fn(rank, flags):
    torch.set_default_tensor_type('torch.FloatTensor')
    a = _run()

FLAGS = {}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=1, start_method='fork')