In [None]:
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import cv2
import torch
from torch import nn
import os
import sys
import albumentations as A
from albumentations.pytorch.transforms import ToTensor

In [None]:
package_path = '../input/pytorch-image-models/pytorch-image-models-master'
sys.path.append(package_path) 

In [None]:
import timm
SEED = 2484
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
torch.cuda.manual_seed_all(SEED)
torch.is_deterministic=True

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
def get_tta(image_size=380, p=0.5):
    imagenet_stats = {"mean": [0.485, 0.456, 0.406],
                      "std": [0.229, 0.224, 0.225]}
    augs = A.Compose(
        [
            A.Resize(428, 428, cv2.INTER_CUBIC),
            A.RandomCrop(image_size, image_size),
            A.RandomRotate90(p=p),
            A.RandomBrightnessContrast(
                brightness_limit=0.2, contrast_limit=0.2, p=p),
            A.OneOf(
                [
                    A.MotionBlur(p=p),
                    A.MedianBlur(blur_limit=3, p=p),
                    A.Blur(blur_limit=3, p=p),
                    A.GaussianBlur(blur_limit=(3, 5), p=p)
                ],
                p=p,
            ),
            A.OneOf(
                [
                    A.OpticalDistortion(p=p),
                    A.GridDistortion(p=p)
                ],
                p=p,
            ),
            ToTensor(normalize=imagenet_stats),
        ]
    )
    return augs

In [None]:
class CassavaLeafDataset(Dataset):
    def __init__(self, root_dir, transforms):
        self.root_dir = root_dir
        self.transform = transforms
        self.dataframe = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')

    def __len__(self):
        return self.dataframe.shape[0]

    def __getitem__(self, idx):
        data = self.dataframe.iloc[idx]
        img_name = os.path.join(self.root_dir, data[0])
        image = cv2.imread(img_name)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.transform(image=image)['image']
        return image, data[1]

In [None]:
test_transforms = get_tta(380)
test_data = CassavaLeafDataset('../input/cassava-leaf-disease-classification/test_images', test_transforms)
test_loader = DataLoader(test_data, batch_size=16, pin_memory=True, num_workers=4)

In [None]:
df = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')
df.shape[0]

In [None]:
def get_probabilities(model):
    model.eval()
    model.to(device)
    res = torch.zeros((df.shape[0], 5), device=device)
    epochs = 10
    with torch.no_grad():
        for _ in range(epochs):
            predictions = torch.tensor([], device=device)
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                predictions = torch.cat((predictions, outputs.data), dim=0)
            res = res + predictions
        res = res / epochs
    return res

In [None]:
resnest = timm.create_model('resnest50d_1s4x24d', pretrained=False)
resnest.fc = nn.Linear(in_features=2048, out_features=5, bias=True)
resnest.load_state_dict(torch.load('../input/mycassavamodels/resnest50.pth'))
resnest_probs = get_probabilities(resnest)

In [None]:
xception = timm.create_model('xception65', pretrained=False)
xception.head.fc = nn.Linear(in_features=2048, out_features=5, bias=True)
xception.load_state_dict(torch.load('../input/mycassavamodels/xception65.pth'))
xception_probs = get_probabilities(xception)

In [None]:
efficientb4 = timm.create_model('tf_efficientnet_b4_ns', pretrained=False)
efficientb4.classifier = nn.Linear(in_features=1792, out_features=5, bias=True)
efficientb4.load_state_dict(torch.load('../input/mycassavamodels/efficientnetB4.pth'))
efficientb4_probs = get_probabilities(efficientb4)

In [None]:
res = resnest_probs + efficientb4_probs + xception_probs
res = res / 3.
pred = torch.argmax(res, dim=1)
print(pred)

In [None]:
labels = [elem.item() for elem in pred]
df['label'] = labels
df.to_csv('./submission.csv', index=False)
df.head()