In [None]:
!pip install tez
!pip install albumentation

In [None]:
import os
import albumentations
import matplotlib.pyplot as plt
import pandas as pd

import tez
from tez.datasets import ImageDataset
from tez.callbacks import EarlyStopping

import torch
import torch.nn as nn
import torchvision

from sklearn import metrics, model_selection


In [None]:
dfx = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')

In [None]:
dfx.head()

In [None]:
dfx.label.value_counts()

In [None]:
df_train, df_valid = model_selection.train_test_split(
    dfx,
    test_size=0.1,
    random_state=42,
    stratify=dfx.label.values
)

df_train = df_train.reset_index(drop = True)
df_valid = df_valid.reset_index(drop = True)

In [None]:
df_train.shape

In [None]:
df_valid.shape

In [None]:
image_path = "../input/cassava-leaf-disease-classification/train_images/"


train_image_paths = [
    os.path.join(image_path, x) for x in df_train.image_id.values
]

valid_image_paths = [
    os.path.join(image_path, x) for x in df_valid.image_id.values
]

In [None]:
train_image_paths[:5]

In [None]:
train_targets = df_train.label.values
valid_targets = df_valid.label.values

In [None]:
train_dataset = ImageDataset(
    image_paths= train_image_paths,
    targets=train_targets,
    
    augmentations = None
)

In [None]:
train_dataset[0]

In [None]:
def plot_img(image_dict):
    image_tensor = image_dict["image"]
    target = image_dict["targets"]
    print(target)
    plt.figure(figsize=(10,10))
    image = image_tensor.permute(1, 2, 0) / 255
    plt.imshow(image)

In [None]:
plot_img(train_dataset[0])

In [None]:
train_aug = albumentations.Compose(
    [
        albumentations.RandomResizedCrop(256, 256),
        albumentations.Transpose(p=0.5),
        albumentations.HorizontalFlip(p=0.5),
        albumentations.VerticalFlip(p=0.5),
    ]

)


valid_aug = albumentations.Compose(
    [
        albumentations.CenterCrop(256, 256, p=1.0),
        albumentations.RandomResizedCrop(256, 256),
        albumentations.Transpose(p=0.5),
        albumentations.HorizontalFlip(p=0.5),
        albumentations.VerticalFlip(p=0.5),
    ]

)

train_dataset = ImageDataset(
    image_paths= train_image_paths,
    targets=train_targets,
    
    augmentations = train_aug
)


valid_dataset = ImageDataset(
    image_paths= valid_image_paths,
    targets=valid_targets,
    
    augmentations = valid_aug
)

In [None]:
plot_img(train_dataset[0])

In [None]:
class LeafModel(tez.Model):
    def __init__(self, num_classes, pretrained=True):
        super().__init__()
        self.convnet = torchvision.models.resnet18(pretrained=pretrained)
        self.convnet.fc = nn.Linear(512, num_classes)
        self.step_scheduler_after = "epoch"
    
    def loss(self, outputs, targets):
        if targets is None:
            return None
        return nn.CrossEntropyLoss()(outputs, targets)
    
    def monitor_matrics(self, outputs, targets):
        outputs = torch.argmax(outputs, dim=1).cpu().detach().numpy()
        targets = targets.cpu().detach().numpy()
        acc = metrics.accuracy_score(targets, outputs)
        return{
            "accuracy":acc
        }
        
    def fetch_optimizer(self):
        opt = torch.optim.Adam(self.parameters(), lr = 1e-3)
        return opt
    
    def fetch_sheduler(self):
        sch = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=0.7)
        return sch
        
    def forward(self, image, targets=None):
        outputs = self.convnet(image)
        if targets is not None:
            loss = self.loss(outputs, targets)
            mon_metrics = self.monitor_matrics(outputs, targets)
            return outputs, loss, mon_metrics
        return outputs, None, None

In [None]:
model = LeafModel(num_classes=dfx.label.nunique(),pretrained=True)

In [None]:
img = train_dataset[10]["image"]
y = train_dataset[10]["targets"]


model(img.unsqueeze(0), y.unsqueeze(0))

In [None]:
es = EarlyStopping(
    monitor="valid_accuracy",
    model_path="model.bin",
    patience=2,
    mode = "max"
)

model.fit(
    train_dataset,
    valid_dataset=valid_dataset,
    train_bs=32,
    valid_bs=64,
    device = 'cuda',
    callbacks=[es],
    fp16=True,
    epochs = 10
)

In [None]:
test_dfx = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')
image_path = "../input/cassava-leaf-disease-classification/test_images/"
test_targets = test_dfx.label.values

test_image_paths = [
    os.path.join(image_path, x) for x in test_dfx.image_id.values
]

test_aug = albumentations.Compose(
    [
        albumentations.CenterCrop(256, 256, p=1.0),
        albumentations.RandomResizedCrop(256, 256),
        albumentations.Transpose(p=0.5),
        albumentations.HorizontalFlip(p=0.5),
        albumentations.VerticalFlip(p=0.5),
    ]

)

test_dataset = ImageDataset(
    image_paths=test_image_paths,
    targets=test_targets,
    augmentations = test_aug
)

In [None]:
pred = model.predict(
    test_dataset,
    batch_size=64, 
    n_jobs = -1
)

final_preds = None

for p in pred:
    if final_preds is None:
        final_preds = p
    else:
        final_preds= np.vstack((final_preds, p))
        
final_preds = final_preds.argmax(axis=1)
test_dfx.label = final_preds
test_dfx.to_csv("submission.csv", index=False)