In [None]:
!pip install tez

In [None]:
import os

import albumentations
import matplotlib.pyplot as plt
import pandas as pd

import tez
from tez.datasets import ImageDataset
from tez.callbacks import EarlyStopping

import torch
import torch.nn as nn
from torch.nn import functional as F

import torchvision

from sklearn import metrics, model_selection, preprocessing

%matplotlib inline

In [None]:
dfx = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')

# and split it into training and validation sets
df_train, df_valid = model_selection.train_test_split(
    dfx, 
    test_size=0.1, 
    random_state=42,
    stratify=dfx.label.values
)

# reset index on both dataframes
df_train = df_train.reset_index(drop=True)
df_valid = df_valid.reset_index(drop=True)

# where are the train/valid images located?
image_path = "../input/cassava-leaf-disease-classification/train_images/"

# create a list of image paths for training
train_image_paths = [os.path.join(image_path, x) for x in df_train.image_id.values]

# create a list of image paths for validation
valid_image_paths = [os.path.join(image_path, x) for x in df_valid.image_id.values]

# targets for training
train_targets = df_train.label.values

# targets for validation
valid_targets = df_valid.label.values

In [None]:
train_dataset = ImageDataset(
    image_paths=train_image_paths,
    targets=train_targets,
   
    augmentations=None,
)

# and the validation dataset
valid_dataset = ImageDataset(
    image_paths=valid_image_paths,
    targets=valid_targets,
    
    augmentations=None,
)


In [None]:
train_dataset[0]

In [None]:
def plot_image(img_dict):
    image_tensor = img_dict["image"]
    target = img_dict["targets"]
    print(target)
    plt.figure(figsize=(10, 10))
    image = image_tensor.permute(1, 2, 0) / 255
    plt.imshow(image)

In [None]:
plot_image(train_dataset[0])

In [None]:
train_aug = albumentations.Compose(
    [
        albumentations.RandomResizedCrop(256, 256),
        albumentations.Transpose(p=0.5),
        albumentations.HorizontalFlip(p=0.5),
        albumentations.VerticalFlip(p=0.5),
        albumentations.ShiftScaleRotate(p=0.5),
        # albumentations.Normalize(
        #    mean=[0.485, 0.456, 0.406], 
        #    std=[0.229, 0.224, 0.225], 
        #    max_pixel_value=255.0, 
        #    p=1.0
        #)
    ]
)


# now, we set resize to None as we are doing 
# resizing via augmentations
train_dataset = ImageDataset(
    image_paths=train_image_paths,
    targets=train_targets,
    
    augmentations=train_aug,
)

In [None]:
plot_image(train_dataset[0])

In [None]:
class LeafModel(tez.Model):
    def __init__(self, num_classes):
        super().__init__()

        self.convnet = torchvision.models.resnet18(pretrained=True)
        self.convnet.fc = nn.Linear(512, num_classes)
        self.step_scheduler_after = "epoch"
        
    def monitor_metrics(self, outputs, targets):
        if targets is None:
            return {}
        outputs = torch.argmax(outputs, dim=1).cpu().detach().numpy()
        targets = targets.cpu().detach().numpy()
        accuracy = metrics.accuracy_score(targets, outputs)
        return {"accuracy": accuracy}
    
    def fetch_optimizer(self):
        opt = torch.optim.Adam(self.parameters(), lr=1e-3)
        return opt
    
    def fetch_scheduler(self):
        sch = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=0.7)
        return sch

    def forward(self, image, targets=None):
        batch_size, _, _, _ = image.shape

        outputs = self.convnet(image)
        
        if targets is not None:
            loss = nn.CrossEntropyLoss()(outputs, targets)
            metrics = self.monitor_metrics(outputs, targets)
            return outputs, loss, metrics
        return outputs, None, None

In [None]:
model = LeafModel(num_classes=5)

In [None]:
image = train_dataset[0]["image"].unsqueeze(0)
target = train_dataset[0]["targets"].unsqueeze(0)


model(image, target)

In [None]:
train_aug = albumentations.Compose([
            albumentations.RandomResizedCrop(256, 256),
            albumentations.Transpose(p=0.5),
            albumentations.HorizontalFlip(p=0.5),
            albumentations.VerticalFlip(p=0.5),
            albumentations.ShiftScaleRotate(p=0.5),
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            )], p=1.)
      
valid_aug = albumentations.Compose([
            albumentations.CenterCrop(256, 256, p=1.),
            albumentations.Resize(256, 256),
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            )], p=1.)

train_dataset = ImageDataset(
    image_paths=train_image_paths,
    targets=train_targets,
    
    augmentations=train_aug,
)

valid_dataset = ImageDataset(
    image_paths=valid_image_paths,
    targets=valid_targets,
    
    augmentations=valid_aug,
)

In [None]:
es = EarlyStopping(
    monitor="valid_accuracy", model_path="model.bin", patience=2, mode="max"
)
model.fit(
    train_dataset,
    valid_dataset=valid_dataset,
    train_bs=32,
    valid_bs=64,
    device="cuda",
    epochs=5,
    callbacks=[es],
    fp16=True,
)

In [None]:
test_dfx = pd.read_csv("../input/cassava-leaf-disease-classification/sample_submission.csv")
image_path = "../input/cassava-leaf-disease-classification/test_images/"
test_image_paths = [os.path.join(image_path, x) for x in test_dfx.image_id.values]
# fake targets
test_targets = test_dfx.label.values


test_aug = albumentations.Compose([
            albumentations.CenterCrop(256, 256, p=1.),
            albumentations.Resize(256, 256),
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            )], p=1.)

test_dataset = ImageDataset(
    image_paths=test_image_paths,
    targets=test_targets,
    
    augmentations=test_aug,
)


In [None]:
preds = model.predict(test_dataset, batch_size=32, n_jobs=-1)
final_preds = None
for p in preds:
    if final_preds is None:
        final_preds = p
    else:
        final_preds = np.vstack((final_preds, p))
final_preds = final_preds.argmax(axis=1)
test_dfx.label = final_preds
test_dfx.to_csv("submission.csv", index=False)