## Importing necessary libraries

In [None]:
# !pip show matplotlib
# !pip show pandas
# !pip show scikit-learn
# !pip show seaborn

In [None]:
# !pip install matplotlib
# !pip install pandas
# !pip install scikit-learn
# !pip install seaborn
# !pip install torchinfo
# !pip install torchviz
# !pip install graphviz

In [None]:
# Importing PyTorch and its computer vision extension
import torch
import torchvision
import torchvision.models as models
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Importing plotting/visualization and numerical computation libraries
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from torchviz import make_dot
from graphviz import Source

# Importing utilities for random operations and file/directory handling
import random
import os
from copy import deepcopy

# Importing data handling utilities from PyTorch and scikit-learn
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from sklearn.model_selection import train_test_split

# Importing specific data preprocessing and augmentation tools from torchvision
from torchinfo import summary
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor, Compose, Normalize
from torchvision.datasets import MNIST
from PIL import Image

# Importing feature extraction tools and evaluation metrics
from torchvision.models.feature_extraction import create_feature_extractor
from sklearn.metrics import precision_score, recall_score, roc_auc_score, confusion_matrix, multilabel_confusion_matrix, classification_report
from sklearn.preprocessing import label_binarize

In [None]:
# Use GPU if available, else use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

## Data Loading & Pre-processing

In [None]:
# Original transformation
original_transform = transforms.Compose([
    
    # Resize input images to a fixed size of 224x224 pixels
    transforms.Resize((224, 224)),  
    
    # Converts the image data into PyTorch tensors
    transforms.ToTensor(), 
    
    # Standardize input data by subtracting the mean and dividing by the standard deviation along each channel (RGB)
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  
])

