In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('bmh')
import os
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
OUTPUT_DIR = './'
image_size = 256
batch_size = 32

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

### Load Model

In [None]:
model = torch.load('../input/cassava-balanced-ce-model/sgd_balanced_ce_aug.pt')

### Get Data

In [None]:
class CassavaDataset(Dataset):
    def __init__(self, data_dir, ids, labels, transform=None):
        self.data_dir = data_dir
        self.ids = ids
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, idx):
        image = cv2.imread(os.path.join(self.data_dir, self.ids[idx]))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            image = self.transform(image=image)['image']
        
        label = self.labels[idx]    
        
        return (image, label)

In [None]:
transform = A.Compose([
    A.RandomResizedCrop(image_size, image_size),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.25),
    A.Transpose(p=0.25),
    A.RandomBrightnessContrast(
                brightness_limit=(-0.1,0.1), 
                contrast_limit=(-0.1, 0.1), 
                p=0.5),
    A.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0),
    ToTensorV2(p=1.0)
])

In [None]:
test_df = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')
test_dir = '../input/cassava-leaf-disease-classification/test_images'
ids = test_df['image_id'].values
labels = test_df['label'].values
test_df

In [None]:
test_dataset = CassavaDataset(test_dir, ids, labels, transform=transform)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

### Inference with TTA

In [None]:
softmax = nn.Softmax(dim = 1)

In [None]:
num_inferences = 10
inferences = []

for i in range(num_inferences): 
    inf = []
    model.eval()
    with torch.no_grad(): 
        for data in test_loader:
            inputs, labels = data
            inputs = inputs.to(device)
            outputs = softmax(model(inputs))
            outputs = outputs.cpu().numpy()
            inf += list(outputs)
    inferences.append(np.array(inf))

In [None]:
preds = np.zeros((inferences[0].shape))
for inf in inferences:
    preds += inf
preds = preds / num_inferences
preds = list(np.argmax(preds, axis=1))

In [None]:
test_df['label'] = preds
test_df.to_csv(OUTPUT_DIR+'submission.csv', index=False)

In [None]:
pd.read_csv(OUTPUT_DIR+'submission.csv')