# Importing Libraries

In [None]:
import os
import sys
sys.path.append("../input/efficientnetpytorch") #for efficient model

#Basic library read and split data-csv
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold

#for albumentations uses cv2 where as torchvision transforms uses PIL
import cv2
import albumentations as A
from albumentations.pytorch.transforms import ToTensor

#PyTorch - deep learning framework
import torch 
from torch import nn
from torch.utils.data import Dataset,DataLoader
import torch.nn.functional as F

#pytorch-lightning on top of PyTorch framework
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.callbacks import ModelCheckpoint 

#for efficient model transfer learning
from efficientnet_pytorch import EfficientNet

# Constants

In [None]:
IMAGES_DIRS = "../input/cassava-leaf-disease-classification/train_images/"
TRAIN_FILE = "../input/cassava-leaf-disease-classification/train.csv"
PRETRAINED_PATH = "../input/resources-for-google-landmark-recognition-2020/efficientnet-b3-5fb5a3c3.pth"
BATCH_SIZE = 40
IMG_SIZE = 512
CLASSES = 5

# Lightning Computation Module (Research code)

In [None]:
class CassavaEfficientNet(pl.LightningModule):
  def __init__(self):
    super().__init__()
    self.efficient_net = EfficientNet.from_name('efficientnet-b3')
    #if you have acces to internet use just \
    #use this- EfficientNet.from_pretrained('efficientnet-b3',num_classes=CLASSES)
    self.efficient_net.load_state_dict(torch.load(PRETRAINED_PATH))
    in_features = self.efficient_net._fc.in_features
    self.efficient_net._fc = nn.Linear(in_features,CLASSES)
    
  def forward(self,x):
    out = self.efficient_net(x)
    return out
  
  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(),lr = 1e-4)
    return optimizer
  
  def training_step(self,batch,batch_idx):
    x,y = batch["x"],batch["y"]
    y_hat = self(x)
    loss = F.cross_entropy(y_hat,y)
    # logs metrics for each training_step - [default:True],
    # the average across the epoch, to the progress bar and logger-[default:False]
    acc = accuracy(y_hat,y)
    self.log("train_acc",acc,on_step=False,on_epoch=True,prog_bar=True,logger=True),
    self.log("train_loss",loss,on_step=False,on_epoch=True,prog_bar=True,logger=True)
    return loss
  
  def validation_step(self,batch,batch_idx):
    x,y = batch["x"],batch["y"]
    y_hat = self(x)
    loss = F.cross_entropy(y_hat,y)
    acc = accuracy(y_hat,y)
    # logs metrics for each validation_step - [default:False]
    #the average across the epoch - [default:True]
    self.log("val_acc",acc,prog_bar=True,logger=True),
    self.log("val_loss",loss,prog_bar=True,logger=True)

## Train Dataset Loader 

In [None]:
class CassavaDataset(Dataset):
  def __init__(self,path,image_ids,labels,transform):
    super().__init__()
    self.image_ids = image_ids
    self.labels = labels
    self.path = path
    self.transform = transform
      
  def __len__(self):
    return len(self.image_ids)
  
  def __getitem__(self,item):
    image_id = str(self.image_ids[item])
    img = cv2.imread(self.path+image_id)
    img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    img = self.transform(image=img)
    #albumentations transform return a dictionary with "image" as key
    image = img["image"]
    label = self.labels[item]
    return {
        "x":image,
        "y":label,
    }    

# Lightning Data Module

In [None]:

