# Classifying the dermamnist data set (initial version: 01)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ubern-mia/bender/episode03/dermamnist_v1_initial.ipynb)

This is a naive version of training a simple 4 layer CNN.

First, we (optionally install and) import the libraries we depend on, and choose the type of compute resource that is available.

In [None]:
%pip install medmnist torch torchvision tqdm matplotlib sklearn

In [None]:
import os

import medmnist
from medmnist import INFO

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms

from tqdm import tqdm
import matplotlib.pyplot as plt

from sklearn.metrics import classification_report

# Define the torch.device you will use.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Next, we define a helper function to load the data from the medmnist data set (we use dermamnist, but it should be easy to modify this to any of the other flavors available). To do this, change the 'data_flag' variable to be "pathmnist", for example.

In [None]:
def load_datasets(flag):
    """
    load_datasets loads the dermamnist data.
    'flag' takes two options:
        'train': loads the training set as first output, validation set as second.
        'test' : loads the training set as first output, test set as second.
    """

    data_flag = "dermamnist"
    download = True
    info = INFO[data_flag]

    DataClass = getattr(medmnist, info["python_class"])

    transform_medmnist = transforms.Compose([transforms.ToTensor(), transforms.Pad(2)])

    data_train = DataClass(
        split="train", transform=transform_medmnist, download=download
    )
    if flag == "train":
        data_next = DataClass(
            split="val", transform=transform_medmnist, download=download
        )
    elif flag == "test":
        data_next = DataClass(
            split="test", transform=transform_medmnist, download=download
        )

    return data_train, data_next

Next, we create a simple class called CNN, which holds our model architecture. In this version, it is a simple four layer Convolutional network, with four blocks of Conv + ReLU + Batch Norm layers, all with 64 filters, and all but the first one with size 3-by-3. 

