# TPU setup

In [None]:
VERSION = "nightly"  #@param ["1.5" , "20200325", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION --apt-packages libomp5 libopenblas-dev

In [None]:
# https://www.kaggle.com/c/flower-classification-with-tpus/discussion/129820
import os
os.environ['XRT_TPU_CONFIG'] = "tpu_worker;0;10.0.0.2:8470"

# Library

In [None]:
import os
import numpy as np 
import pandas as pd 
import json

In [None]:
os.listdir('../input/104-flowers-garden-of-eden/')

In [None]:
train = pd.DataFrame()
IMG_FOLDER = '512x512'

# train/
image_dict = {}
for _class in os.listdir(f'../input/104-flowers-garden-of-eden/jpeg-{IMG_FOLDER}/train/'):
    image_dict[_class] = os.listdir(f'../input/104-flowers-garden-of-eden/jpeg-{IMG_FOLDER}/train/{_class}/')   
for k, values in image_dict.items():
    train = pd.concat([train, pd.DataFrame({'id': [f'train/{k}/{v}' for v in values], 'class': [k]*len(values)})])

# val/
image_dict = {}
for _class in os.listdir(f'../input/104-flowers-garden-of-eden/jpeg-{IMG_FOLDER}/val/'):
    image_dict[_class] = os.listdir(f'../input/104-flowers-garden-of-eden/jpeg-{IMG_FOLDER}/val/{_class}/')   
for k, values in image_dict.items():
    train = pd.concat([train, pd.DataFrame({'id': [f'val/{k}/{v}' for v in values], 'class': [k]*len(values)})])

train = train.reset_index(drop=True)

In [None]:
# From https://www.kaggle.com/mgornergoogle/getting-started-with-100-flowers-on-tpu
CLASSES = ['pink primrose',    'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     'wild geranium',     'tiger lily',           'moon orchid',              'bird of paradise', 'monkshood',        'globe thistle',         # 00 - 09
           'snapdragon',       "colt's foot",               'king protea',      'spear thistle', 'yellow iris',       'globe-flower',         'purple coneflower',        'peruvian lily',    'balloon flower',   'giant white arum lily', # 10 - 19
           'fire lily',        'pincushion flower',         'fritillary',       'red ginger',    'grape hyacinth',    'corn poppy',           'prince of wales feathers', 'stemless gentian', 'artichoke',        'sweet william',         # 20 - 29
           'carnation',        'garden phlox',              'love in the mist', 'cosmos',        'alpine sea holly',  'ruby-lipped cattleya', 'cape flower',              'great masterwort', 'siam tulip',       'lenten rose',           # 30 - 39
           'barberton daisy',  'daffodil',                  'sword lily',       'poinsettia',    'bolero deep blue',  'wallflower',           'marigold',                 'buttercup',        'daisy',            'common dandelion',      # 40 - 49
           'petunia',          'wild pansy',                'primula',          'sunflower',     'lilac hibiscus',    'bishop of llandaff',   'gaura',                    'geranium',         'orange dahlia',    'pink-yellow dahlia',    # 50 - 59
           'cautleya spicata', 'japanese anemone',          'black-eyed susan', 'silverbush',    'californian poppy', 'osteospermum',         'spring crocus',            'iris',             'windflower',       'tree poppy',            # 60 - 69
           'gazania',          'azalea',                    'water lily',       'rose',          'thorn apple',       'morning glory',        'passion flower',           'lotus',            'toad lily',        'anthurium',             # 70 - 79
           'frangipani',       'clematis',                  'hibiscus',         'columbine',     'desert-rose',       'tree mallow',          'magnolia',                 'cyclamen ',        'watercress',       'canna lily',            # 80 - 89
           'hippeastrum ',     'bee balm',                  'pink quill',       'foxglove',      'bougainvillea',     'camellia',             'mallow',                   'mexican petunia',  'bromelia',         'blanket flower',        # 90 - 99
           'trumpet creeper',  'blackberry lily',           'common tulip',     'wild rose']                                                                                                                                               # 100 - 102

In [None]:
class_map = {}

for i, c in enumerate(CLASSES):
    class_map[c] = i
    
train['class'] = train['class'].map(class_map)

In [None]:
train.head()

# Data Loading

In [None]:
#submission = pd.read_csv('../input/flower-classification-with-tpus/sample_submission.csv')
submission = pd.read_csv('../input/getting-started-with-100-flowers-on-tpu/submission.csv').drop(columns='label')
submission.head()

# Library

