In [None]:
tez_path = '../input/tez-lib/'
effnet_path = '../input/efficientnet-pytorch/'
import sys
sys.path.append(tez_path)
sys.path.append(effnet_path)

In [None]:
import os
import albumentations
import pandas as pd
import numpy as np

import tez
from tez.datasets import ImageDataset

import torch
import torch.nn as nn
from torch.nn import functional as F

from efficientnet_pytorch import EfficientNet

In [None]:
class LeafModel(tez.Model):
    def __init__(self, num_classes):
        super().__init__()

        self.effnet = EfficientNet.from_name("efficientnet-b4")
        self.dropout = nn.Dropout(0.1)
        self.out = nn.Linear(1792, num_classes)
        self.step_scheduler_after = "epoch"

    def forward(self, image, targets=None):
        batch_size, _, _, _ = image.shape

        x = self.effnet.extract_features(image)
        x = F.adaptive_avg_pool2d(x, 1).reshape(batch_size, -1)
        outputs = self.out(self.dropout(x))
        return outputs, None, None

In [None]:
# augmentations taken from: https://www.kaggle.com/khyeh0719/pytorch-efficientnet-baseline-inference-tta
test_aug = albumentations.Compose([
    albumentations.RandomResizedCrop(256, 256),
    albumentations.Transpose(p=0.5),
    albumentations.HorizontalFlip(p=0.5),
    albumentations.VerticalFlip(p=0.5),
    albumentations.HueSaturationValue(
        hue_shift_limit=0.2, 
        sat_shift_limit=0.2,
        val_shift_limit=0.2, 
        p=0.5
    ),
    albumentations.RandomBrightnessContrast(
        brightness_limit=(-0.1,0.1), 
        contrast_limit=(-0.1, 0.1), 
        p=0.5
    ),
    albumentations.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225], 
        max_pixel_value=255.0, 
        p=1.0
    )
], p=1.)

In [None]:
dfx = pd.read_csv("../input/cassava-leaf-disease-classification/sample_submission.csv")
image_path = "../input/cassava-leaf-disease-classification/test_images/"
test_image_paths = [os.path.join(image_path, x) for x in dfx.image_id.values]
# fake targets
test_targets = dfx.label.values
test_dataset = ImageDataset(
    image_paths=test_image_paths,
    targets=test_targets,
    resize=None,
    augmentations=test_aug,
)

In [None]:
train_dfx = pd.read_csv("../input/cassava-leaf-disease-classification/train.csv")
model = LeafModel(num_classes=train_dfx.label.nunique())
model.load("../input/leafmodel/model.bin")

In [None]:
# run inference 5 times
final_preds = None
for j in range(5):
    preds = model.predict(test_dataset, batch_size=32, n_jobs=-1, device="cuda")
    temp_preds = None
    for p in preds:
        if temp_preds is None:
            temp_preds = p
        else:
            temp_preds = np.vstack((temp_preds, p))
    if final_preds is None:
        final_preds = temp_preds
    else:
        final_preds += temp_preds
final_preds /= 5

In [None]:
final_preds = final_preds.argmax(axis=1)

In [None]:
dfx.label = final_preds

In [None]:
dfx.to_csv("submission.csv", index=False)