In [None]:
class CNN(nn.Module):
    """
    A simple 4 layered CNN to run classification on dermamnist.
    """

    def __init__(self):
        """
        Definition of layers in the CNN.
        """

        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, (5, 5), padding=2, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            # (32, 32, 32)
            nn.Conv2d(64, 64, (3, 3), padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            # (32, 32, 32)
            nn.Conv2d(64, 64, (3, 3), padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            # (32, 32, 32)
            nn.Conv2d(64, 64, (3, 3), padding=1, stride=2, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            # (64, 16, 16)
        )

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.classifier = nn.Sequential(nn.Linear(64, 7))

    def forward(self, in_tensor):
        """
        Forward pass through the CNN.
        """

        in_tensor = self.features(in_tensor)
        in_tensor = self.avgpool(in_tensor)
        in_tensor = torch.reshape(in_tensor, (-1, 64))
        return self.classifier(in_tensor)

With these components of our model setup, we then write up the training and test loops. The training loop here includes setting up the model hyperparameters and the loss function. This script also includes code to log training loss and validation accuracies at every 50 iterations, and, the end of each epoch respectively. These plots are stored as .png files to analyze after, for performance and potential issues.

In [None]:
def train(output_path: str = None, batch_size: int = 8, num_epochs: int = 100):
    """
    Model training loop, including setting up hyperparameters.
    """

    if output_path is None:
        print("output_path needs to be setup. Exiting.")
        return
    
    os.makedirs(output_path, exist_ok=True)

    data_train, data_val = load_datasets("train")
    # Define the PyTorch data loaders for the training and test datasets.
    # Use the given batch_size and remember that the training loader should
    # shuffle the batches each epoch.
    loader_train = DataLoader(data_train, batch_size=batch_size, shuffle=True)
    loader_val = DataLoader(data_val, batch_size=batch_size, shuffle=False)

    # Define the model architecture.
    model = CNN()

    # Compute the number of parameters of the model
    num_params = sum(p.numel() for p in model.parameters())
    print(f"Number of parameters: {num_params}")

    # Setup the model training hyperparameters.
    optimizer = torch.optim.SGD(model.parameters(), lr=0.000005, momentum=0.5)
    loss_function = torch.nn.CrossEntropyLoss()

    # Iteration counter
    it = 0
    train_loss = []
    val_acc = []

    # Keep track of the best performance reached so far.
    best_accuracy = 0

    # Number of iterations required in one epoch
    epoch_length = len(loader_train)

    # Repeat training the given number of epochs
    for epoch in range(num_epochs):

        print(f"Starting epoch {epoch + 1}/{num_epochs}...")

        # Run one epoch
        for batch in tqdm(loader_train):

            it += 1

            # REMEMBER TO SET THE TRAINING STATE OF THE MODEL.
            # Call .train() before training and .eval() before evaluation every
            # time!!
            model.train()
            inputs = batch[0]
            labels = batch[1]
            labels = labels.squeeze().long()

            # Zero your gradients for every batch!
            optimizer.zero_grad()

            # Make predictions for this batch
            outputs = model(inputs)

            # Compute the loss and its gradients
            loss = loss_function(outputs, labels)
            loss.backward()

            # Adjust learning weights
            optimizer.step()

            # Log the training loss once every 50 iterations
            if (it % 50) == 0:
                train_loss.append([loss, it])

            # Run validation and save the model once every epoch.
            # You could put this code outside the inner training loop, but
            # doing it here allows you to run validation more than once per epoch.
            if (it % epoch_length) == 0:

                batch = next(iter(loader_val))

                inputs = batch[0]
                labels = batch[1]
                labels = labels.squeeze().long()

                # Zero your gradients for every batch!
                optimizer.zero_grad()

                # Make predictions for this batch
                outputs = model(inputs)

                # Loop over the metrics for validation, loss and accuracy.
                metrics = evaluate_model(model, loader_val)

                # Loop over the metrics and log them
                for key in metrics.keys():
                    val_acc.append([metrics[key], it])

                accuracy = metrics["accuracy"]
                if accuracy > best_accuracy:
                    best_accuracy = accuracy
                    model_file = os.path.join(output_path, "best_model.pt")
                    # Save the model to a `model_file`.
                    torch.save(model.state_dict(), model_file)

                print(f"Current accuracy is {accuracy}, and best is: {best_accuracy}.")

    plt.figure()
    plt.plot(
        [it for loss, it in train_loss],
        [loss.detach().item() for loss, it in train_loss],
    )
    plt.xlabel("Iterations")
    plt.ylabel("Training loss")
    plt.grid()
    plt.savefig(os.path.join(output_path, "train_loss.png"))

    plt.figure()
    plt.plot([it for acc, it in val_acc], [acc for acc, it in val_acc])
    plt.xlabel("Iterations")
    plt.ylabel("Validation accuracy")
    plt.grid()
    plt.savefig(os.path.join(output_path, "val_acc.png"))

The test function is relatively simpler: this requires a model that is trained already, and uses the same 'load_datasets' function written earlier to now load the test set and evaluate the generalization capacity of our model. The evaluation is done using another helper function called 'evaluate_model', which is reused in the training loop as well! :-) 

In [None]:
def test(model_path: str=None):
    """
    Test model after training.
    """

    model = CNN()
    if model_path is None:
        print("model_path needs to be specified, which includes 'best_model.pt'.")
        return

    model.load_state_dict(torch.load(os.path.join(model_path, "best_model.pt")))

    # Print model's state_dict
    print("Model's state_dict:")
    for param_tensor in model.state_dict():
        print(param_tensor, "\t", model.state_dict()[param_tensor].size())

    num_params = sum(p.numel() for p in model.parameters())
    print(f"Number of parameters: {num_params}")

    _, data_test = load_datasets("test")
    loader_test = DataLoader(data_test, batch_size=8, shuffle=False)

    metrics = evaluate_model(model, loader_test)
    print(f"Test accuracy is: {metrics['accuracy']}")

The following function evaluates our model (both during validation and testing): and prints out the classification report, as well as the accuracy of the model. Note that it is very important to set the model to 'eval()' mode, and use the torch.no_grad() decorator, without which bad things can happen ;-). 

In [None]:
@torch.no_grad()
def evaluate_model(model: nn.Module, loader: DataLoader):
    """
    Evaluate model while training.
    """

    data_flag = "dermamnist"
    info = INFO[data_flag]

    # Evaluate the model with the given data loader.
    model.eval()
    correct = 0
    total = 0
    metrics = {}

    label_list = []
    pred_list = []

    for data in loader:
        images, labels = data[0], data[1]
        labels = labels.squeeze().long()

        # calculate outputs by running images through the network
        outputs = model(images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        for x in predicted.numpy().tolist():
            pred_list.append(x)
        for x in labels.numpy().tolist():
            label_list.append(x)

    print(
        classification_report(
            label_list, pred_list, target_names=list(info["label"].values()), digits=4
        )
    )

    metrics["accuracy"] = correct / total
    return metrics

Finally, choose a mode to run these functions: in TRAINING_MODE, the model is trained on the training set we load from medmnist, and otherwise, it is tested using the testing set, for evaluating the generalization capability. Change the output_path to be any writable folder on your computer to save the plots for the training loss and validation accuracy. The metrics while training are otherwise printed on the terminal/console. 

In [None]:
training_mode = True
output_path = "./dermamnist_v1"

if training_mode:
    train(output_path)
else:
    test(output_path)