In [None]:
# ====================================================
# Library
# ====================================================

import sys

import gc
import os
import random
import time
from contextlib import contextmanager
from pathlib import Path
from collections import defaultdict, Counter

import cv2
from PIL import Image
import numpy as np
import pandas as pd
import scipy as sp

import sklearn.metrics
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold

from functools import partial
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
import torchvision.models as models

from albumentations import (
    Compose, OneOf, Normalize, Resize, RandomResizedCrop, RandomCrop, HorizontalFlip, VerticalFlip, 
    RandomBrightness, RandomContrast, RandomBrightnessContrast, Rotate, ShiftScaleRotate, Cutout, IAAAdditiveGaussianNoise
)
from albumentations.pytorch import ToTensorV2
from albumentations import ImageOnlyTransform


#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device

In [None]:
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.data_parallel as dp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils

# Utils

In [None]:
# ====================================================
# Utils
# ====================================================

@contextmanager
def timer(name):
    t0 = time.time()
    LOGGER.info(f'[{name}] start')
    yield
    LOGGER.info(f'[{name}] done in {time.time() - t0:.0f} s.')

    
def init_logger(log_file='train.log'):
    from logging import getLogger, DEBUG, FileHandler,  Formatter,  StreamHandler
    
    log_format = '%(asctime)s %(levelname)s %(message)s'
    
    stream_handler = StreamHandler()
    stream_handler.setLevel(DEBUG)
    stream_handler.setFormatter(Formatter(log_format))
    
    file_handler = FileHandler(log_file)
    file_handler.setFormatter(Formatter(log_format))
    
    logger = getLogger('Flower')
    logger.setLevel(DEBUG)
    logger.addHandler(stream_handler)
    logger.addHandler(file_handler)
    
    return logger

LOG_FILE = 'train.log'
LOGGER = init_logger(LOG_FILE)


def seed_torch(seed=777):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

SEED = 777
seed_torch(SEED)

# Dataset

In [None]:
N_CLASSES = 104


class TrainDataset(Dataset):
    def __init__(self, df, labels, transform=None):
        self.df = df
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.df['id'].values[idx]
        file_path = f'../input/104-flowers-garden-of-eden/jpeg-{IMG_FOLDER}/{file_name}'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
            
        label = self.labels.values[idx]
        
        return image, label
    
    
class TestDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.df['id'].values[idx]
        file_path = f'../input/104-flowers-garden-of-eden/jpeg-{IMG_FOLDER}/test/{file_name}.jpeg'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        
        return image

# Transforms

In [None]:
HEIGHT = 256
WIDTH = 256


def get_transforms(*, data):
    
    assert data in ('train', 'valid', 'test-tta')
    
    if data == 'train':
        return Compose([
            #Resize(HEIGHT, WIDTH),
            RandomResizedCrop(HEIGHT, WIDTH),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            ShiftScaleRotate(rotate_limit=30, p=0.5),
            Cutout(p=0.5, max_h_size=12, max_w_size=12, num_holes=6),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])
    
    elif data == 'valid':
        return Compose([
            Resize(HEIGHT, WIDTH),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])
    
    elif data == 'test-tta':
        return Compose([
            RandomResizedCrop(HEIGHT, WIDTH),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])

# train valid split

In [None]:
from sklearn.model_selection import StratifiedKFold

DEBUG = False
N_FOLD = 4

if DEBUG:
    folds = train.sample(n=1000, random_state=42).reset_index(drop=True).copy()
else:
    folds = train.copy()
    
train_labels = folds['class'].values
kf = StratifiedKFold(n_splits=N_FOLD, shuffle=True, random_state=42)
for fold, (train_index, val_index) in enumerate(kf.split(folds.values, train_labels)):
    folds.loc[val_index, 'fold'] = int(fold)
folds['fold'] = folds['fold'].astype(int)

folds.head()

# Model

In [None]:
# https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/senet.py

from collections import OrderedDict
import math


class SEModule(nn.Module):

    def __init__(self, channels, reduction):
        super(SEModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1,
                             padding=0)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1,
                             padding=0)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        module_input = x
        x = self.avg_pool(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return module_input * x


class Bottleneck(nn.Module):
    """
    Base class for bottlenecks that implements `forward()` method.
    """
    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out = self.se_module(out) + residual
        out = self.relu(out)

        return out


class SEBottleneck(Bottleneck):
    """
    Bottleneck for SENet154.
    """
    expansion = 4

    def __init__(self, inplanes, planes, groups, reduction, stride=1,
                 downsample=None):
        super(SEBottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes * 2)
        self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3,
                               stride=stride, padding=1, groups=groups,
                               bias=False)
        self.bn2 = nn.BatchNorm2d(planes * 4)
        self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1,
                               bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.se_module = SEModule(planes * 4, reduction=reduction)
        self.downsample = downsample
        self.stride = stride


