# Cassava PyTorch XLA/TPU starter (GPU inference)

### If you found this helpful, please give it an upvote!

This is a GPU inference kernel for my [PyTorch TPU starter kernel](https://www.kaggle.com/tanlikesmath/cassava-pytorch-xla-tpu-starter-training) (TPU notebook inference is not allowed in this competition).

## Installs & Imports

The below cell will install the [timm]() library, which is what we will use to define our models and get pretrained weights.

In [None]:
!pip install ../input/pytorch-image-models/timm-0.3.1-py3-none-any.whl > /dev/null

Here are all of our imports!

In [None]:
import gc
import os
import time
import torch
import albumentations

import numpy as np
import pandas as pd

import cv2
from PIL import Image

import torch.nn as nn
from sklearn import metrics
from sklearn import model_selection
from torch.nn import functional as F
from torch.optim import Adam

import timm

from tqdm.notebook import tqdm

import warnings
warnings.filterwarnings("ignore")

## Definitions

Now let's define the necessary functions and variables needed for training.

These are the flags for inference.

In [None]:
FLAGS = {
    
    'num_folds': 5,
    'model': 'resnext50_32x4d',
    'model_path': '../input/cassava-pytorch-xla-tpu-starter-training/',
    'batch_size': 32,
    'epochs': 10,
    'num_workers': 4,
}

Here, I define a model class for the timm models.

In [None]:
# Using Ross Wightman's timm package
class TimmModels(nn.Module):
    def __init__(self, model_name,pretrained=True, num_classes=5):
        super(TimmModels, self).__init__()
        self.m = timm.create_model(model_name,pretrained=pretrained)
        model_list = list(self.m.children())
        model_list[-1] = nn.Linear(
            in_features=model_list[-1].in_features, 
            out_features=num_classes, 
            bias=True
        )
        self.m = nn.Sequential(*model_list)
        
    def forward(self, image):
        out = self.m(image)
        return out


Here, I define a class for the PyTorch Dataset (taken from @abhishek's amazing [Tez package](https://github.com/abhishekkrthakur/tez)).

In [None]:
# Image Dataset class taken from Abhishek's tez package

class ImageDataset:
    def __init__(
        self,
        image_paths,
        targets,
        resize,
        augmentations=None,
        backend="pil",
        channel_first=True,
    ):
        """
        :param image_paths: list of paths to images
        :param targets: numpy array
        :param resize: tuple or None
        :param augmentations: albumentations augmentations
        """
        self.image_paths = image_paths
        self.targets = targets
        self.resize = resize
        self.augmentations = augmentations
        self.backend = backend
        self.channel_first = channel_first

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, item):
        targets = self.targets[item]
        if self.backend == "pil":
            image = Image.open(self.image_paths[item])
            if self.resize is not None:
                image = image.resize(
                    (self.resize[1], self.resize[0]), resample=Image.BILINEAR
                )
            image = np.array(image)
            if self.augmentations is not None:
                augmented = self.augmentations(image=image)
                image = augmented["image"]
        elif self.backend == "cv2":
            image = cv2.imread(self.image_paths[item])
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            if self.resize is not None:
                image = cv2.resize(
                    image,
                    (self.resize[1], self.resize[0]),
                    interpolation=cv2.INTER_CUBIC,
                )
            if self.augmentations is not None:
                augmented = self.augmentations(image=image)
            image = augmented["image"]
        else:
            raise Exception("Backend not implemented")
        if self.channel_first:
            image = np.transpose(image, (2, 0, 1)).astype(np.float32)
        return {
            "image": torch.tensor(image),
            "targets": torch.tensor(targets),
        }

## Inference code

Let's start predicting! To do so, we start by initializing the model.

In [None]:
MX = TimmModels(FLAGS['model'],pretrained=False, num_classes=5)

Let's now define our inference function.

In [None]:
def single_model_inference_fn(data_loader, model, device):
    fin_outputs = []
    model.eval()
    model.to(device)
    for bi, d in enumerate(tqdm(data_loader)): # enumerate through dataloader
        
        images = d['image'].to(device) # obtain the images

        # pass image to model
        outputs = model(images)

        # Add the outputs and targets to a list 
        outputs_np = outputs.cpu().detach().numpy().tolist()
        fin_outputs.extend(outputs_np)    
        del outputs_np
        gc.collect() # delete for memory conservation
                
    o = np.array(fin_outputs)
    return o

In [None]:
def tta(num_times, data_loader, model, device):
    final_preds = None
    for i in range(num_times):
        temp_preds = single_model_inference_fn(data_loader, model, device)
        if final_preds is None:
            final_preds = temp_preds
        else:
            final_preds += temp_preds
        
    final_preds /= num_times
    return final_preds

In [None]:
device = torch.device('cuda:0')


mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

test_aug = albumentations.Compose(
    [
        albumentations.Resize(256, 256, p=1.0),
        albumentations.Normalize(
            mean, 
            std, 
            max_pixel_value=255.0, 
            always_apply=True
        ),
        albumentations.Transpose(p=0.5),
        albumentations.HorizontalFlip(p=0.5),
        albumentations.VerticalFlip(p=0.5),
        albumentations.ShiftScaleRotate(p=0.5),
        albumentations.HueSaturationValue(
            hue_shift_limit=0.2, 
            sat_shift_limit=0.2, 
            val_shift_limit=0.2, 
            p=0.5
        ),
        albumentations.RandomBrightnessContrast(
            brightness_limit=(-0.1,0.1), 
            contrast_limit=(-0.1, 0.1), 
            p=0.5
        ),
        albumentations.CoarseDropout(p=0.5),
        albumentations.Cutout(p=0.5)
    ]
)

df_test = pd.read_csv("../input/cassava-leaf-disease-classification/sample_submission.csv")
test_image_paths = "../input/cassava-leaf-disease-classification/test_images/"

test_images = df_test.image_id.values.tolist()
test_images = [
    os.path.join(test_image_paths, i) for i in test_images
]

test_dataset = ImageDataset(
    image_paths=test_images,
    targets=[0]*len(test_images),
    resize=None,
    augmentations=test_aug,
)


test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=FLAGS['batch_size'],
    num_workers=FLAGS['num_workers'],
    drop_last=False)

In [None]:
final_preds = None
for fold in range(FLAGS['num_folds']):
    state_dict = torch.load(os.path.join(FLAGS['model_path'],f"xla_trained_model_{FLAGS['epochs']}_epochs_fold_{fold}.pth"))
    MX.load_state_dict(state_dict)
    
    fold_preds = tta(num_times=5, data_loader=test_loader, model=MX, device=device)
    
    if final_preds is None:
        final_preds = fold_preds
    else:
        final_preds += fold_preds 
final_preds /= FLAGS['num_folds']

In [None]:
final_preds = final_preds.argmax(axis=1)
df_test.label = final_preds

In [None]:
df_test.to_csv("submission.csv", index=False)

Now, **WE ARE DONE!**

If you enjoyed this kernel, please give it an upvote. If you have any questions or suggestions, please leave a comment!

Make sure to check out the training kernel [here](https://www.kaggle.com/tanlikesmath/cassava-pytorch-xla-tpu-starter-training).

Also, check out my [related kernel](https://www.kaggle.com/tanlikesmath/the-ultimate-pytorch-tpu-tutorial-jigsaw-xlm-r) with more detailed information on PyTorch XLA/TPU training.