### Online usage

In [None]:
# Online
# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
# !python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev
## !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp37-cp37m-linux_x86_64.whl

### Offline usage

In [None]:
# Offline
!pip install -U ../input/torchxla17wheel/torch_xla-1.7-cp37-cp37m-linux_x86_64.whl

## Dataset

In [None]:
import os
import cv2
import time
import numpy as np
import pandas as pd
from tqdm import tqdm
import gc; gc.collect()
import matplotlib.pyplot as plt

In [None]:
import torch
import torch.nn as nn
import torch.utils.data as D
import torch.nn.functional as F

import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.core.xla_model as xm
import torch_xla.utils.utils as xu
import torch_xla.utils.serialization as xser
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

import torchvision
from torchvision import transforms as T
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split

In [None]:
train_df_path="../input/hpa-single-cell-image-classification/train.csv"
train_images_path="../input/hpa-single-cell-image-classification/train"
test_images_path="../input/hpa-single-cell-image-classification/test"
sample_df_path="../input/hpa-single-cell-image-classification/sample_submission.csv"

In [None]:
os.environ['XLA_USE_BF16']="1"
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '100000000'

In [None]:
train_df=pd.read_csv(train_df_path)
train_df.head()

In [None]:
Transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

In [None]:
class HPADataset(Dataset):
    def __init__(self, path, df, img_size, Transform):
        self.path = path
        self.df = df
        self.img_ids = df['ID'].values
        self.labels = df['Label'].values
        self.img_size = img_size        
        self.transform = Transform
        
    def _get_image(self, ID):
        R = cv2.imread(self.path + '/' + ID + '_red.png', cv2.IMREAD_UNCHANGED)
        Y = cv2.imread(self.path + '/' + ID + '_yellow.png', cv2.IMREAD_UNCHANGED)
        G = cv2.imread(self.path + '/' + ID + '_green.png', cv2.IMREAD_UNCHANGED)
        B = cv2.imread(self.path + '/' + ID + '_blue.png', cv2.IMREAD_UNCHANGED)
        img = np.stack((
                R/2 + Y/2, 
                G/2 + Y/2, 
                B),-1)
        
        img = cv2.resize(img, (self.img_size, self.img_size))
        img = np.divide(img, 255)
        return img          
        
    def __len__(self):
        return len(self.df) 
    
    def __getitem__(self, index):
        x = self._get_image(self.img_ids[index])
        x = self.transform(x)
        y = self.labels[index]
        y = y.split('|')
        y = list(map(int, y))            
        y = np.eye(FLAGS['NUM_CLASSES'], dtype='float')[y]                                    
        y = y.sum(axis=0)
        return x, y

In [None]:
class HPATestDataset(Dataset):
    def __init__(self, path, df, img_size, Transform):
        self.path = path
        self.df = df
        self.img_ids = df['ID'].values
        self.labels = df['Label'].values
        self.img_size = img_size        
        self.transform = Transform
        
    def _get_image(self, ID):
#         R = cv2.imread(self.path + '/' + ID + '_red.png', cv2.IMREAD_UNCHANGED)
#         Y = cv2.imread(self.path + '/' + ID + '_yellow.png', cv2.IMREAD_UNCHANGED)
#         G = cv2.imread(self.path + '/' + ID + '_green.png', cv2.IMREAD_UNCHANGED)
#         B = cv2.imread(self.path + '/' + ID + '_blue.png', cv2.IMREAD_UNCHANGED)
#         img = np.stack((
#                 R/2 + Y/2, 
#                 G/2 + Y/2, 
#                 B),-1)
        
#         img = cv2.resize(img, (self.img_size, self.img_size))
#         img = np.divide(img, 255)
        
        data_file = cv2.imread(self.path + '/' + ID + '_green.png')
            
        img = cv2.resize(data_file, (self.img_size, self.img_size))
        X = img/255.             
        
        return X          
        
    def __len__(self):
        return len(self.df) 
    
    def __getitem__(self, index):
        x = self._get_image(self.img_ids[index])
        x = self.transform(x)
        if "train" in self.path:
            y = self.labels[index]
            y = y.split('|')
            y = list(map(int, y))            
            y = np.eye(FLAGS['NUM_CLASSES'], dtype='float')[y]                                    
            y = y.sum(axis=0)
            return x, y
        y = self.img_ids[index]
        return x, y

In [None]:
train_split, eval_split = train_test_split(train_df, test_size=0.2, random_state=42)

## Model

In [None]:
def get_model():  
    model = torchvision.models.resnet50()
    model.fc = nn.Linear(2048, 19, bias=True)
    return model

resnet50 = get_model()

In [None]:
def graph_losses(losses):
    for phase, color in zip(['train', 'eval'], ['r--', 'b--']):
        if not losses[phase]:
            continue
        epoch_count = range(1, len(losses[phase]) + 1)
        plt.plot(epoch_count, losses[phase], color)
        plt.legend([f'{phase.capitalize()} Loss'])
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.show() 

def reduce_fn(vals):
    # take average
    return sum(vals) / len(vals)        
        
