In [None]:
# !pip install timm

In [None]:
import os
import albumentations as A
from albumentations.core.composition import Compose
from albumentations.pytorch import ToTensorV2
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
import cv2
import numpy as np

import timm
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import Dataset, DataLoader

# pytorch lightning
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning import Callback
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from sklearn import metrics, model_selection
%matplotlib inline


In [None]:
class CFG:
    seed = 42
    model_name='resnet50'
    pretrained=True
    img_size=256
    num_classes=5
    lr=1e-4
    min_lr=1e-3
    t_max=20
    num_epochs=10
    batch_size=64
    accum=1
    n_fold=5
    precision=16
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
df = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')

In [None]:
df.head()

In [None]:
df.label.value_counts()

In [None]:
df_train, df_valid = model_selection.train_test_split(
df,
test_size=0.1,
random_state=CFG.seed,
stratify=df.label.values)

df_train = df_train.reset_index(drop=True)
df_valid = df_valid.reset_index(drop=True)

In [None]:
df_train.shape, df_valid.shape

In [None]:
image_path = '../input/cassava-leaf-disease-classification/train_images'

train_image_paths = [os.path.join(image_path, x) for x in df_train.image_id.values]
valid_image_paths = [os.path.join(image_path, x) for x in df_valid.image_id.values]

In [None]:
train_image_paths[:5]

In [None]:
train_targets = df_train.label.values
valid_targets = df_valid.label.values

In [None]:
class CassavaLeafDataset(Dataset):
    def __init__(self, image_paths, targets, transform=None):
        self.image_paths = image_paths
        self.targets = targets
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, item):
        targets = self.targets[item]
        image = cv2.imread(self.image_paths[item])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        augmented = self.transform(image=image)
        
        image = augmented['image']
        image = np.transpose(image, (2, 0, 1)).astype(np.float32)
        image = image / 255.0

#         image = Image.open(self.image_paths[item])
#         image = np.array(image)
#         augmented = self.transform(image=image)
#         image = augmented['image']
# #         print(image)
#         image = np.transpose(image, (0, 1, 2)).astype(np.float32)
#         image_tensor = torch.tensor(image)
        return {
            'image': image, 
            'targets': targets
        }
        

In [None]:
train_aug = A.Compose([
        A.RandomResizedCrop(height=CFG.img_size, width=CFG.img_size),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Transpose(p=0.5)
])

valid_aug = A.Compose([
        A.CenterCrop(height=CFG.img_size, width=CFG.img_size, p=1.0),
        A.Resize(256, 256),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Transpose(p=0.5)
])

In [None]:
# def get_transform(phase: str):
#     if phase == 'train':
#         return Compose([
#             A.RandomResizedCrop(height=CFG.img_size, width=CFG.img_size),
#             A.HorizontalFlip(p=0.5),
#             A.ShiftScaleRotate(p=0.5),
#             A.RandomBrightnessContrast(p=0.5),
#             A.Normalize(),
#             ToTensorV2(),
#         ])
#     else:
#         return Compose([
#             A.Resize(height=CFG.img_size, width=CFG.img_size),
#             A.Normalize(),
#             ToTensorV2(),
#         ])

In [None]:
train_dataset = CassavaLeafDataset(
    image_paths=train_image_paths,
    targets=train_targets,
    transform=train_aug
)

valid_dataset = CassavaLeafDataset(
    image_paths=valid_image_paths,
    targets=valid_targets,
    transform=valid_aug
)


In [None]:
# valid_dataset[0]['image']

In [None]:
def plot_image(image_dict):
    img_tensor = image_dict['image']
    target = image_dict['targets']
    plt.figure(figsize=(5, 5))
    image = img_tensor/255
    print(image.shape)
    print(target)
    plt.imshow(image)

In [None]:
# plot_image(valid_dataset[1])

In [None]:
# plot_image(train_dataset[1])

In [None]:
# Data Loaders

train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, num_workers=2, pin_memory=True, drop_last=True, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=CFG.batch_size, num_workers=2, shuffle=False)

In [None]:
# # Defining model
# class CustomResNet(nn.Module):
#     def __init__(self, model_name='resnet18', pretrained=False):
#         super().__init__()
#         self.model = timm.create_model(model_name, pretrained=pretrained)
#         in_features = self.model.get_classifier().in_features
#         self.model.fc = nn.Linear(in_features, CFG.num_classes)
        
#     def forward(self, x):
#         x = self.model(x)
#         return x