class SEResNetBottleneck(Bottleneck):
    """
    ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe
    implementation and uses `stride=stride` in `conv1` and not in `conv2`
    (the latter is used in the torchvision implementation of ResNet).
    """
    expansion = 4

    def __init__(self, inplanes, planes, groups, reduction, stride=1,
                 downsample=None):
        super(SEResNetBottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False,
                               stride=stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1,
                               groups=groups, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.se_module = SEModule(planes * 4, reduction=reduction)
        self.downsample = downsample
        self.stride = stride


class SEResNeXtBottleneck(Bottleneck):
    """
    ResNeXt bottleneck type C with a Squeeze-and-Excitation module.
    """
    expansion = 4

    def __init__(self, inplanes, planes, groups, reduction, stride=1,
                 downsample=None, base_width=4):
        super(SEResNeXtBottleneck, self).__init__()
        width = math.floor(planes * (base_width / 64)) * groups
        self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False,
                               stride=1)
        self.bn1 = nn.BatchNorm2d(width)
        self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
                               padding=1, groups=groups, bias=False)
        self.bn2 = nn.BatchNorm2d(width)
        self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.se_module = SEModule(planes * 4, reduction=reduction)
        self.downsample = downsample
        self.stride = stride


class SENet(nn.Module):

    def __init__(self, block, layers, groups, reduction, dropout_p=0.2,
                 inplanes=128, input_3x3=True, downsample_kernel_size=3,
                 downsample_padding=1, num_classes=1000):
        super(SENet, self).__init__()
        self.inplanes = inplanes
        if input_3x3:
            layer0_modules = [
                ('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1,
                                    bias=False)),
                ('bn1', nn.BatchNorm2d(64)),
                ('relu1', nn.ReLU(inplace=True)),
                ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1,
                                    bias=False)),
                ('bn2', nn.BatchNorm2d(64)),
                ('relu2', nn.ReLU(inplace=True)),
                ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1,
                                    bias=False)),
                ('bn3', nn.BatchNorm2d(inplanes)),
                ('relu3', nn.ReLU(inplace=True)),
            ]
        else:
            layer0_modules = [
                ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2,
                                    padding=3, bias=False)),
                ('bn1', nn.BatchNorm2d(inplanes)),
                ('relu1', nn.ReLU(inplace=True)),
            ]
        # To preserve compatibility with Caffe weights `ceil_mode=True`
        # is used instead of `padding=1`.
        layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2,
                                                    ceil_mode=True)))
        self.layer0 = nn.Sequential(OrderedDict(layer0_modules))
        self.layer1 = self._make_layer(
            block,
            planes=64,
            blocks=layers[0],
            groups=groups,
            reduction=reduction,
            downsample_kernel_size=1,
            downsample_padding=0
        )
        self.layer2 = self._make_layer(
            block,
            planes=128,
            blocks=layers[1],
            stride=2,
            groups=groups,
            reduction=reduction,
            downsample_kernel_size=downsample_kernel_size,
            downsample_padding=downsample_padding
        )
        self.layer3 = self._make_layer(
            block,
            planes=256,
            blocks=layers[2],
            stride=2,
            groups=groups,
            reduction=reduction,
            downsample_kernel_size=downsample_kernel_size,
            downsample_padding=downsample_padding
        )
        self.layer4 = self._make_layer(
            block,
            planes=512,
            blocks=layers[3],
            stride=2,
            groups=groups,
            reduction=reduction,
            downsample_kernel_size=downsample_kernel_size,
            downsample_padding=downsample_padding
        )
        self.avg_pool = nn.AvgPool2d(7, stride=1)
        self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
        self.last_linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
                    downsample_kernel_size=1, downsample_padding=0):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=downsample_kernel_size, stride=stride,
                          padding=downsample_padding, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, groups, reduction, stride,
                            downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups, reduction))

        return nn.Sequential(*layers)

    def features(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x

    def logits(self, x):
        x = self.avg_pool(x)
        if self.dropout is not None:
            x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = self.last_linear(x)
        return x

    def forward(self, x):
        x = self.features(x)
        x = self.logits(x)
        return x


def initialize_pretrained_model(model, num_classes, settings):
    assert num_classes == settings['num_classes'], \
        'num_classes should be {}, but is {}'.format(
            settings['num_classes'], num_classes)
    model.load_state_dict(model_zoo.load_url(settings['url']))
    model.input_space = settings['input_space']
    model.input_size = settings['input_size']
    model.input_range = settings['input_range']
    model.mean = settings['mean']
    model.std = settings['std']


def se_resnext50_32x4d(num_classes=1000, pretrained='imagenet'):
    model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16,
                  dropout_p=None, inplanes=64, input_3x3=False,
                  downsample_kernel_size=1, downsample_padding=0,
                  num_classes=num_classes)
    if pretrained is not None:
        settings = pretrained_settings['se_resnext50_32x4d'][pretrained]
        initialize_pretrained_model(model, num_classes, settings)
    return model