def run(epochs=20, validate_every=2):
    
    device = xm.xla_device()
    
    # Init DataLoader
    loaders = {}
    Transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

    train_dataset = HPADataset(train_images_path, train_split, FLAGS['IMG_SIZE'], Transform)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
          train_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=True)    
    train_loader = DataLoader(train_dataset, batch_size=FLAGS['BATCH_SIZE'], sampler=train_sampler, shuffle=False)

    eval_dataset = HPADataset(train_images_path, eval_split, FLAGS['IMG_SIZE'], Transform)
    eval_sampler = torch.utils.data.distributed.DistributedSampler(
          eval_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=False)    
    eval_loader = DataLoader(eval_dataset, batch_size=FLAGS['BATCH_SIZE'], sampler=eval_sampler, shuffle=False)

    loaders['train'] = train_loader
    loaders['eval'] = eval_loader    
    
    # Initialize model
    model = resnet50.to(device)
    learning_rate = FLAGS['LR'] * xm.xrt_world_size()
    optimizer = torch.optim.AdamW(model.parameters(),
                      lr=learning_rate, weight_decay=5e-4)
    criterion = nn.BCEWithLogitsLoss()
    
    loaders = loaders
    criterion = criterion
    best_loss = 0.0
    running_losses = {'train': [], 'eval': []}
    
    for epoch in range(1, epochs + 1):
        phases = ['train']
        if epoch % validate_every == 0:
            phases.append('eval')

        for phase in phases:
            model.eval() if phase == 'eval' else model.train()
            gc.collect() # prevent OOM problems
            para_loader = pl.ParallelLoader(train_loader, [device]) 
            gc.collect()

            xm.master_print("Epoch {}/{}".format(epoch, epochs))
            loader = para_loader.per_device_loader(device)
            # loader = pl.MpDeviceLoader(self.loaders[phase], FLAGS['DEVICE'])
            for idx, (imgs, labels) in enumerate(tqdm(loader)):
                xm.master_print(f'Phase: {phase}, current step: {idx}')
                imgs, labels = imgs.float().to(device), labels.float().to(device)
                optimizer.zero_grad()
                outputs = model(imgs)
                loss = criterion(outputs, labels)
                loss_reduced = xm.mesh_reduce('loss_reduce', loss, reduce_fn) 
                running_losses[phase].append(loss_reduced.item())
                if phase == 'train':
                    loss.backward()
                    xm.optimizer_step(optimizer)

            mean_loss = np.array(running_losses[phase]).mean()
            if phase == 'eval':
                xm.master_print("Eval running loss: ", running_losses['eval'])
                if mean_loss < best_loss:
                    best_loss = mean_loss
                    xm.save(model.state_dict(), 'model_best.pth')
            xm.master_print(epoch, mean_loss, best_loss, (time.time()-start_time)/60**1)
    graph_losses(running_losses)

In [None]:
import time


FLAGS = {}
EPOCHS = 1 # For testing purposes
FLAGS['LR'] = 1e-4
FLAGS['IMG_SIZE'] = 256
FLAGS['BATCH_SIZE'] = 64
FLAGS['NUM_CLASSES'] = 19

start_time = time.time()

# Start training processes
def _mp_fn(rank, flags):
    a = run(EPOCHS)

# xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=1, start_method='fork')

### Submission
Test with one epoch's worth of training

In [None]:
state_dict = xser.load('../input/model-epoch-1pth/model_epoch_1.pth')
resnet50.load_state_dict(state_dict)

In [None]:
x = [name.rstrip('green.png').rstrip('_') for name in (os.listdir(test_images_path)) if '_green.png' in name]
y = np.zeros(len(x))
z = zip(x, y)
test_df = pd.DataFrame(list(z), columns = ['ID', 'Label'])
test_dataset = HPATestDataset(test_images_path, test_df, FLAGS['IMG_SIZE'], Transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
def test(model):
    s_ls = []
    with torch.no_grad():
        model.eval()
        DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
        for image, fname in test_loader:
            image = image.float().to(DEVICE)
            model.to(DEVICE)
            logits = model(image)                     
            prob = F.softmax(logits, dim=1)
            p, top_class = prob.topk(1, dim=1)
            sp = ' '.join(str(e) for e in [top_class[0][0].item(), p[0][0].item()])      
            img = cv2.imread(test_images_path + "/" + fname[0] + '_green.png')

            if img.shape[0] == 2048:
                sp = sp + ' eNoLCAgIMAEABJkBdQ=='
            elif img.shape[0] == 1728:
                sp = sp + ' eNoLCAjJNgIABNkBkg=='
            else:
                sp = sp + ' eNoLCAgIsAQABJ4Beg=='

            s_ls.append([fname[0], img.shape[1], img.shape[0], sp])
    return s_ls
            
results = test(resnet50)

In [None]:
sub = pd.DataFrame.from_records(results, columns=['ID', 'ImageWidth', 'ImageHeight', 'PredictionString'])
sub

In [None]:
sub.to_csv('submission.csv', index=False)