# Augmented transformation with horizontal flip and random rotation
augmented_transform = transforms.Compose([
    
    # Resize input images to a fixed size of 224x224 pixels
    transforms.Resize((224, 224)),
    
    # Randomly flips the images horizontally with a probability of 1.0
    transforms.RandomHorizontalFlip(p=1.0), 
    
    # Randomly rotate images in the range of -15 to +15 degrees
    transforms.RandomRotation(degrees=15), 
    
    # Converts the augmented images into PyTorch tensors
    transforms.ToTensor(),
    
    # Standardize input data by subtracting the mean and dividing by the standard deviation along each channel (RGB)
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
# Load the dataset from local folder, without applying any transformations
raw_dataset = datasets.ImageFolder(root='dataset', transform=None)

# Load the dataset from Kaggle, without applying transformations
# raw_dataset = datasets.ImageFolder(root='/kaggle/input/the-iqothnccd-lung-cancer-dataset', transform=None)

# print(raw_dataset)
# print("---")
# print(raw_dataset[0])
# print("---")
# print(type(raw_dataset[0]))

In [None]:
# Access the first raw image and its label directly
# img, label = raw_dataset[0]
raw_img, label = raw_dataset[0]

# Now apply the orignal and augmented transformations to the raw PIL image for demonstration
img_original_tensor = original_transform(raw_img)
img_augmented_tensor = augmented_transform(raw_img)

### Showing sample:

In [None]:
# Define the function to unnormalize and show the image
def show_image(img, title=None, ax=None):
    
    # Convert image tensor from PyTorch tensor format to NumPy array format. Transpose the dimensions to re-arrange order of axes
    img = img.numpy().transpose((1, 2, 0))
    
    # Define the mean and standard deviations values used (when normalizing)
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    
    # Reverse the normalization applied during data preprocessing.
    img = std * img + mean  
    
    # Clips the pixel values of the image array to ensure that they fall within the valid range of [0, 1]
    img = np.clip(img, 0, 1)
    
    # If axis is provided, plot image on specified axis, else plot using image dimensions
    if ax is not None:
        ax.imshow(img)
        if title is not None:
            ax.set_title(title)
        ax.axis('off')
        
    else:
        plt.imshow(img)
        if title is not None:
            plt.title(title)
        plt.pause(0.001)  

In [None]:
# Setup for display: Creates a figure and a set of subplots arranged in a single row with two columns and figure size as 12 inches by 6 inches wide. 
fig, axs = plt.subplots(1, 2, figsize=(12, 6))

# Displaying the original image and the augmented image
show_image(img_original_tensor, title='Original Image', ax=axs[0])
show_image(img_augmented_tensor, title='Augmented Image', ax=axs[1])

# Displaying the entire figure containing both subplots with the original and augmented images
plt.show()

### Data Preprocessing:

In [None]:
# Load the dataset from local folder, applying transformations this time
dataset = datasets.ImageFolder(root='dataset', transform = augmented_transform)

# Load the dataset from Kaggle, applying transformations this time
# dataset = datasets.ImageFolder(root='/kaggle/input/the-iqothnccd-lung-cancer-dataset', transform=augmented_transform)

In [None]:
# Display list of tuples, where each tuple contains the file path to an image and its corresponding class label
dataset.samples

In [None]:
# Splitting the dataset - 70% for training, 15% for validation and the rest 15% for testing
train_size = int(0.7 * len(dataset)) 
val_size = int(0.15 * len(dataset)) 
test_size = len(dataset) - train_size - val_size 

# Extracting labels from the dataset
targets = [s[1] for s in dataset.samples]  

# Splitting the dataset into 2 sets: train+val and test sets
train_val_idx, test_idx = train_test_split(
    # The array which we want to split
    range(len(targets)),
    # Specify proportion of the dataset that should be allocated for the test set 
    test_size= test_size/len(dataset), 
    # Ensure class distribution is preserved in both the training+validation and testing sets
    stratify=targets,
    # random seed for reproducibility
    random_state=42  
)

# Splitting the train+val further into 2 sets: training and validation sets
train_idx, val_idx = train_test_split(
    # The array which we want to split
    train_val_idx,
    # Specify proportion of the dataset that should be allocated for the validation set 
    test_size=val_size / (train_size + val_size),  
    # Ensure class distribution is preserved for both the training and validation set
    stratify=[targets[i] for i in train_val_idx],
    # random seed for reproducibility
    random_state=42  
)

# Creating subsets for each split
train_dataset = Subset(dataset, train_idx)
validation_dataset = Subset(dataset, val_idx)
test_dataset = Subset(dataset, test_idx)

In [None]:
print(train_dataset)
print(validation_dataset)
print(test_dataset)

In [None]:
# Creating data loaders that help iterate through the dataset in batches during training, validation and testing.
# Using a batch size of 16 everywhere and only shuffling during training and not validation or testing.
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
test_loader_for_inference = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [None]:
# print(train_loader)
# testing_loop = 0

# for inputs, labels in train_loader:
#     print(f"Iteration number {testing_loop}")
#     print("Inputs")
#     print(inputs)
#     print("Size of inputs")
#     print(inputs.size())
#     print("Labels")
#     print(labels)
#     print("=========")
#     testing_loop += 1

## Defining CNN Model

In [None]:
class CNN_for_LungCancer(nn.Module):
    def __init__(self, dropout_rate, fc_units):
        
        # Initialize base class, nn.Module
        super(CNN_for_LungCancer, self).__init__()
        
        # First convolutional layer with 3 input channels (RGB images), 32 output channels, 3x3 kernel size, and 1 pixel padding
        # Note: Width & Height don't change as padding is 1
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
        
        # Batch normalization 
        self.bn1 = nn.BatchNorm2d(32)
        
        # ReLu Activation
        self.act1 = nn.ReLU()
        
        # Max pooling with 2*2 kernel size and stride of 2
        # Note: The Width & Height will be halved here, as stride is 2, with 2*2 kernel size (skipping over alternate columns)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Dropout layer applied after the first pooling layer
        self.dropout1 = nn.Dropout(dropout_rate)  
        
        # Second convolutional layer with 32 input channels (output from the previous layer), 64 output channels, 3x3 kernel size, and 1 pixel padding
        # Note: Width & Height don't change as padding is 1
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        
        # Batch normalization
        self.bn2 = nn.BatchNorm2d(64)  
        
        # ReLu Activation
        self.act2 = nn.ReLU()
        
        # Dropout layer applied after the second pooling layer
        self.dropout2 = nn.Dropout(dropout_rate)   
        
        # Commented - 3rd Convolutional Layer
        #self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        #self.bn3 = nn.BatchNorm2d(128)
        #self.act3 = nn.ReLU()
        #self.dropout3 = nn.Dropout(dropout_rate)  
        #self.fc1 = nn.Linear(in_features=128 * 28 * 28, out_features=fc_units)
        
        # First fully connected of neural network, with fc_units number of output features/neurons in the fully connected layer
        self.fc1 = nn.Linear(in_features = 64 * 56 * 56, out_features = fc_units)
        
        # Second fully connected layer with 3 output features for classification (benign, malignant, and normal)
        self.fc2 = nn.Linear(fc_units, 3)
        
        # Final dropout layer with specified dropout rate
        self.dropout4 = nn.Dropout(dropout_rate)


    def forward(self, x):
        
        # Applies the first convolutional layer, then ReLU activation function, then batch normalization and max pooling
        x = self.pool(self.bn1(self.act1(self.conv1(x))))
        
        # Applies dropout to output of the first pooling layer
        x = self.dropout1(x)
        
        # Applies the second convolutional layer, then ReLU activation function, then batch normalization and max pooling
        x = self.pool(self.bn2(self.act2(self.conv2(x))))
        
        # Applies dropout to output of the second pooling layer
        x = self.dropout2(x)
        
        #x = self.pool(self.bn3(self.act3(self.conv3(x))))
        #x = self.dropout3(x)
        
        # Prepare data for input into fully-connected layers by flattening output of the last pooling layer into a 1-dimensional tensor
        x = torch.flatten(x, 1)
        
        # Applies dropout to the output of the first fully connected layer 
        x = self.dropout4(self.fc1(x))
        
        # Computes the final output of the neural network by passing the output of the first fully connected layer through the second fully connected layer
        # This represents class scores for each input sample
        x = self.fc2(x)

        return x

In [None]:
# Pass dropout rate of 0.5 and 64 neurons in first fully connected layer
cnn_model = CNN_for_LungCancer(dropout_rate=0.5, fc_units=64)
cnn_model.to(device)  

In [None]:
# Create summary of model
# Input size of images would be a batch size of 16, RGB channels of 3, Height of 224 and Width of 224
summary(cnn_model, input_size=(16, 3, 224, 224))

## Defining CNN + LSTM Model:

In [None]:
class CNN_LSTM_for_LungCancer(nn.Module):
    
    def __init__(self, dropout_rate, fc_units, lstm_units, num_layers):
        
        super(CNN_LSTM_for_LungCancer, self).__init__()
        
        
        ## CNN Part
        
        
        # First convolutional layer with 3 input channels (RGB images), 32 output channels, 3x3 kernel size, and 1 pixel padding
        # Note: Width & Height don't change as padding is 1
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
        
        # Batch normalization
        self.bn1 = nn.BatchNorm2d(32)
        
        # ReLu Activation
        self.act1 = nn.ReLU()
        
        # Max pooling with 2*2 kernel size and stride of 2
        # Note: The Width & Height will be halved here, as stride is 2, with 2*2 kernel size (skipping over alternate columns)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Dropout layer applied after the first pooling layer
        self.dropout1 = nn.Dropout(dropout_rate)
        
        # Second convolutional layer with 32 input channels (output from the previous layer), 64 output channels, 3x3 kernel size, and 1 pixel padding
        # Note: Width & Height don't change as padding is 1
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        
        # Batch normalization
        self.bn2 = nn.BatchNorm2d(64)
        
        # ReLu Activation
        self.act2 = nn.ReLU()
        
        # Dropout layer applied after the second pooling layer
        self.dropout2 = nn.Dropout(dropout_rate)
        
        
        
        ## LSTM Part
        
        
        # Obtain inputs of flattened size of 64 * 56 (dimensions after the CNN layers), with tensors in a batch first manner
        self.lstm = nn.LSTM(input_size=64 * 56, hidden_size=lstm_units, num_layers=num_layers, batch_first=True)
        
        # Apply dropout to output of LSTM layer for regularization
        self.dropout3 = nn.Dropout(dropout_rate)
        
        # First Fully Connected Layer mapping output of the LSTM layer to a lower-dimensional space with fc_units neurons
        self.fc1 = nn.Linear(lstm_units, fc_units)  
        
        # Second Fully Connected Layer mapping output of the previous fully connected layer to the final output space with 3 classes
        self.fc2 = nn.Linear(fc_units, 3)  
        
    
    def forward(self, x):
        
        
        ## CNN Part
        
        
        # Applies the first convolutional layer, then ReLU activation function, then batch normalization and max pooling
        x = self.pool(self.bn1(self.act1(self.conv1(x))))
        
        # Applies dropout to output of the first pooling layer
        x = self.dropout1(x)
        
        # Applies the second convolutional layer, then ReLU activation function, then batch normalization and max pooling
        x = self.pool(self.bn2(self.act2(self.conv2(x))))
        
        # Applies dropout to output of the second pooling layer
        x = self.dropout2(x)
        
        # At the end of CNN, the dimensions would be:
        # (batch_size, channels, height, width)
        
        
        ## Prepare for LSTM
        
        
        # Swap the second dimension (channels) with third dimension (height) so it becomes (batch_size, height, channels, width)
        # This is because the sequence length of LSTM needs to be the height of the images 
        x = x.permute(0, 2, 1, 3).contiguous()  
        
        # Load x.size() as such 
        batch_size, seq_len, channels, height = x.size()
        
        # Reshape the last input_size dimensions
        x = x.view(batch_size, seq_len, -1)  
        
        
        ## LSTM Part
        
        
        # Pass input tensor, x to the model. 
        # Returns output of LSTM layer at each time step, along with a tuple of hidden state and cell state of LSTM at last time step
        x, (hn, cn) = self.lstm(x)
        
        # Applying dropout to this selected hidden state of the LSTM at the last time step for each sample in the batch, as we are only interested in final output
        x = self.dropout3(x[:, -1, :])  
        
        # Utilizing the 2 Fully Connected Layer
        x = self.fc1(x)
        x = self.fc2(x)
        
        return x

In [None]:
lstm_units = 256  
num_layers = 2  

# Instantiating the CNN+LSTM model
CNN_LSTM_model = CNN_LSTM_for_LungCancer(dropout_rate = 0.25, fc_units = 64, lstm_units = lstm_units, num_layers = num_layers)
CNN_LSTM_model.to(device)

In [None]:
# Create summary of model
# Input size of images would be a batch size of 16, RGB channels of 3, Height of 224 and Width of 224
summary(CNN_LSTM_model, input_size=(16, 3, 224, 224))

## Building Training & Evaluating function

In [None]:
def train_and_evaluate(model, model_name, train_loader, validation_loader, test_loader, epochs=10, lr=0.0001, early_stopping_patience = 5, lr_scheduler_patience = 10):
    
    # Initialize lists to store accuracies and losses for train, test and validation
    train_accuracies = []
    test_accuracies = []
    validation_accuracies = []

    train_losses = []
    test_losses = []
    validation_losses = []
    
    # Initialize the early stopping parameters
    best_val_loss = float('inf')
    best_model_state = deepcopy(model.state_dict())
    epochs_no_improve = 0
    
    # Move device to GPU if available, else CPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
        
    # Computes Loss Function using Cross Entropy Loss
    criterion = nn.CrossEntropyLoss()
    
    # Initialize Adam Optimizer with certain parameters 
    optimizer = optim.Adam(model.parameters(),
                           # Set learning rate of the optimizer
                           lr = lr,
                           # Set L2 regularization to the model's weights during optimization
                           weight_decay = 1e-4, 
                           # Defining beta1 (first moment) as 0.9 and beta2 (second moment) as 0.999
                           betas = (0.9, 0.999),
                           # Epsilon to ensure numerical stability during optimization
                           eps = 1e-8,
                           # Adopting AMSGrad variant of the Adam optimizer
                           amsgrad = True)
    
    print(optimizer)
    
    # Initialize the learning rate scheduler
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience = lr_scheduler_patience, verbose = True, threshold = 0.0001, threshold_mode = 'rel', cooldown = 0, min_lr = 0, eps = 1e-8)
    
    # Running the algorithm for specified number of epochs
    for epoch in range(epochs):
        
        
        ## TRAINING
        
        # Set the model in training mode
        model.train()
        
        # Keeps track of training loss
        train_loss = 0
        
        # Correct_train keeps track of the number of correctly predicted samples during training
        correct_train = 0
        
        # Total_train keeps track of the number of the total number of training samples processed so far 
        total_train = 0
        
        # For each input data and label in training dataset
        for inputs, labels in train_loader:
            
            # Move it to the right device
            inputs, labels = inputs.to(device), labels.to(device)
                        
            # Set all calculated gradients to 0
            optimizer.zero_grad()

            # If model being trained is GoogleNet and it is in training mode, then initialize it differently (due to auxiliary outputs), as follows:
            if model_name == "GoogleNet_Model" and model.training:
                
                # Perform forward pass through the network with given inputs
                # Note: Inception v3 model returns outputs in the form of (output, aux_logits) during training
                outputs, aux_outputs = model(inputs)
                
                # Calculates loss of main output and loss of auxiliary output
                loss1 = criterion(outputs, labels)
                loss2 = criterion(aux_outputs, labels)
                
                # Combine main loss and auxiliary loss to get final loss of GoogleNet (a weighted average of the two losses)
                loss = loss1 + 0.4 * loss2
                
            else:
                # Perform forward pass through the network with given inputs
                outputs = model(inputs)
                
                # Obtain loss value 
                loss = criterion(outputs, labels)
            
            # Compute the gradients of the loss with respect to all trainable parameters in the model.
            loss.backward()
            
            # Update the model parameters based on the computed gradients
            optimizer.step()
            
            # Accumulates the value of the loss incurred in the current batch to the overall training loss.
            # Note: .item() extracts scalar value
            train_loss += loss.item()
            
            # Obtain the class predicted by the model
            _, predicted = torch.max(outputs.data, 1)
            
            # Update total_train to keep track of the total number of training samples processed so far
            total_train += labels.size(0)
            
            # Update correct_train to keep track of the number of correctly predicted samples in the current batch
            correct_train += (predicted == labels).sum().item()

        # Calculate train_accuracy
        train_accuracy = 100 * correct_train / total_train
        
        # Append the training accuracy to train_accuracies list to monitor it over epochs
        train_accuracies.append(train_accuracy)
        
        # Append the training loss to train_losses list to monitor it over epochs
        train_losses.append(train_loss / len(train_loader))
        
        
        ## VALIDATION

        
        # Set the model to evaluation mode
        model.eval()
        
        # Correct_validation keeps track of the number of correctly predicted samples during validation
        correct_validation = 0
        
        # Total_validation keeps track of the number of total samples during validation
        total_validation = 0
        
        # Validation_loss keeps track of total loss during evaluation on the validation set
        validation_loss = 0
        
        # During validation, we don't need to utilize gradient
        with torch.no_grad():
            
            # For each input data and label in validation dataset
            for inputs, labels in validation_loader:
                
                # Move it to the right device
                inputs, labels = inputs.to(device), labels.to(device)
                
                # Trigger a forward pass through the neural network
                outputs = model(inputs)
                
                # Compare the outputs (predictions made by the model) and the labels (ground truth data) wrt to loss function
                loss = criterion(outputs, labels)
                
                # Accumulates the value of the loss incurred in the current batch to the overall validation loss.
                # Note: .item() extracts scalar value
                validation_loss += loss.item()
                
                # Obtain the class predicted by the model
                _, predicted = torch.max(outputs.data, 1)
                
                # Update total_validation to keep track of the total number of validation samples processed so far
                total_validation += labels.size(0)
                
                # Update correct_validation to keep track of the number of correctly predicted samples in the current batch
                correct_validation += (predicted == labels).sum().item()

        # Calculate validation_accuracy
        validation_accuracy = 100 * correct_validation / total_validation
        
        # Append the validation accuracy to validation_accuracies list to monitor it over epochs
        validation_accuracies.append(validation_accuracy)
        
        # Append the validation loss to validation_losses list to monitor it over epochs
        validation_losses.append(validation_loss / len(validation_loader))

        
        # Step the learning rate scheduler based on validation loss
        scheduler.step(validation_loss)
        
        
        ## Early stopping logic
        
        # Retrieve the most recent validation loss
        current_val_loss = validation_losses[-1]
        
        # If current validation loss is lower than the best validation loss, it means model has improved
        if current_val_loss < best_val_loss:
            
            # Set the new best validation loss as the current valiation loss
            best_val_loss = current_val_loss
            
            # Copy all the parameters of the current model and save it to best_modeL_state variable
            best_model_state = deepcopy(model.state_dict())
            
            # Reset epochs_no_improve variable to 0
            epochs_no_improve = 0
            
        else:
            
            # Increment the epochs_no_improve
            epochs_no_improve += 1
            
            # Check if epochs_no_improve has reached the patience threshold (hyperparameter). 
            # If yes, it will break the epoch loop and save weights of the model at the end
            if epochs_no_improve == early_stopping_patience:
                print(f'Early stopping triggered after {epoch + 1} epochs!')
                break
                
                
        ## TESTING
        
        # These are the lists to collect all true labels and predictions
        all_test_labels = []
        all_test_predictions = []
        
        # Keeps track of testing loss
        test_loss = 0 
        
        # Keeps track of the number of correctly predicted samples during testing
        correct_test = 0
        
        # Total_test keeps track of the number of the total number of testing samples processed so far 
        total_test = 0
        
        # During testing, we don't need to utilize gradient
        with torch.no_grad():
            
            # For each input data and label in test dataset
            for inputs, labels in test_loader:
                
                # Move it to the right device
                inputs, labels = inputs.to(device), labels.to(device)
                
                # Trigger a forward pass through the neural network
                outputs = model(inputs)
                
                # Compare the outputs (predictions made by the model) and the labels (ground truth data) wrt to loss function
                loss = criterion(outputs, labels)
                
                # Accumulates the value of the loss incurred in the current batch to the overall testing loss.
                # Note: .item() extracts scalar value
                test_loss += loss.item()
                
                # Obtain the class predicted by the model
                _, predicted = torch.max(outputs.data, 1)
                
                # Save the labels and predictions to lists
                all_test_labels.extend(labels.cpu().numpy())
                all_test_predictions.extend(predicted.cpu().numpy())

                # Update total_test to keep track of the total number of testing samples processed so far
                total_test += labels.size(0)
                
                # Append the testing loss to correct_test list to monitor it over epochs
                correct_test += (predicted == labels).sum().item()

        # Calculate testing accuracy
        test_accuracy = 100 * correct_test / total_test
        
        # Append the testing accuracy to test_accuracies list to monitor it over epochs
        test_accuracies.append(test_accuracy)
        
        # Append the testing loss to test_losses list to monitor it over epochs
        test_losses.append(test_loss / len(test_loader))

        print(f"Epoch {epoch+1}, Train Loss: {train_loss/len(train_loader):.4f}, "
              f"Validation Loss: {validation_loss/len(validation_loader):.4f}, "
              f"Test Loss: {test_loss/len(test_loader):.4f}, "
              f"Train Accuracy: {train_accuracy:.2f}%, "
              f"Validation Accuracy: {validation_accuracy:.2f}%, "
              f"Test Accuracy: {test_accuracy:.2f}%")
        
        print("--------------------------------------------------------------")
        
        #print(f"Epoch number {epoch+1}:")
        #print("Model dictionary at that epoch")
        
        #print(model.state_dict())
        
        #for param_tensor in model.state_dict():
            #print(param_tensor, "\t", model.state_dict()[param_tensor].size())
            
        print("--------------------------------------------------------------")

       
    ## This is after the epoch loop
    print(f"Best validation loss obtained for particular epoch: {best_val_loss}")
    
    # Load the best model state (from the best epoch). This will be the one corresponding to the lowest (best) validation loss
    model.load_state_dict(best_model_state) 
    
    #print("Loaded Model dictionary")
    
    #print(model.state_dict())
    
#     if model_name == "Custom_CNN_Model":
#         print("Saving CNN Model")
#         torch.save(model.state_dict(), "train/weights/lung_cancer_classification_cnn_model.pth")

#     if model_name == "ResNet_Model":
#         print("Saving ResNet Model")
#         torch.save(model.state_dict(), "train/weights/lung_cancer_classification_resnet_model.pth")

#     if model_name == "VGGNet_Model":
#         print("Saving VGGNet Model")
#         torch.save(model.state_dict(), "train/weights/lung_cancer_classification_vgg19net_model.pth")
                                  
#     if model_name == "DenseNet_Model":
#         print("Saving DenseNet Model")
#         torch.save(model.state_dict(), "train/weights/lung_cancer_classification_DenseNet161_model.pth")
        
#     if model_name == "MobileNet_Model":
#         print("Saving MobileNetV3 Model")
#         torch.save(model.state_dict(), "train/weights/lung_cancer_classification_MobileNetV3_model.pth")
        
#     if model_name == "Wide_ResNet_Model":
#         print("Saving MobileNetV3 Model")
#         torch.save(model.state_dict(), "train/weights/lung_cancer_classification_WideResNet_model.pth")
        
#     if model_name == "AlexNet_Model":
#         print("Saving AlexNet Model")
#         torch.save(model.state_dict(), "train/weights/lung_cancer_classification_AlexNet_model.pth")
        
#     if model_name == "GoogleNet_Model":
#         print("Saving GoogleNet Model")
#         torch.save(model.state_dict(), "train/weights/lung_cancer_classification_GoogleNet_model.pth")
        
    # Convert the all_test_labels and all_test_predictions lists to numpy arrays
    all_test_labels = np.array(all_test_labels)
    all_test_predictions = np.array(all_test_predictions)

    return train_accuracies, test_accuracies, validation_accuracies, train_losses, test_losses, validation_losses, all_test_labels, all_test_predictions

**Training and evaluating CNN Model:**

In [None]:
%%time

train_accuracies_cnn, test_accuracies_cnn, validation_accuracies_cnn, train_losses_cnn, test_losses_cnn, validation_losses_cnn, all_test_labels_cnn, all_test_predictions_cnn = train_and_evaluate(model = cnn_model, model_name = "Custom_CNN_Model", train_loader = train_loader, validation_loader = validation_loader, test_loader = test_loader, epochs=10, lr=1e-4)

In [None]:
print(f"Training accuracies: {train_accuracies_cnn} \n")
print(f"Testing accuracies: {test_accuracies_cnn} \n")
print(f"Validation accuracies: {validation_accuracies_cnn} \n")
print(f"Training loses: {train_losses_cnn} \n")
print(f"Testing loses: {test_losses_cnn} \n")
print(f"Validation loses: {validation_losses_cnn} \n")

In [None]:
# Plotting Training, Validation and Test Accuracies
plt.figure(figsize=(10, 6))
plt.plot(train_accuracies_cnn, label='Train Accuracy')
plt.plot(test_accuracies_cnn, label='Test Accuracy')
plt.plot(validation_accuracies_cnn, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training, Validation and Test Accuracies of CNN Model')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# Plotting Training, Validation and Test Losses
plt.figure(figsize=(10, 6))
plt.plot(train_losses_cnn, label='Train Loss')
plt.plot(test_losses_cnn, label='Test Loss')
plt.plot(validation_losses_cnn, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training, Validation and Test Losses of CNN Model')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
print(classification_report(all_test_labels_cnn, all_test_predictions_cnn))

In [None]:
# Calculate confusion matrix
cm = confusion_matrix(all_test_labels_cnn, all_test_predictions_cnn)

# Define class names (if available)
class_names = ['Benign (Class 0)', 'Malignant (Class 1)', 'Normal (Class 2)'] 

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, cmap='Blues', fmt='g', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

**Training and evaluating CNN+LSTM Model:**

In [None]:
%%time

train_accuracies_cnn_lstm, test_accuracies_cnn_lstm, validation_accuracies_cnn_lstm, train_losses_cnn_lstm, test_losses_cnn_lstm, validation_losses_cnn_lstm, all_test_labels_cnn_lstm, all_test_predictions_cnn_lstm = train_and_evaluate(model = cnn_model, model_name = "Custom_CNN_Model", train_loader = train_loader, validation_loader = validation_loader, test_loader = test_loader, epochs=10, lr=1e-4)

In [None]:
print(f"Training accuracies: {train_accuracies_cnn_lstm} \n")
print(f"Testing accuracies: {test_accuracies_cnn_lstm} \n")
print(f"Validation accuracies: {validation_accuracies_cnn_lstm} \n")
print(f"Training loses: {train_losses_cnn_lstm} \n")
print(f"Testing loses: {test_losses_cnn_lstm} \n")
print(f"Validation loses: {validation_losses_cnn_lstm} \n")

In [None]:
# Plotting Training, Validation and Test Accuracies
plt.figure(figsize=(10, 6))
plt.plot(train_accuracies_cnn_lstm, label='Train Accuracy')
plt.plot(test_accuracies_cnn_lstm, label='Test Accuracy')
plt.plot(validation_accuracies_cnn_lstm, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training, Validation and Test Accuracies of CNN+LSTM Model')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# Plotting Training, Validation and Test Losses
plt.figure(figsize=(10, 6))
plt.plot(train_losses_cnn_lstm, label='Train Loss')
plt.plot(test_losses_cnn_lstm, label='Test Loss')
plt.plot(validation_losses_cnn_lstm, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training, Validation and Test Losses of CNN + LSTM Model')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
print(classification_report(all_test_labels_cnn_lstm, all_test_predictions_cnn_lstm))

In [None]:
# Calculate confusion matrix
cm = confusion_matrix(all_test_labels_cnn_lstm, all_test_predictions_cnn_lstm)

# Define class names (if available)
class_names = ['Benign (Class 0)', 'Malignant (Class 1)', 'Normal (Class 2)'] 

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, cmap='Blues', fmt='g', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

**Training and evaluating ResNet Model:**

Pre-trained models can be found at: https://pytorch.org/vision/0.9/models.html

In [None]:
# Load a pre-trained ResNet model
resnet_model = models.resnet152(pretrained=True)

In [None]:
print(type(resnet_model))

In [None]:
# Create summary of ResNet model
# Input size of images would be a batch size of 16, RGB channels of 3, Height of 224 and Width of 224
summary(resnet_model, input_size=(16, 3, 224, 224))

In [None]:
# Modify the model for the 3 classes: benign, malignant, normal
num_features = resnet_model.fc.in_features
resnet_model.fc = nn.Linear(num_features, 3)
resnet_model = resnet_model.to(device)

In [None]:
train_accuracies_resnet, test_accuracies_resnet, validation_accuracies_resnet, train_losses_resnet, test_losses_resnet, validation_losses_resnet, all_test_labels_resnet, all_test_predictions_resnet = train_and_evaluate(model = resnet_model, model_name = "ResNet_Model", train_loader = train_loader, validation_loader = validation_loader, test_loader = test_loader, epochs=10, lr=1e-4)

In [None]:
# Plotting Training, Validation and Test Accuracies
plt.figure(figsize=(10, 6))
plt.plot(train_accuracies_resnet, label='Train Accuracy')
plt.plot(test_accuracies_resnet, label='Test Accuracy')
plt.plot(validation_accuracies_resnet, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training, Validation and Test Accuracies of ResNet Model')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# Plotting Training, Validation and Test Losses
plt.figure(figsize=(10, 6))
plt.plot(train_losses_resnet, label='Train Loss')
plt.plot(test_losses_resnet, label='Test Loss')
plt.plot(validation_losses_resnet, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training, Validation and Test Losses of ResNet Model')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
print(classification_report(all_test_labels_resnet, all_test_predictions_resnet))

In [None]:
# Calculate confusion matrix
cm = confusion_matrix(all_test_labels_resnet, all_test_predictions_resnet)

# Define class names (if available)
class_names = ['Benign (Class 0)', 'Malignant (Class 1)', 'Normal (Class 2)'] 

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, cmap='Blues', fmt='g', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

**Training and evaluating VGG Model:**

In [None]:
# Load the pre-trained VGG-19 model with batch normalization
vgg19_bn = models.vgg19_bn(pretrained=True)

In [None]:
print(type(vgg19_bn))

In [None]:
# Create summary of VGG-19 model
# Input size of images would be a batch size of 16, RGB channels of 3, Height of 224 and Width of 224
summary(vgg19_bn, input_size=(16, 3, 224, 224))

In [None]:
# Modify the model for the 3 classes: benign, malignant, normal
num_classes = 3
vgg19_bn.classifier[6] = nn.Linear(vgg19_bn.classifier[6].in_features, num_classes)

In [None]:
train_accuracies_vgg19_bn, test_accuracies_vgg19_bn, validation_accuracies_vgg19_bn, train_losses_vgg19_bn, test_losses_vgg19_bn, validation_losses_vgg19_bn, all_test_labels_vgg19_bn, all_test_predictions_vgg19_bn = train_and_evaluate(model = vgg19_bn, model_name = "VGGNet_Model", train_loader = train_loader, validation_loader = validation_loader, test_loader = test_loader, epochs = 10, lr = 1e-4)

In [None]:
# Plotting Training, Validation and Test Accuracies
plt.figure(figsize=(10, 6))
plt.plot(train_accuracies_vgg19_bn, label='Train Accuracy')
plt.plot(test_accuracies_vgg19_bn, label='Test Accuracy')
plt.plot(validation_accuracies_vgg19_bn, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training, Validation and Test Accuracies of VGG-19 Model')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# Plotting Training, Validation and Test Losses
plt.figure(figsize=(10, 6))
plt.plot(train_losses_vgg19_bn, label='Train Loss')
plt.plot(test_losses_vgg19_bn, label='Test Loss')
plt.plot(validation_losses_vgg19_bn, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training, Validation and Test Losses of VGG-19 Model')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
print(classification_report(all_test_labels_vgg19_bn, all_test_predictions_vgg19_bn))

In [None]:
# Calculate confusion matrix
cm = confusion_matrix(all_test_labels_vgg19_bn, all_test_predictions_vgg19_bn)

# Define class names (if available)
class_names = ['Benign (Class 0)', 'Malignant (Class 1)', 'Normal (Class 2)'] 

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, cmap='Blues', fmt='g', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

**Training and evaluating DenseNet 161 Model:**

In [None]:
# Load the pre-trained DenseNet model
densenet_model = models.densenet161(pretrained=True)

In [None]:
print(type(densenet_model))

In [None]:
# Create summary of DenseNet model
# Input size of images would be a batch size of 16, RGB channels of 3, Height of 224 and Width of 224
summary(densenet_model, input_size=(16, 3, 224, 224))

In [None]:
# Modify the model for the 3 classes: benign, malignant, normal
num_classes = 3
densenet_model.classifier = nn.Linear(densenet_model.classifier.in_features, num_classes)

In [None]:
train_accuracies_densenet_model, test_accuracies_densenet_model, validation_accuracies_densenet_model, train_losses_densenet_model, test_losses_densenet_model, validation_losses_densenet_model, all_test_labels_densenet_model, all_test_predictions_densenet_model = train_and_evaluate(model = densenet_model, model_name = "DenseNet_Model", train_loader = train_loader, validation_loader = validation_loader, test_loader = test_loader, epochs = 10, lr = 1e-4)

In [None]:
# Plotting Training, Validation and Test Accuracies
plt.figure(figsize=(10, 6))
plt.plot(train_accuracies_densenet_model, label='Train Accuracy')
plt.plot(test_accuracies_densenet_model, label='Test Accuracy')
plt.plot(validation_accuracies_densenet_model, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training, Validation and Test Accuracies of Densenet-161 Model')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# Plotting Training, Validation and Test Losses
plt.figure(figsize=(10, 6))
plt.plot(train_losses_densenet_model, label='Train Loss')
plt.plot(test_losses_densenet_model, label='Test Loss')
plt.plot(validation_losses_densenet_model, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training, Validation and Test Losses of Densenet-161 Model')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
print(classification_report(all_test_labels_densenet_model, all_test_predictions_densenet_model))

In [None]:
# Calculate confusion matrix
cm = confusion_matrix(all_test_labels_densenet_model, all_test_predictions_densenet_model)

# Define class names (if available)
class_names = ['Benign (Class 0)', 'Malignant (Class 1)', 'Normal (Class 2)'] 

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, cmap='Blues', fmt='g', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

**Training and evaluating MobileNet V3 Large Model:**

In [None]:
# Load the pre-trained MobileNetV3-Large model
mobilenet_model = models.mobilenet_v3_large(pretrained=True)

In [None]:
print(type(mobilenet_model))

In [None]:
# Create summary of MobileNet model
# Input size of images would be a batch size of 16, RGB channels of 3, Height of 224 and Width of 224
summary(mobilenet_model, input_size=(16, 3, 224, 224))

In [None]:
# Modify the model for the 3 classes: benign, malignant, normal
num_classes = 3
mobilenet_model.classifier[3] = nn.Linear(mobilenet_model.classifier[3].in_features, num_classes)

In [None]:
train_accuracies_mobilenet_model, test_accuracies_mobilenet_model, validation_accuracies_mobilenet_model, train_losses_mobilenet_model, test_losses_mobilenet_model, validation_losses_mobilenet_model, all_test_labels_mobilenet_model, all_test_predictions_mobilenet_model = train_and_evaluate(model = mobilenet_model, model_name = "MobileNet_Model", train_loader = train_loader, validation_loader = validation_loader, test_loader = test_loader, epochs = 10, lr = 1e-4)

In [None]:
# Plotting Training, Validation and Test Accuracies
plt.figure(figsize=(10, 6))
plt.plot(train_accuracies_mobilenet_model, label='Train Accuracy')
plt.plot(test_accuracies_mobilenet_model, label='Test Accuracy')
plt.plot(validation_accuracies_mobilenet_model, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training, Validation and Test Accuracies of MobileNet V3 Large Model')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# Plotting Training, Validation and Test Losses
plt.figure(figsize=(10, 6))
plt.plot(train_losses_mobilenet_model, label='Train Loss')
plt.plot(test_losses_mobilenet_model, label='Test Loss')
plt.plot(validation_losses_mobilenet_model, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training, Validation and Test Losses of MobileNet V3 Large Model')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
print(classification_report(all_test_labels_mobilenet_model, all_test_predictions_mobilenet_model))

In [None]:
# Calculate confusion matrix
cm = confusion_matrix(all_test_labels_mobilenet_model, all_test_predictions_mobilenet_model)

# Define class names (if available)
class_names = ['Benign (Class 0)', 'Malignant (Class 1)', 'Normal (Class 2)'] 

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, cmap='Blues', fmt='g', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

**Training and evaluating Wide ResNet-101-2	model:**

In [None]:
# Load the pre-trained Wide ResNet-101-2 model
wide_resnet_model = models.wide_resnet101_2(pretrained=True)

In [None]:
print(type(wide_resnet_model))

In [None]:
# Create summary of Wide ResNet model
# Input size of images would be a batch size of 16, RGB channels of 3, Height of 224 and Width of 224
summary(wide_resnet_model, input_size=(16, 3, 224, 224))

In [None]:
# Modify the model for the 3 classes: benign, malignant, normal
num_classes = 3
wide_resnet_model.fc = nn.Linear(wide_resnet_model.fc.in_features, num_classes)

In [None]:
train_accuracies_wide_resnet_model, test_accuracies_wide_resnet_model, validation_accuracies_wide_resnet_model, train_losses_wide_resnet_model, test_losses_wide_resnet_model, validation_losses_wide_resnet_model, all_test_labels_wide_resnet_model, all_test_predictions_wide_resnet_model = train_and_evaluate(model = wide_resnet_model, model_name = "Wide_ResNet_Model", train_loader = train_loader, validation_loader = validation_loader, test_loader = test_loader, epochs = 10, lr = 1e-4)

In [None]:
# Plotting Training, Validation and Test Accuracies
plt.figure(figsize=(10, 6))
plt.plot(train_accuracies_wide_resnet_model, label='Train Accuracy')
plt.plot(test_accuracies_wide_resnet_model, label='Test Accuracy')
plt.plot(validation_accuracies_wide_resnet_model, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training, Validation and Test Accuracies of Wide ResNet-101-2 Large Model')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# Plotting Training, Validation and Test Losses
plt.figure(figsize=(10, 6))
plt.plot(train_losses_wide_resnet_model, label='Train Loss')
plt.plot(test_losses_wide_resnet_model, label='Test Loss')
plt.plot(validation_losses_wide_resnet_model, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training, Validation and Test Losses of MobileNet V3 Large Model')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
print(classification_report(all_test_labels_wide_resnet_model, all_test_predictions_wide_resnet_model))

In [None]:
# Calculate confusion matrix
cm = confusion_matrix(all_test_labels_wide_resnet_model, all_test_predictions_wide_resnet_model)

# Define class names (if available)
class_names = ['Benign (Class 0)', 'Malignant (Class 1)', 'Normal (Class 2)'] 

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, cmap='Blues', fmt='g', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

**Training and evaluating AlexNet model:**

In [None]:
# Load the pre-trained AlexNet model
alexnet_model = models.alexnet(pretrained=True)

In [None]:
print(type(alexnet_model))

In [None]:
# Create summary of AlexNet model
# Input size of images would be a batch size of 16, RGB channels of 3, Height of 224 and Width of 224
summary(alexnet_model, input_size=(16, 3, 224, 224))

In [None]:
# Modify the model for the 3 classes: benign, malignant, normal
num_classes = 3
alexnet_model.classifier[6] = nn.Linear(alexnet_model.classifier[6].in_features, num_classes)

In [None]:
train_accuracies_alexnet_model, test_accuracies_alexnet_model, validation_accuracies_alexnet_model, train_losses_alexnet_model, test_losses_alexnet_model, validation_losses_alexnet_model, all_test_labels_alexnet_model, all_test_predictions_alexnet_model = train_and_evaluate(model = alexnet_model, model_name == "AlexNet_Model", train_loader = train_loader, validation_loader = validation_loader, test_loader = test_loader, epochs = 10, lr = 1e-4)

In [None]:
# Plotting Training, Validation and Test Accuracies
plt.figure(figsize=(10, 6))
plt.plot(train_accuracies_alexnet_model, label='Train Accuracy')
plt.plot(test_accuracies_alexnet_model, label='Test Accuracy')
plt.plot(validation_accuracies_alexnet_model, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training, Validation and Test Accuracies of AlexNet Model')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# Plotting Training, Validation and Test Losses
plt.figure(figsize=(10, 6))
plt.plot(train_losses_alexnet_model, label='Train Loss')
plt.plot(test_losses_alexnet_model, label='Test Loss')
plt.plot(validation_losses_alexnet_model, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training, Validation and Test Losses of AlexNet Model')
plt.legend()
plt.grid(True)
plt.show()

**Training and evaluating GoogleNet Model:**

GoogleNet requires images of input size of 299x299 pixels, hence we need to change the resize in the augmented transform function and reload data.

In [None]:
# Augmented transformation with horizontal flip and random rotation
augmented_transform_GoogleNet = transforms.Compose([
    
    # Resize input images to a fixed size of 299 * 299 pixels
    transforms.Resize((299, 299)),
    
    # Randomly flips the images horizontally with a probability of 1.0
    transforms.RandomHorizontalFlip(p = 1.0), 
    
    # Randomly rotate images in the range of -15 to +15 degrees
    transforms.RandomRotation(degrees = 15), 
    
    # Converts the augmented images into PyTorch tensors
    transforms.ToTensor(),
    
    # Standardize input data by subtracting the mean and dividing by the standard deviation along each channel (RGB)
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
# Load the dataset from local folder, without applying any transformations 
dataset_GoogleNet = datasets.ImageFolder(root='data', transform = augmented_transform_GoogleNet)

In [None]:
# Splitting the dataset - 70% for training, 15% for validation and the rest 15% for testing
train_size_GoogleNet = int(0.7 * len(dataset_GoogleNet)) 
val_size_GoogleNet = int(0.15 * len(dataset_GoogleNet)) 
test_size_GoogleNet = len(dataset_GoogleNet) - train_size_GoogleNet - val_size_GoogleNet 

# Extracting labels from the dataset
targets_GoogleNet = [s[1] for s in dataset_GoogleNet.samples]  

# Splitting the dataset into 2 sets: train+val and test sets
train_val_idx_GoogleNet, test_idx_GoogleNet = train_test_split(
    # The array which we want to split
    range(len(targets_GoogleNet)),
    # Specify proportion of the dataset that should be allocated for the test set 
    test_size = test_size_GoogleNet/len(dataset_GoogleNet), 
    # Ensure class distribution is preserved in both the training+validation and testing sets
    stratify = targets_GoogleNet,
    # random seed for reproducibility
    random_state = 42  
)

# Splitting the train+val further into 2 sets: training and validation sets
train_idx_GoogleNet, val_idx_GoogleNet = train_test_split(
    # The array which we want to split
    train_val_idx_GoogleNet,
    # Specify proportion of the dataset that should be allocated for the validation set 
    test_size = val_size_GoogleNet / (train_size_GoogleNet + val_size_GoogleNet),  
    # Ensure class distribution is preserved for both the training and validation set
    stratify = [targets_GoogleNet[i] for i in train_val_idx_GoogleNet],
    # random seed for reproducibility
    random_state = 42  
)

# Creating subsets for each split
train_dataset_GoogleNet = Subset(dataset_GoogleNet, train_idx_GoogleNet)
validation_dataset_GoogleNet = Subset(dataset_GoogleNet, val_idx_GoogleNet)
test_dataset_GoogleNet = Subset(dataset_GoogleNet, test_idx_GoogleNet)

# Creating data loaders that help iterate through the dataset in batches during training, validation and testing.
# Using a batch size of 16 everywhere and only shuffling during training and not validation or testing.
train_loader_GoogleNet = DataLoader(train_dataset_GoogleNet, batch_size=16, shuffle=True)
validation_loader_GoogleNet = DataLoader(validation_dataset_GoogleNet, batch_size=16, shuffle=False)
test_loader_GoogleNet = DataLoader(test_dataset_GoogleNet, batch_size=16, shuffle=False)

In [None]:
# Load the pre-trained GoogLeNet model
googlenet_model = models.inception_v3(pretrained=True)

In [None]:
print(type(googlenet_model))

In [None]:
# Create summary of GoogleNet model
# Input size of images would be a batch size of 16, RGB channels of 3, Height of 224 and Width of 224
summary(googlenet_model, input_size=(16, 3, 299, 299))

In [None]:
# Modify the model for the 3 classes: benign, malignant, normal
num_classes = 3
googlenet_model.fc = nn.Linear(googlenet_model.fc.in_features, num_classes)
googlenet_model = googlenet_model.to(device)

In [None]:
train_accuracies_googlenet_model, test_accuracies_googlenet_model, validation_accuracies_googlenet_model, train_losses_googlenet_model, test_losses_googlenet_model, validation_losses_googlenet_model, all_test_labels_googlenet_model, all_test_predictions_googlenet_model = train_and_evaluate(model = googlenet_model, model_name = "GoogleNet_Model", train_loader = train_loader_GoogleNet, validation_loader = validation_loader_GoogleNet, test_loader = test_loader_GoogleNet, epochs = 10, lr = 1e-4)

In [None]:
# Plotting Training, Validation and Test Accuracies
plt.figure(figsize=(10, 6))
plt.plot(train_accuracies_googlenet_model, label='Train Accuracy')
plt.plot(test_accuracies_googlenet_model, label='Test Accuracy')
plt.plot(validation_accuracies_googlenet_model, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training, Validation and Test Accuracies of GoogleNet Model')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# Plotting Training, Validation and Test Losses
plt.figure(figsize=(10, 6))
plt.plot(train_losses_googlenet_model, label='Train Loss')
plt.plot(test_losses_googlenet_model, label='Test Loss')
plt.plot(validation_losses_googlenet_model, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training, Validation and Test Losses of GoogleNet Model')
plt.legend()
plt.grid(True)
plt.show()

## Hyperparameter Search

In [None]:
# Defining dictionary of hyperparameters containing:
# Learning rate - The learning rate of our Adam optimizer
# Dropout rate - The dropout rate applied to our Adam optimizer
# fc_units - The number of output features/neurons in the fully connected layer

hyperparameters = {
    'learning_rate': [1e-3, 1e-4, 1e-5],
    'dropout_rate': [0.25, 0.5, 0.75],
    'fc_units': [64, 128, 256]
}

In [None]:
def train_and_evaluate_new(model, train_loader, validation_loader, test_loader, epochs=10, lr=0.0001, save_path='best_model.pt'):
    
    # Move device to GPU if available, else CPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # Computes Loss Function using Cross Entropy Loss
    criterion = nn.CrossEntropyLoss()
        
    # Initialize Adam Optimizer with certain parameters 
    optimizer = optim.Adam(model.parameters(),
                           # Set learning rate of the optimizer
                           lr=lr,
                           # Set L2 regularization to the model's weights during optimization
                           weight_decay=1e-4, 
                           # Defining beta1 (first moment) as 0.9 and beta2 (second moment) as 0.999
                           betas=(0.9, 0.999),
                           # Epsilon to ensure numerical stability during optimization
                           eps=1e-8,
                           # Adopting AMSGrad variant of the Adam optimizer
                           amsgrad=True)

    # Variable to keep track of the best validation accuracy
    best_validation_accuracy = 0  
    
    # Store various metrics in form of dictionary. These metrics are as follows:
    # best_validation_accuracy: Highest validation accuracy achieved during the training process 
    # best_validation_loss: Lowest validation loss achieved during the training process
    # corresponding_test_accuracy: Test accuracy corresponding to the best validation accuracy
    # corresponding_test_loss: Test loss corresponding to the best validation loss
    # final_train_accuracy: Final training accuracy achieved at the end of the training process
    # final_train_loss: Final training loss achieved at the end of the training process
    metrics = {
        'best_validation_accuracy': 0,
        'best_validation_loss': float('inf'),
        'corresponding_test_accuracy': 0,
        'corresponding_test_loss': float('inf'),
        'final_train_accuracy': 0,
        'final_train_loss': float('inf'),
    }

    # Running the algorithm for specified number of epochs
    for epoch in range(epochs):
        
        # Set the model to train mode
        model.train()
                
        # Keeps track of training loss
        train_loss = 0
        
        # Correct_train keeps track of the number of correctly predicted samples during training
        correct_train = 0
        
        # Total_train keeps track of the number of the total number of training samples processed so far 
        total_train = 0
            
        # For each input data and label in training dataset
        for inputs, labels in train_loader:
            
            # Move it to the right device
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Set all calculated gradients to 0
            optimizer.zero_grad()
            
            # Trigger a forward pass through the neural network
            outputs = model(inputs)
            
            # Compare the outputs (predictions made by the model) and the labels (ground truth data) wrt to loss function
            loss = criterion(outputs, labels)
            
            # Compute the gradients of the loss with respect to all trainable parameters in the model.
            loss.backward()
            
            # Update the model parameters based on the computed gradients
            optimizer.step()
            
            # Accumulates the value of the loss incurred in the current batch to the overall training loss.
            # Note: .item() extracts scalar value
            train_loss += loss.item()
            
            # Obtain the class predicted by the model
            _, predicted = torch.max(outputs.data, 1)
            
            # Update total_train to keep track of the total number of training samples processed so far
            total_train += labels.size(0)
            
            # Update correct_train to keep track of the number of correctly predicted samples in the current batch
            correct_train += (predicted == labels).sum().item()
        
        # Calculate train_accuracy
        train_accuracy = 100 * correct_train / total_train
                
        # Calculate the average training loss
        train_loss_avg = train_loss / len(train_loader)

        
        
        ## VALIDATION
        
        # Set the model to evaluation mode
        model.eval()
        
        # Correct_validation keeps track of the number of correctly predicted samples during validation
        correct_validation = 0
        
        # Total_validation keeps track of the number of total samples during validation
        total_validation = 0
        
        # Validation_loss keeps track of total loss during evaluation on the validation set
        validation_loss = 0
        
        # During validation, we don't need to utilize gradient
        with torch.no_grad():
            
            # For each input data and label in validation dataset
            for inputs, labels in validation_loader:
                
                # Move it to the right device
                inputs, labels = inputs.to(device), labels.to(device)
                
                # Trigger a forward pass through the neural network
                outputs = model(inputs)
                
                # Compare the outputs (predictions made by the model) and the labels (ground truth data) wrt to loss function
                loss = criterion(outputs, labels)
                
                # Accumulates the value of the loss incurred in the current batch to the overall validation loss.
                # Note: .item() extracts scalar value
                validation_loss += loss.item()
                
                # Obtain the class predicted by the model
                _, predicted = torch.max(outputs.data, 1)
                
                # Update total_validation to keep track of the total number of validation samples processed so far
                total_validation += labels.size(0)
                
                # Update correct_validation to keep track of the number of correctly predicted samples in the current batch
                correct_validation += (predicted == labels).sum().item()
            
        # Calculate validation_accuracy
        validation_accuracy = 100 * correct_validation / total_validation
        
        # Calculate validation loss
        validation_loss_avg = validation_loss / len(validation_loader)
        
        
        ## TESTING
        
        
        # Keeps track of testing loss
        test_loss = 0 
        
        # Keeps track of the number of correctly predicted samples during testing
        correct_test = 0
        
        # Total_test keeps track of the number of the total number of testing samples processed so far 
        total_test = 0
        
        # During testing, we don't need to utilize gradient
        with torch.no_grad():
            
            # For each input data and label in test dataset
            for inputs, labels in test_loader:
                
                # Move it to the right device
                inputs, labels = inputs.to(device), labels.to(device)
                
                # Trigger a forward pass through the neural network
                outputs = model(inputs)
                
                # Compare the outputs (predictions made by the model) and the labels (ground truth data) wrt to loss function
                loss = criterion(outputs, labels)
                
                # Accumulates the value of the loss incurred in the current batch to the overall testing loss.
                # Note: .item() extracts scalar value
                test_loss += loss.item()
                
                # Obtain the class predicted by the model
                _, predicted = torch.max(outputs.data, 1)
                
                # Update total_test to keep track of the total number of testing samples processed so far
                total_test += labels.size(0)
                
                # Append the testing loss to correct_test list to monitor it over epochs
                correct_test += (predicted == labels).sum().item()
            
        # Calculate testing accuracy
        test_accuracy = 100 * correct_test / total_test
        
        # Calculate testing loss
        test_loss_avg = test_loss / len(test_loader)

        
        if validation_accuracy > best_validation_accuracy:
            
            # Update best_validation_accuracy as better validation accuracy has been found
            best_validation_accuracy = validation_accuracy
            
            # Save the model weights
            torch.save(model.state_dict(), save_path)
            
            # Update metrics with the new best values
            metrics.update({
                'best_validation_accuracy': validation_accuracy,
                'best_validation_loss': validation_loss_avg,
                'corresponding_test_accuracy': test_accuracy,
                'corresponding_test_loss': test_loss_avg,
            })

        # Always update the Final training accuracy and final training loss of the model after each epoch
        metrics.update({
            'final_train_accuracy': train_accuracy,
            'final_train_loss': train_loss_avg,
        })

        # Print necessary information after each epoch ends
        print(f"Epoch {epoch+1}, Train Loss: {train_loss_avg:.4f}, "
              f"Validation Loss: {validation_loss_avg:.4f}, "
              f"Test Loss: {test_loss_avg:.4f}, "
              f"Train Accuracy: {train_accuracy:.2f}%, "
              f"Validation Accuracy: {validation_accuracy:.2f}%, "
              f"Test Accuracy: {test_accuracy:.2f}%")
        print("--------------------------------------------------------------")
    
    
    # After the loop, save the path in metrics for reference
    metrics['save_path'] = save_path
    
    # Return the metrics dictionary
    return metrics

In [None]:
def random_search(hyperparameters, num_trials, base_save_path='weights'):
    
    # If 'weights' folder does not exist, create it
    if not os.path.exists(base_save_path):
        os.makedirs(base_save_path)

    # Store performance metrics of each trial
    results = []
    
    # Keeps track of the highest validation accuracy achieved among all trials
    best_validation_accuracy = 0
    
    # Stores the file path of the best-performing model
    best_model_path = ""

    
    for trial in range(num_trials):
        
        # Randomly sample hyperparameters from the provided hyperparameter dictionary
        lr = random.choice(hyperparameters['learning_rate'])
        dropout = random.choice(hyperparameters['dropout_rate'])
        fc_units = random.choice(hyperparameters['fc_units'])
        
        # Define the save path for the current trial's best model
        save_path = os.path.join(base_save_path, f'best_model_trial_{trial+1}.pt')
        
        print(f"Trial {trial+1}: Training with lr={lr}, dropout={dropout}, fc_units={fc_units}")
        
        # Initialize the model with particular dropout rate and fc_units
        model = CNN_for_LungCancer(dropout_rate = dropout, fc_units = fc_units)
        
        # Call the train_and_evaluate_new function with the chosen hyperparameters
        metrics = train_and_evaluate_new(model, train_loader, validation_loader, test_loader, epochs=10, lr=lr, save_path=save_path)
        
        # Append metrics to results list
        results.append(metrics)
        
        # Update the path of the best model if the current model performence is better
        if metrics['best_validation_accuracy'] > best_validation_accuracy:
            best_validation_accuracy = metrics['best_validation_accuracy']
            best_model_path = metrics['save_path']
    
    print(f"Best model saved at: {best_model_path}")
    return results, best_model_path

In [None]:
results, best_model_path = random_search(hyperparameters, num_trials = 10)

### Visualizing the accuracy and loss obtained for all trials of hyperparameter search:

In [None]:
train_accuracies = np.array([res['final_train_accuracy'] for res in results])
validation_accuracies = np.array([res['best_validation_accuracy'] for res in results])
test_accuracies = np.array([res['corresponding_test_accuracy'] for res in results])

train_losses = np.array([res['final_train_loss'] for res in results])
validation_losses = np.array([res['best_validation_loss'] for res in results])
test_losses = np.array([res['corresponding_test_loss'] for res in results])

trials = np.arange(1, len(results) + 1)

# Identifying the best trial based on validation accuracy
best_trial_acc = np.argmax(validation_accuracies) + 1
best_acc = validation_accuracies[best_trial_acc - 1]

# Plotting accuracies (training, validation, test)
plt.figure(figsize=(14, 7))

plt.subplot(1, 2, 1)
plt.plot(trials, train_accuracies, 'o-', label='Train Acc', color='blue')
plt.plot(trials, validation_accuracies, 's--', label='Valid Acc', color='orange')
plt.plot(trials, test_accuracies, '^-.', label='Test Acc', color='green')
plt.axvline(x=best_trial_acc, color='gray', linestyle='--', label=f'Best Trial #{best_trial_acc}')
plt.title('Accuracy Across Trials')
plt.xlabel('Trial')
plt.ylabel('Accuracy (%)')
plt.legend()

# Plotting losses (training, validation, and test)
plt.subplot(1, 2, 2)
plt.plot(trials, train_losses, 'o-', label='Train Loss', color='blue')
plt.plot(trials, validation_losses, 's--', label='Validation Loss', color='orange')
plt.plot(trials, test_losses, '^-.', label='Test Loss', color='green')
plt.axvline(x=best_trial_acc, color='gray', linestyle='--', label=f'Best Trial #{best_trial_acc}')
plt.title('Loss Across Trials')
plt.xlabel('Trial')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()

print(f"Best Trial: {best_trial_acc} with Validation Accuracy: {best_acc}%")

## Loading the model and conducting Inferencing

In [None]:
# Augmented transformation with horizontal flip and random rotation
augmented_transform = transforms.Compose([
    
    # Resize input images to a fixed size of 224x224 pixels
    transforms.Resize((224, 224)),
    
    # Randomly flips the images horizontally with a probability of 1.0
    transforms.RandomHorizontalFlip(p=1.0), 
    
    # Randomly rotate images in the range of -15 to +15 degrees
    transforms.RandomRotation(degrees=15), 
    
    # Converts the augmented images into PyTorch tensors
    transforms.ToTensor(),
    
    # Standardize input data by subtracting the mean and dividing by the standard deviation along each channel (RGB)
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load the dataset from local folder, applying transformations this time
dataset = datasets.ImageFolder(root='dataset', transform = augmented_transform)

# Splitting the dataset - 70% for training, 15% for validation and the rest 15% for testing
train_size = int(0.7 * len(dataset)) 
val_size = int(0.15 * len(dataset)) 
test_size = len(dataset) - train_size - val_size 

# Extracting labels from the dataset
targets = [s[1] for s in dataset.samples]  

# Splitting the dataset into 2 sets: train+val and test sets
train_val_idx, test_idx = train_test_split(
    # The array which we want to split
    range(len(targets)),
    # Specify proportion of the dataset that should be allocated for the test set 
    test_size= test_size/len(dataset), 
    # Ensure class distribution is preserved in both the training+validation and testing sets
    stratify=targets,
    # random seed for reproducibility
    random_state=42  
)

# Splitting the train+val further into 2 sets: training and validation sets
train_idx, val_idx = train_test_split(
    # The array which we want to split
    train_val_idx,
    # Specify proportion of the dataset that should be allocated for the validation set 
    test_size=val_size / (train_size + val_size),  
    # Ensure class distribution is preserved for both the training and validation set
    stratify=[targets[i] for i in train_val_idx],
    # random seed for reproducibility
    random_state=42  
)

# Creating subsets for each split
train_dataset = Subset(dataset, train_idx)
validation_dataset = Subset(dataset, val_idx)
test_dataset = Subset(dataset, test_idx)

# Creating data loaders that help iterate through the dataset in batches during training, validation and testing.
# Using a batch size of 16 everywhere and only shuffling during training and not validation or testing.
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [None]:
cnn_model_inference = CNN_for_LungCancer(dropout_rate=0.5, fc_units=64)
cnn_model_inference.load_state_dict(torch.load('model/lung_cancer_classification_cnn_model.pth'))
cnn_model_inference.eval() 

In [None]:
# resnet_model_inference = models.resnet152(pretrained=True)

# num_features = resnet_model_inference.fc.in_features
# resnet_model_inference.fc = nn.Linear(num_features, 3)
# resnet_model_inference = resnet_model_inference.to(device)
# resnet_model_inference.load_state_dict(torch.load('model/lung_cancer_classification_cnn_model.pth'))
# resnet_model_inference.eval() 

In [None]:
def testing_model(model, test_loader):
    
    # Move device to GPU if available, else CPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # Computes Loss Function using Cross Entropy Loss
    criterion = nn.CrossEntropyLoss()
    
    # Keeps track of testing loss
    test_loss = 0 
    # Keeps track of the number of correctly predicted samples during testing
    correct_test = 0
    # Total_test keeps track of the number of the total number of testing samples processed so far 
    total_test = 0
    
    # During testing, we don't need to utilize gradient
    with torch.no_grad():
        # For each input data and label in test dataset
        for inputs, labels in test_loader:
            # Move data to the right device
            inputs, labels = inputs.to(device), labels.to(device)
            # Trigger a forward pass through the neural network
            outputs = model(inputs)
            # Compare the outputs (predictions made by the model) and the labels (ground truth data) wrt to loss function
            loss = criterion(outputs, labels)
            # Accumulates the value of the loss incurred in the current batch to the overall testing loss.
            test_loss += loss.item()
            # Obtain the class predicted by the model
            _, predicted = torch.max(outputs.data, 1)
            # Update total_test to keep track of the total number of testing samples processed so far
            total_test += labels.size(0)
            # Update correct_test to keep track of the number of correctly predicted samples in the current batch
            correct_test += (predicted == labels).sum().item()

    # Calculate testing accuracy
    final_test_loss = test_loss / len(test_loader)
    test_accuracy = (correct_test/total_test) * 100
    
    print(f"Test Loss: {final_test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
    print("--------------------------------------------------------------")
                                      
    return final_test_loss, test_accuracy

In [None]:
test_loss, test_accuracy = testing_model(model = inference_model, test_loader = test_loader_for_inference)

In [None]:
# test_loss, test_accuracy = testing_model(model = resnet_model_inference, test_loader = test_loader)

## Appendix:

### Seed for reproducibility:

In [None]:
# # Function to set the seed
# def set_seed(seed=42):
#     """Sets various seed for reproducibility (especially useful when checking whether saved trained model matches with that loaded on inference)."""
    
    
#     random.seed(seed)
    
#     # Sets the seed for NumPy's random number generator
#     np.random.seed(seed)
    
#     # Sets the seed for generating random numbers in PyTorch on CPU and GPU.
#     torch.manual_seed(seed)
    
#     # Ensure seed is set for all CUDA devices for operations that use PyTorch with CUDA
#     torch.cuda.manual_seed_all(seed)
    
#     # Forces CuDNN to use deterministic algorithms
#     torch.backends.cudnn.deterministic = True
    
#     # Disables the CuDNN benchmarking feature, which dynamically finds the most efficient algorithms for your specific operations
#     torch.backends.cudnn.benchmark = False

# # Call set_seed at the beginning of your script with a particular seed, say 42. 
# set_seed(seed=42)

### Saving classification report:

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
import numpy as np

def save_classification_report_image(report, filename):
    # Parse the classification report string
    lines = report.split('\n')
    data = [line.split() for line in lines[2:-3]]

    # Extract class names and metrics
    classes = [row[0] for row in data]
    metrics = np.array(data)[:, 1:].astype(float)

    # Plot classification report as a heatmap
    plt.figure(figsize=(8, 6))
    plt.imshow(metrics, cmap='viridis', aspect='auto')
    plt.title('Classification Report Heatmap')
    plt.xlabel('Metrics')
    plt.ylabel('Classes')
    plt.xticks(np.arange(len(classes)), ['Precision', 'Recall', 'F1-score', 'Support'])
    plt.yticks(np.arange(len(classes)), classes)
    plt.colorbar(label='Metric Value')

    # Save the plot as an image
    plt.savefig(filename)
    plt.close()

# Example classification report

# Save the classification report as an image
save_classification_report_image(report, 'classification_report.png')

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def save_classification_report_image(report, filename):
    # Parse the classification report string
    lines = report.strip().split('\n')

    # Extract class names and metrics
    data = [line.split() for line in lines[2:] if line.strip() and not line.strip().startswith(('micro', 'macro', 'weighted'))]  # Skip empty lines and macro/weighted averages

    classes = []
    metrics = []

    for row in data:
        if len(row) >= 5:
            classes.append(row[0])
            metrics.append(row[1:])

    # Convert metrics to numpy array
    metrics = np.array(metrics).astype(float)

    # Plot classification report as a heatmap
    plt.figure(figsize=(8, 6))
    plt.imshow(metrics, cmap='viridis', aspect='auto')
    plt.title('Classification Report Heatmap')
    plt.xlabel('Metrics')
    plt.ylabel('Classes')
    plt.xticks(np.arange(4), ['Precision', 'Recall', 'F1-score', 'Support'])
    plt.yticks(np.arange(len(classes)), classes)
    plt.colorbar(label='Metric Value')

    # Save the plot as an image
    plt.savefig(filename)
    plt.close()

# Example classification report
report = """
             precision    recall  f1-score   support

     class 0       0.75      0.60      0.67        10
     class 1       0.80      0.89      0.84        19
     class 2       0.50      0.50      0.50         6

    accuracy                           0.76        35
   macro avg       0.68      0.66      0.67        35
weighted avg       0.75      0.76      0.75        35
"""

# Save the classification report as an image
save_classification_report_image(report, 'classification_report.png')