def se_resnext101_32x4d(num_classes=1000, pretrained='imagenet'):
    model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16,
                  dropout_p=None, inplanes=64, input_3x3=False,
                  downsample_kernel_size=1, downsample_padding=0,
                  num_classes=num_classes)
    if pretrained is not None:
        settings = pretrained_settings['se_resnext101_32x4d'][pretrained]
        initialize_pretrained_model(model, num_classes, settings)
    return model

In [None]:
pretrained_path = {'se_resnext50_32x4d': '../input/pytorch-se-resnext/se_resnext50_32x4d-a260b3a4.pth',
                   'se_resnext101_32x4d': '../input/pytorch-se-resnext/se_resnext101_32x4d-3b2fe3d8.pth',}

class CustomSEResNeXt(nn.Module):

    def __init__(self, model_name='se_resnext50_32x4d'):
        assert model_name in ('se_resnext50_32x4d', 'se_resnext101_32x4d')
        super().__init__()
        
        self.model = se_resnext50_32x4d(pretrained=None)
        weights_path = pretrained_path[model_name]
        self.model.load_state_dict(torch.load(weights_path))
        self.model.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.model.last_linear = nn.Linear(self.model.last_linear.in_features, N_CLASSES)
        
    def forward(self, x):
        x = self.model(x)
        return x

# Train

In [None]:
from sklearn.metrics import f1_score

