In [None]:
!pip install tez

In [None]:
import os
import albumentations
import tez
import pandas as pd
import matplotlib.pyplot as plt


from tez.datasets import ImageDataset
from tez.callbacks import EarlyStopping

import torch
import torch.nn as nn

import torchvision

from sklearn import metrics, model_selection

%matplotlib inline

In [None]:
dfx = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')

In [None]:
dfx.head()

In [None]:
dfx.label.value_counts()

In [None]:
df_train, df_valid = model_selection.train_test_split(dfx, test_size=0.1,random_state=42, stratify=dfx.label.values)

In [None]:
df_train = df_train.reset_index(drop=True)
df_valid = df_valid.reset_index(drop=True)

In [None]:
print (df_train.shape)
print (df_valid.shape)

In [None]:
image_path = '../input/cassava-leaf-disease-classification/train_images/'

train_images_path = [os.path.join(image_path, x) for x in df_train.image_id.values]
valid_images_path = [os.path.join(image_path, x) for x in df_valid.image_id.values]

In [None]:
print (train_images_path[:5])

In [None]:
train_targets = df_train.label.values
valid_targets = df_valid.label.values

In [None]:
train_aug = albumentations.Compose([
            albumentations.RandomResizedCrop(256, 256),
            albumentations.Transpose(p=0.5),
            albumentations.HorizontalFlip(p=0.5),
            albumentations.VerticalFlip(p=0.5),
            albumentations.ShiftScaleRotate(p=0.5),
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            )], p=1.)
      
valid_aug = albumentations.Compose([
            albumentations.CenterCrop(256, 256, p=1.),
            albumentations.Resize(256, 256),
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            )], p=1.)


train_dataset = ImageDataset(image_paths=train_images_path, 
                             targets=train_targets, resize=None, augmentations=train_aug)

valid_dataset = ImageDataset(image_paths=valid_images_path, 
                             targets=valid_targets, resize=None, augmentations=valid_aug)

In [None]:
train_dataset[0]

In [None]:
class LeafModel(tez.Model):
    def __init__(self, num_classes, pretrained=True):
        super().__init__()
        print (num_classes)
        self.convnet = torchvision.models.resnet18(pretrained=pretrained)
        self.convnet.fc = nn.Linear(512, num_classes)
        self.step_scheduler_after = "epoch"
        
    def loss(self, output, target=None):
        if target is None:
            return None
        return nn.CrossEntropyLoss()(output, target)
    
    def monitor_metrics(self, output, targets):
        output = torch.argmax(output, dim=1).cpu().detach().numpy()
        target = targets.cpu().detach().numpy()
        acc = metrics.accuracy_score(output, target)
        return {
            "accuracy":acc
        }
    
    def fetch_optimizer(self):
        opt = torch.optim.Adam(self.convnet.parameters(), lr=1e-3)
        return opt
    
    def fetch_scheduler(self):
        sch = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1
        )
        return sch
        
    def forward(self, image, targets=None):
        output = self.convnet(image)
        if targets is not None:
            loss = self.loss(output, targets)
            mon_metrics = self.monitor_metrics(output, targets)
            return output, loss, mon_metrics
        return output, None, None

In [None]:
model = LeafModel(num_classes=dfx.label.nunique(), pretrained=True)

In [None]:
x = train_dataset[0]["image"]
y = train_dataset[0]["targets"]
model(x.unsqueeze(0), y.unsqueeze(0))

In [None]:
es = EarlyStopping(
        monitor="valid_accuracy",
        model_path = "model.bin",
        patience = 2,
        mode = "max"
    )

In [None]:
model.fit(train_dataset, 
          valid_dataset=valid_dataset, 
          train_bs=32, 
          valid_bs=64,
          device = "cuda",
          callbacks = [es],
          fp16 = True,
          epochs = 15
         )

In [None]:
model.save("model.bin")

In [None]:
#final_model = model.load('model.bin')

In [None]:
# test_aug = albumentations.Compose(
#     [
#             albumentations.RandomResizedCrop(256,256),
#             albumentations.Transpose(p=0.5),
#             albumentations.HorizontalFlip(p=0.5),
#             albumentations.VerticalFlip(p=0.5)
#     ]
# )

In [None]:
# df_test = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')

In [None]:
# test_path = '../input/cassava-leaf-disease-classification/test_images'

In [None]:
# test_image_paths = [os.path.join(test_path, x) for x in df_test.image_id]

In [None]:
# print(test_image_paths)

In [None]:
# test_aug = albumentations.Compose(
#     [
        
#         albumentations.CenterCrop(256,256, p=1.0),
#         albumentations.Resize(256,256),
#         albumentations.Transpose(p=0.5),
#         albumentations.HorizontalFlip(p=0.5),
#         albumentations.VerticalFlip(p=0.5)
        
#     ]

# )
# test_targets = df_test.label.values

In [None]:
# test_dataset = ImageDataset(image_paths=test_image_paths, 
#                              targets=test_targets, resize=None, augmentations=test_aug)

In [None]:
# test_dataset[0]

In [None]:
# preds = model.predict(test_dataset, batch_size=32, n_jobs=-1, device="cuda")
# final_preds = None
# for p in preds:
#     if final_preds is None:
#         final_preds = p
#     else:
#         final_preds = np.vstack((final_preds, p))
# final_preds = final_preds.argmax(axis=1)
# df_test.label = final_preds
# df_test.to_csv("submission.csv", index=False)

In [None]:
# df_test

In [None]:
### Autoencoder -- commit
### Commit this notebook as well
