# Imports and GPU-Check

First, let's import everything that we'll need:
* PyTorch and its Torchvision library (for image-related data preparation)
* Pandas for reading CSVs
* Numpy for storing arrays
* os for some path-related functions
* PyTorch Lightning for the magic :)

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import time
import torch.nn.functional as F
import torch.nn as nn
from torchvision import models
import pandas as pd
import numpy as np
import os

import pytorch_lightning as pl

Here we're just going to check for CUDA so that we can run on the GPU - I **highly recommend** using a GPU for image-related ML tasks, as the CPU is not suited for this kind of training and therefore runs *wayyyy* too slowly for productive training.

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# The Briefest Data Exploration Ever
Let's now take a quick look at our data - as we can see, each image has an id which corresponds to its **filename**, and then a label which corresponds to its **classification**.

In [None]:
data = pd.read_csv("/kaggle/input/cassava-leaf-disease-classification/train.csv")
data.head()

# Data Preparation: Dataset and DataModule Creation
Here's where we start getting to the good stuff.

First, we'll build a custom CassavaDataset class using PyTorch's Dataset class, which requires that we define three functions: the initialization of a class instance (in the *init* function) as well as *len* and *getitem* functions. All we're doing in the *init* function is reading from the correct `.csv` file based on whether we need the training or test dataset, and then storing the image_ids and labels in arrays along with other information we might need (such as the image directory and the transforms we should apply to each image).

Then, in the *getitem* stage we actually load the image, perform any transforms necessary, and then return it with the corresponding label. 

In [None]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image

class CassavaDataset(Dataset):
    """ Cassava Dataset """
    
    def __init__(self, root_dir, transform=None, stage=None):
        if (stage):
            # We're in test stage then
            csv_output = pd.read_csv(os.path.join(root_dir, "sample_submission.csv"))
            self.images_dir = os.path.join(root_dir, "test_images")
        else:
            csv_output = pd.read_csv(os.path.join(root_dir, "train.csv"))
            self.images_dir = os.path.join(root_dir, "train_images")
        self.image_urls = np.asarray(csv_output["image_id"])
        self.labels = np.asarray(csv_output["label"])
        self.transform = transform
        
    def __len__(self):
        return len(self.image_urls)
    
    def __getitem__(self, idx):
        # Get and load image
        image_path = os.path.join(self.images_dir, self.image_urls[idx])
        image = Image.open(image_path)
        # Perform transforms if any
        if self.transform:
            image = self.transform(image)
        # Get label
        label = self.labels[idx]
        return image, label

This is where we start to use Pytorch Lightning. We'll build a DataModule here, which lets us easily construct the PyTorch DataLoaders we'll use during training. As you can see in the *setup* function, we use the previously defined CassavaDataset class to build our dataset, and then split it up into training, validation, and test sets which are then returned as DataLoaders through their respective functions. 

In [None]:
from torch.utils.data import random_split
import math

class CassavaDataModule(pl.LightningDataModule):
    """ Cassava DataModule for Lightning """
    def __init__(self, root_dir, transform=None, batch_size=32):
        super().__init__()
        self.batch_size = batch_size
        self.root_dir = root_dir
        self.transform = transform
        
    def setup(self, stage=None):
        cassava_full = CassavaDataset(self.root_dir, self.transform)
        train_data_len = math.floor(len(cassava_full) * 0.7)
        val_data_len = len(cassava_full) - train_data_len
        # Create train and validation datasets
        self.cassava_train, self.cassava_val = random_split(cassava_full, [train_data_len, val_data_len], generator=torch.Generator().manual_seed(42))
        
        # Create test dataset
        self.cassava_test = CassavaDataset(self.root_dir, self.transform, stage="test")
        
    def train_dataloader(self):
        return DataLoader(self.cassava_train, batch_size=self.batch_size)
    
    def val_dataloader(self):
        return DataLoader(self.cassava_val, batch_size=self.batch_size)
    
    def test_dataloader(self):
        return DataLoader(self.cassava_test, batch_size=self.batch_size)

Finally, we just need to define the transforms we'll be using, and then create our DataModule!

The transforms are as follows:
* Resize each image to 224x224 as that is the image size that our ResNet model (defined later) was trained on
* Convert the image to a Tensor so that PyTorch can handle it
* Normalize images according to either standard normalization or ImageNet Normalization

Both normalization techniques are meant to ensure that our pixels, across images, have a similar data distribution - this is meant to help training (read more [here](https://becominghuman.ai/image-data-pre-processing-for-neural-networks-498289068258)). While the standard normalization works for most cases, ImageNet normalization is sometimes applied to ensure that all incoming images have a similar data distribution to the images that ImageNet was trained on. Test both of these and see which one works best for you!

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Standard Normalization
#      transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) #ImageNet Normalization
    ])