In [None]:
#     def __init__(self):
#         super().__init__()
#         self.efficient_net = EfficientNet.from_name('efficientnet-b5')
#         self.efficient_net.load_state_dict(torch.load(PRETRAINED_PATH))
# #         self.efficient_net=EfficientNet.from_pretrained('efficientnet-b3',num_classes=CLASSES)
#         in_features=self.efficient_net._fc.in_features
#         self.efficient_net._fc=nn.Linear(in_features,CLASSES)
    
#     def forward(self,x):
#         out=self.efficient_net(x)
#         return out

In [None]:
class CassavaLeafModel(pl.LightningModule):
    def __init__(self):
        super(CassavaLeafModel, self).__init__()
        self.model = timm.create_model('resnet18', pretrained=True)
        in_features = self.model.get_classifier().in_features
        self.model.fc = nn.Linear(in_features, CFG.num_classes)
        self.metric = pl.metrics.F1(num_classes=CFG.num_classes)
        self.criterion = nn.CrossEntropyLoss()
        self.lr = CFG.lr
        
    def forward(self, x, *args, **kwargs):
        return self.model(x)
    
    def configure_optimizers(self):
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max= CFG.t_max, eta_min=CFG.min_lr)
        return {'optimizer': self.optimizer, 'lr_scheduler':self.scheduler}
    
    def training_step(self, batch, batch_idx):
        image = batch['image']
        target = batch['targets']
        output = self.model(image)
        loss = self.criterion(output, target)
        score = self.metric(output.argmax(1), target)
        logs = {'train_loss': loss, 'train_f1': score, 'lr':self.optimizer.param_groups[0]['lr']}
        self.log_dict(logs, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss
        
    def validation_step(self, batch, batch_idx):
        image = batch['image']
        target = batch['targets']
        output = self.model(image)
        loss = self.criterion(output, target)
        score = self.metric(output.argmax(1), target)
        logs = {
            'valid_loss': loss,
            'valid_f1': score,
        }
        self.log_dict(logs, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

In [None]:
# model = CustomResNet(model_name=CFG.model_name, pretrained=CFG.pretrained)
cassava_model = CassavaLeafModel()
# cassava_model

In [None]:
logger = CSVLogger(save_dir='logs/', name=CFG.model_name)
logger.log_hyperparams(CFG.__dict__)
checkpoint_callback = ModelCheckpoint(monitor='valid_loss',
                                    save_top_k=1,
                                     save_last=True,
                                     save_weights_only=True,
                                     filename='checkpoint/{epoch:02d}-{valid_loss:.4f}-{valid_f1:.4f}',
                                     verbose=False,
                                     mode='min')

trainer = Trainer(max_epochs=CFG.num_epochs,
                 gpus=1,
                 accumulate_grad_batches=CFG.accum,
                 precision=CFG.precision,
                 checkpoint_callback=checkpoint_callback,
                 logger=logger,
                 weights_summary='top',
)

In [None]:
trainer.fit(cassava_model, train_dataloader=train_loader, val_dataloaders=valid_loader)

In [None]:
test_df_ = pd.read_csv("../input/cassava-leaf-disease-classification/sample_submission.csv")
image_path = "../input/cassava-leaf-disease-classification/test_images/"
test_image_paths = [os.path.join(image_path, x) for x in test_df_.image_id.values]
# fake targets
test_targets = test_df_.label.values


test_aug = A.Compose([
            A.CenterCrop(256, 256, p=1.),
            A.Resize(256, 256),
            A.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            )], p=1.)

test_dataset = CassavaLeafDataset(
    image_paths=test_image_paths,
    targets=test_targets,
    transform=test_aug,
)

test_loader = DataLoader(test_dataset, batch_size=CFG.batch_size)

In [None]:
test_img = test_dataset[0]['image']
# test_img

In [None]:
# torch.from_numpy(test_img)

In [None]:
# cassava_model

In [None]:
best_checkpoints = trainer.checkpoint_callback.best_model_path
# pretrained_model = CassavaLite().load_from_checkpoint(checkpoint_path = best_checkpoints)
# pretrained_model = pretrained_model.to("cuda")
pretrained_model = CassavaLeafModel().load_from_checkpoint(checkpoint_path = best_checkpoints)
pretrained_model = pretrained_model.to('cuda')
pretrained_model.eval()
pretrained_model.freeze()
# pretrained_model

In [None]:
fin_out = []
for data in test_loader:
    y_hat = pretrained_model(data["image"].to("cuda"))
    y_hat = torch.argmax(y_hat,dim=1)
    fin_out.extend(y_hat.cpu().detach().numpy().tolist())
test_df_["label"] = fin_out
test_df_[["image_id","label"]].to_csv("submission.csv",index=False)
test_df_.head()