class CassavaDataModule(pl.LightningDataModule):
  def __init__(self):
    super().__init__()
    self.train_transform = A.Compose([A.Resize(IMG_SIZE,IMG_SIZE),
                                      A.RandomCrop(318,318),
                                      A.HorizontalFlip(),
                                      A.VerticalFlip(),
                                      A.ShiftScaleRotate(),
                                      A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                                      ToTensor()])
    self.test_transform = A.Compose([A.Resize(IMG_SIZE,IMG_SIZE),
                                     A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                                     ToTensor()])
  
  def prepare_data(self):
    # prepare_data is called only once on 1- GPU in a distributed computing
    df = pd.read_csv(TRAIN_FILE)
    df["kfold"] =-1
    df = df.sample(frac=1).reset_index(drop=True)
    stratify = StratifiedKFold(n_splits=5)
    for i,(t_idx,v_idx) in enumerate(stratify.split(X=df.image_id.values,y=df.label.values)):
      df.loc[v_idx,"kfold"]=i
    df.to_csv("train_folds.csv",index=False)

  def setup(self,stage=None):
    dfx = pd.read_csv("train_folds.csv")
    train = dfx.loc[dfx["kfold"]!=1]
    val = dfx.loc[dfx["kfold"]==1]
    self.train_dataset = CassavaDataset(IMAGES_DIRS,
                                        image_ids = train.image_id.values,
                                        labels = train.label.values,
                                        transform = self.train_transform)
    self.valid_dataset = CassavaDataset(IMAGES_DIRS,
                                        image_ids = val.image_id.values,
                                        labels = val.label.values,
                                        transform = self.test_transform)
  
  def train_dataloader(self):
    return DataLoader(self.train_dataset,
                      batch_size=BATCH_SIZE,
                      num_workers=4,
                      shuffle=True)
  
  def val_dataloader(self):
    return DataLoader(self.valid_dataset,
                      batch_size=BATCH_SIZE,
                      num_workers=4)
  
  #def test_dataloader(self):
   # pass

Saving Models in each epoch as *.ckpt*

In [None]:
model_checkpoint = ModelCheckpoint(monitor = "val_loss",
                                   verbose=True,
                                   filename="{epoch}_{val_loss:.4f}")

# Finally- Trainer

In [None]:
dm = CassavaDataModule()
cassava_model = CassavaEfficientNet()

#CPU:default,GPU:gpus,TPU:tpu_cores
trainer = pl.Trainer(gpus=-1,
                     max_epochs=6,
                     callbacks=[model_checkpoint]) 
trainer.fit(model=cassava_model,
            datamodule=dm) 

#manually you can save best checkpoints - 
trainer.save_checkpoint("cassava_efficient_net.ckpt")

# INFERENCE 

## Setting Inference Data Loader

In [None]:
TEST_IMAGE_DIRS = "../input/cassava-leaf-disease-classification/test_images/"
test = pd.read_csv("../input/cassava-leaf-disease-classification/sample_submission.csv")


#this is inference dataset object, as it does not have labels
class CassavaTestData(Dataset):
  def __init__(self,path,image_ids):
    super().__init__()
    self.image_ids = image_ids
    self.path = path
    self.transform = A.Compose([A.Resize(IMG_SIZE,IMG_SIZE),
                             A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                             ToTensor()])
      
  def __len__(self):
    return len(self.image_ids)
  
  def __getitem__(self,item):
    image_id = str(self.image_ids[item])
    image = cv2.imread(self.path+image_id)
    image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
    image = self.transform(image=image)
    return {
        "x":image["image"],
    }
test_dataset = CassavaTestData(path = TEST_IMAGE_DIRS,
                              image_ids = test.image_id.values)
test_loader = DataLoader(test_dataset,
                        batch_size=32)

## Freeze trained model and predict

In [None]:
#loading the best checkpoints to model
best_checkpoints = trainer.checkpoint_callback.best_model_path
pretrained_model = CassavaEfficientNet.load_from_checkpoint(checkpoint_path = best_checkpoints)
pretrained_model = pretrained_model.to("cuda")
pretrained_model.eval()
pretrained_model.freeze()

fin_out = []
for data in test_loader:
    y_hat = pretrained_model(data["x"].to("cuda"))
    y_hat = torch.argmax(y_hat,dim=1)
    fin_out.extend(y_hat.cpu().detach().numpy().tolist())
test["label"] = fin_out
test[["image_id","label"]].to_csv("submission.csv",index=False)
test.head()

In [None]:
#uncomment below lines to view the tensorboard
#%load_ext tensorboard
#%tensorboard --logdir ./lightning_logs