root_dir = "/kaggle/input/cassava-leaf-disease-classification/"
cassava_data = CassavaDataModule(root_dir, transform, batch_size=64)
cassava_data.setup()

# Model Initialization
Great - so we've processed our data, setup a DataModule and its DataLoaders, now what? It's time to construct our model class!

For a model, PyTorch Lightning requires us to define just 5 functions. First we have the initialization function, where we set a few variables that we'll use for training (the learning rate and the loss function), as well as construct the two important parts of our model: the ImageNet-trained ResNet50 model along with the additional Linear layer. This additional layer will be used for the actual classifying of the Cassava leaves, as it will take the image "features" extracted by ResNet and then make a classification after some additional processing.

The next few functions are used for different parts of training. First, *configure_optimizers* allows us to set the optimization function we'll be using for training - here we've chosen Adam, as it is often considered the best-performing optimizer on image tasks. Next, we define our *forward* function, which essentially just describes how we go from start to finish with our model - from the input image to a classification prediction output. You might be wondering - shouldn't there be a backprop function then too? Thankfully, PyTorch and PyTorch Lightning handle that for us, so we don't have to worry about that at all :)

Then, the final two functions here describe the training process. In both, we receive a batch of images, process them with our *forward* function, and then calculate the loss between the predicted classes and the actual classes. In the *training_step* we return this calculated loss as PyTorch Lightning will automatically use that loss with our previously-configured optimizer to compute the backpropagation for the model; which, as mentioned earlier, Lightning handles for us.

In [None]:
import torchvision.models as models

class ImageNetModel(pl.LightningModule):
    def __init__(self, learning_rate=1e-3):
        super().__init__()
        
        # Set our learning rate
        self.learning_rate = learning_rate
        
        num_target_classes = 5
        self.feature_extractor = models.resnet50(pretrained=True)
        self.feature_extractor.eval()
        
        # Use the pretrained model to classify cassava
        self.classifier = nn.Linear(1000, num_target_classes)
        
        # Create loss function
        self.loss_fn = torch.nn.CrossEntropyLoss()
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer
    
    def forward(self, input_data):
        representations = self.feature_extractor(input_data)
        preds = self.classifier(representations)
        return preds
    
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        predictions = self.forward(x)
        loss = self.loss_fn(predictions, y)
        return loss
    
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        predictions = self.forward(x)
        loss = self.loss_fn(predictions, y)
        self.log('val_loss', loss)

## Initializing our Model & Trainer
Now that we've defined our model class, we're done with setup! Yes, it really was that easy.

From here, it's time to create an instance of the model and an instance of PyTorch's Trainer (which is where the magic happens). Note that we set the `auto_lr_find` parameter to `True`, which allows us to "tune" the learning rate to an appropriate level (which the Trainer will find for us!) and use that during training.

In [None]:
model = ImageNetModel()

trainer = pl.Trainer(gpus=1, auto_lr_find=True)

## Tune the Learning Rate

In [None]:
trainer.tune(model, cassava_data)

## Fitting our Model to Our Data
Now, we just fit the PyTorch LightningModule-based model to our previously-defined DataModule using the Trainer - and then we have our results! Here we only ran it for one epoch (which means that our model will see each image in the dataset only *once*), but you can definitely run it for more (and should!) to get better results.

Additionally, the Trainer saves training logs and a version of our model after each epoch - which allows us to easily inspect our results without having to mess around with the "administrative" side of ML. Isn't it great that we can just focus on the data and model without having to worry about tangential things like logging?

In [None]:
trainer.fit(model, cassava_data)

# Evaluating our Performance
Finally, here's a simple way for us to check our model's performance at the end on the validation and test sets.

Considering we only ran our model for one epoch, I'd say this is pretty good!

In [None]:
def evaluate_results(loader):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for x, y in iter(loader):        
            x = x.to(device)
            y = y.to(device)
            preds = model(x)
            _, predicted = torch.max(preds, 1)

            correct += (predicted == y).sum().item()
            total += len(y)
    return (correct / total)

In [None]:
# Check the Validation Results
validation_loader = cassava_data.val_dataloader()
evaluate_results(validation_loader)

In [None]:
test_loader = cassava_data.test_dataloader()
evaluate_results(test_loader)

## Finding Our Saved Model
Finally, we can see that PyTorch Lightning did indeed save a checkpoint of our model after the first epoch

In [None]:
for dirname, _, filenames in os.walk('/kaggle/working/lightning_logs/version_0/checkpoints'):
    for filename in filenames:
        print(os.path.join(dirname, filename))