In [None]:
def _run():
    
        
    def train_loop_fn(para_train_loader, model, optimizer, criterion, device):
            
        model.train()
        avg_loss = 0.

        optimizer.zero_grad()
            
        for i, (images, labels) in enumerate(para_train_loader.per_device_loader(device)):

            images = images.to(device)
            labels = labels.to(device)

            y_preds = model(images)
            loss = criterion(y_preds, labels)
            
            if i % 40 == 0:
                xm.master_print(f'[train] i={i}, loss={loss}')
                    
            loss.backward()
            #optimizer.step()
            xm.optimizer_step(optimizer)
            optimizer.zero_grad()

            avg_loss += loss.item() / len(train_loader)
                
        return avg_loss
          
        
    def eval_loop_fn(para_valid_loader, model, criterion, device, scheduler):

        model.eval()
        avg_val_loss = 0.
        preds = []
        valid_labels = []
            
        for i, (images, labels) in enumerate(para_valid_loader.per_device_loader(device)):

            images = images.to(device)
            labels = labels.to(device)

            with torch.no_grad():
                y_preds = model(images)

            preds.append(y_preds.argmax(1).to('cpu').numpy())
            valid_labels.append(labels.to('cpu').numpy())

            loss = criterion(y_preds, labels)
            avg_val_loss += loss.item() / len(valid_loader)

        scheduler.step(avg_val_loss)

        preds = np.concatenate(preds)
        valid_labels = np.concatenate(valid_labels)
        
        score = f1_score(valid_labels, preds, average='macro')
            
        return avg_val_loss, score
        
        
    device = xm.xla_device()
    world_size = xm.xrt_world_size()
        
    batch_size = int( 128 / world_size )
    n_epochs = 20
    lr = 1e-4 * world_size
        
    xm.master_print(f'device: {device}')
    xm.master_print(f'world_size: {world_size}')
    xm.master_print(f'batch_size: {batch_size}')
    xm.master_print(f'n_epochs: {n_epochs}')
    xm.master_print(f'lr: {lr}')
    
    NUM_TTA = 1
    ENSEMBLE_WEIGHTS = {'se_resnext': 1}
    probas = []
    model1_proba = []
    
    for FOLD in range(N_FOLD):
        
        xm.master_print(f"FOLD: {FOLD}")
        
        trn_idx = folds[folds['fold'] != FOLD].index
        val_idx = folds[folds['fold'] == FOLD].index
        
        train_dataset = TrainDataset(folds.loc[trn_idx].reset_index(drop=True), 
                             folds.loc[trn_idx]['class'], 
                             transform=get_transforms(data='train'))
        valid_dataset = TrainDataset(folds.loc[val_idx].reset_index(drop=True), 
                                     folds.loc[val_idx]['class'], 
                                     transform=get_transforms(data='valid'))
        test_dataset = TestDataset(submission, transform=get_transforms(data='valid'))

        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,
                                                                        num_replicas=xm.xrt_world_size(),
                                                                        rank=xm.get_ordinal(),
                                                                        shuffle=True)
        valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset,
                                                                        num_replicas=xm.xrt_world_size(),
                                                                        rank=xm.get_ordinal(),
                                                                        shuffle=False)
        test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset,
                                                                       num_replicas=xm.xrt_world_size(),
                                                                       rank=xm.get_ordinal(),
                                                                       shuffle=False)
        if world_size==1:
            N_JOBS = 4
        else:
            N_JOBS = 0
        train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, drop_last=True, num_workers=N_JOBS)
        valid_loader = DataLoader(valid_dataset, batch_size=batch_size, sampler=valid_sampler, drop_last=False, num_workers=N_JOBS)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, sampler=test_sampler, drop_last=False, num_workers=N_JOBS)
        
        xm.master_print(f"Train for {len(train_loader)} steps per epoch")
        
        my_model = CustomSEResNeXt(model_name='se_resnext50_32x4d')
        model = my_model.to(device)

        optimizer = Adam(model.parameters(), lr=lr, amsgrad=False)
        scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.75, patience=1, verbose=True, eps=1e-6)

        criterion = nn.CrossEntropyLoss()
        best_loss = np.inf
        best_score = 0.
        best_thresh = 0.

        for epoch in range(n_epochs):

            start_time = time.time()
            
            # train
            para_train_loader = pl.ParallelLoader(train_loader, [device])
            avg_loss = train_loop_fn(para_train_loader, model, optimizer, criterion, device)
            
            # eval
            para_valid_loader = pl.ParallelLoader(valid_loader, [device])
            avg_val_loss, score = eval_loop_fn(para_valid_loader, model, criterion, device, scheduler)

            elapsed = time.time() - start_time

            if world_size==1:
                LOGGER.debug(f'  Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
                LOGGER.debug(f'  Epoch {epoch+1} - f1_score: {score}')
            else:
                xm.master_print(f'  Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
                xm.master_print(f'  Epoch {epoch+1} - f1_score: {score}')

            if score>best_score:
                best_score = score
                if world_size==1:
                    LOGGER.debug(f'  Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
                    torch.save({'model': model.state_dict(), 
                                'score': best_score}, 
                               f'se_resnext_fold{FOLD}.pth')
                else:
                    xm.master_print(f'  Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
                    xm.save({'model': model.state_dict(), 
                             'score': best_score}, 
                            f'se_resnext_fold{FOLD}.pth')
                    
        # inference
        model.load_state_dict(torch.load(f'./se_resnext_fold{FOLD}.pth')['model'])

        for param in model.parameters():
            param.requires_grad = False
        model.eval()

        TTA_list = []

        for _ in range(NUM_TTA):

            test_proba = []

            for i, images in enumerate(test_loader):

                images = images.to(device) 

                with torch.no_grad():
                    y_preds = model(images)

                test_proba.append(list(y_preds.to('cpu').numpy()))
                
                if i % 20 == 0:
                    print(f'[inference] i={i}... done')

            TTA_list.append(sum(test_proba, []))

        # TTA average
        model1_proba.append(np.mean([test_preds for test_preds in TTA_list], axis=0))

    # FOLD average
    probas.append(np.mean(model1_proba, axis=0))
    
    ensemble_proba = ENSEMBLE_WEIGHTS['se_resnext']*np.array(probas[0])
        
    # ensemble predictions
    predictions = ensemble_proba.argmax(1)
    
    submission['label'] = predictions
    submission['label'] = submission['label'].astype(int)
    submission.to_csv('submission.csv', index=False)

In [None]:
# Start training processes
def _mp_fn(rank, flags):
    torch.set_default_tensor_type('torch.FloatTensor')
    a = _run()

FLAGS = {}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=1, start_method='fork')