# Scalable solution - ResNet50 Pretrained weights CNN

---

In this section, we focus on implementing a highly scalable solution.
Scalability in Data science means that we are well equipped to deal with a large influx of data. How might we approach the situation where we have 10x as much data, or a constant stream of new data. How can we make sure that our model can process this large amount of data as efficiently as possible.

---

We theorise that for our task, a pretrained deep convolutional neural network model will have best performance. Since we have around 5000 images in our training set, we feel that this may not be enough for to train a powerful deep CNN for ourselves, since there will likely be more parameters than data. To be able to learn these very complex relationships between the MRI images and the classes, we would likely need a large CNN with multiple hiddem layers.

It can therefore be concluded that utalising a pretrained model could be a successful way of increasing the depth of our network, only having our small test set of MRI images. 

ResNet-50 is a 50-layer convolutional neural network (48 convolutional layers, one MaxPool layer, and one average pool layer). We use the fact that in pytorch, we can load a set of weights for the network trained on ImageNet. 
"The ImageNet project is a large visual database designed for use in visual object recognition software research. More than 14 million images have been hand-annotated by the project to indicate what objects are pictured." [1]
Note that we have not actually accessed the ImageNet dataset to confirm if it contains the brain MRI dataset we are using in our project. If ImageNet were to contain a small subset of images which are in our testing dataset, this would invalidate our testing accuracy. We continue anyway since in a real commercial or research setting the dataset we use could not be public and therefore not run the risk of this happening, it would absolutely be something to consider.

In [21]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split
import numpy as np

data_dir = "Data/Training"
test_dir = "Data/Testing"

# Define batch size, image dimensions
batch_size = 64
img_height = 244
img_width = 244

# Augmentations and transforms for training set
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(360),
    transforms.RandomResizedCrop((img_height, img_width), scale=(0.8, 1.0)),
    transforms.ColorJitter(brightness=0.1, contrast=(0.8, 1.2)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Resize and Normalise for test set
transform_test = transforms.Compose([
    transforms.Resize((img_height, img_width)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Load datasets
dataset = torchvision.datasets.ImageFolder(data_dir, transform=transform_train)
testset = torchvision.datasets.ImageFolder(test_dir, transform=transform_test)

# Create train and validation datasets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

# Get class names
class_names = dataset.classes

This code defines our augmentations and transforms which we apply to the dataset. We use more transforms such as rotations, cropping, flips and changes to the brightness and contrast on our training dataset so that it helps our model generalise well, and look for import features inside the image rather than unimportant features such as skull shape, angle, head size. We only use basic augmentations on the test set to standardise the image format, therefore our test set accuracy will reflect accuracy on the real images.

We use the ImageFolder function to apply these transforms to the datasets. If we were to collect more data in the future, we could process the data with this same function and then concatinate it with the already processed data, removing the need to reprocess all of the data.

We create a validation set to keep track of during training, ensuring we are careful not to overfit to the training data. This will also help us set out early stopping conditions which could be important when dealing with a large amount of data with slower training since we can stop training at the optimum time and not waste additional computational resources.

We put the data into dataloaders with our specified batch size. Using a specified batch size helps us to manage memory effectively during training, limiting how much memory we are using at once. Batch size can be selected based on this, and stochastic qualities for our optimisation algorithm.

---

In [22]:
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models
from tqdm import tqdm

resnet50 = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
num_classes = len(class_names)
resnet50.fc = torch.nn.Linear(resnet50.fc.in_features, num_classes)

if torch.cuda.is_available():
    resnet50 = resnet50.cuda()
    
# Check if multiple GPUs are available
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    resnet50 = nn.DataParallel(resnet50)

# Set device to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
resnet50.to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet50.parameters(), lr=0.0001)



Using 2 GPUs


In this section we look for GPU's in the system. Our model, and CNN's in general, can be greatly sped up with GPU processing.
We further this by checking for multiple GPU's, and if available, make it so our data can be processed in parallel using all of our GPU power.
Setting our device to the GPU allows us to store our tensors on the GPU during training so we can process them on GPU.

Setting a low learning rate on Adam optimiser takes advantage of the pretrained weights of the model such that we only make fine tuning adjustments. A larger learning rate may try to make too big steps to find a good minimum for this deep CNN.

---

In [23]:
def train_model(model, criterion, optimizer, train_loader, val_loader, num_epochs, device, early_stopping_patience=5):
    model.train()
    best_val_loss = float("inf")
    patience_counter = 0
    for epoch in range(num_epochs):
        running_loss = 0.0
        running_corrects = 0

        # tqdm progress bar
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

        for i, data in enumerate(pbar, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

            # Update progress bar
            pbar.set_postfix(loss=loss.item(), acc=(torch.sum(preds == labels.data).item() / inputs.size(0)))

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.double() / len(train_loader.dataset)

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")

        # Validation
        if val_loader:
            model.eval()
            val_running_loss = 0.0
            val_running_corrects = 0

            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)

                with torch.no_grad():
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                val_running_loss += loss.item() * inputs.size(0)
                val_running_corrects += torch.sum(preds == labels.data)

            val_epoch_loss = val_running_loss / len(val_loader.dataset)
            val_epoch_acc = val_running_corrects.double() / len(val_loader.dataset)

            print(f"Validation Loss: {val_epoch_loss:.4f}, Validation Accuracy: {val_epoch_acc:.4f}")
            
            if val_epoch_loss < best_val_loss:
                best_val_loss = val_epoch_loss
                patience_counter = 0
            else:
                patience_counter += 1

            if patience_counter >= early_stopping_patience:
                print(f"Early stopping triggered after {early_stopping_patience} epochs without improvement in validation loss.")
                torch.save(model.state_dict(), "best_ResNet50_model.pth")
                break
    torch.save(model.state_dict(), "best_ResNet50_model.pth")
    return model


# Train the model
num_epochs = 15
trained_model = train_model(resnet50, criterion, optimizer, train_loader, val_loader, num_epochs, device, early_stopping_patience=5)

Epoch 1/15: 100%|██████████| 72/72 [00:53<00:00,  1.35it/s, acc=0.76, loss=0.776] 

Epoch 1/15, Loss: 0.5038, Accuracy: 0.8387





Validation Loss: 0.2927, Validation Accuracy: 0.9046


Epoch 2/15: 100%|██████████| 72/72 [00:51<00:00,  1.39it/s, acc=0.96, loss=0.0742] 

Epoch 2/15, Loss: 0.3047, Accuracy: 0.8941





Validation Loss: 0.2039, Validation Accuracy: 0.9291


Epoch 3/15: 100%|██████████| 72/72 [00:52<00:00,  1.38it/s, acc=1, loss=0.0614]    

Epoch 3/15, Loss: 0.1445, Accuracy: 0.9470





Validation Loss: 0.1171, Validation Accuracy: 0.9589


Epoch 4/15: 100%|██████████| 72/72 [00:51<00:00,  1.39it/s, acc=1, loss=0.0282]    

Epoch 4/15, Loss: 0.1211, Accuracy: 0.9569





Validation Loss: 0.1621, Validation Accuracy: 0.9396


Epoch 5/15: 100%|██████████| 72/72 [00:52<00:00,  1.38it/s, acc=0.96, loss=0.101]  

Epoch 5/15, Loss: 0.0966, Accuracy: 0.9643





Validation Loss: 0.0985, Validation Accuracy: 0.9659


Epoch 6/15: 100%|██████████| 72/72 [00:51<00:00,  1.39it/s, acc=0.92, loss=0.097]  

Epoch 6/15, Loss: 0.0769, Accuracy: 0.9698





Validation Loss: 0.1033, Validation Accuracy: 0.9694


Epoch 7/15: 100%|██████████| 72/72 [00:52<00:00,  1.36it/s, acc=0.96, loss=0.106]  

Epoch 7/15, Loss: 0.0812, Accuracy: 0.9731





Validation Loss: 0.0748, Validation Accuracy: 0.9781


Epoch 8/15: 100%|██████████| 72/72 [00:52<00:00,  1.36it/s, acc=0.92, loss=0.257]  

Epoch 8/15, Loss: 0.0649, Accuracy: 0.9783





Validation Loss: 0.0995, Validation Accuracy: 0.9650


Epoch 9/15: 100%|██████████| 72/72 [00:51<00:00,  1.40it/s, acc=0.96, loss=0.133]  

Epoch 9/15, Loss: 0.0560, Accuracy: 0.9810





Validation Loss: 0.0637, Validation Accuracy: 0.9825


Epoch 10/15: 100%|██████████| 72/72 [00:52<00:00,  1.38it/s, acc=1, loss=0.00544]   

Epoch 10/15, Loss: 0.0522, Accuracy: 0.9814





Validation Loss: 0.1237, Validation Accuracy: 0.9633


Epoch 11/15: 100%|██████████| 72/72 [00:51<00:00,  1.39it/s, acc=0.96, loss=0.105]  

Epoch 11/15, Loss: 0.0509, Accuracy: 0.9847





Validation Loss: 0.0700, Validation Accuracy: 0.9773


Epoch 12/15: 100%|██████████| 72/72 [00:52<00:00,  1.38it/s, acc=1, loss=0.0077]    

Epoch 12/15, Loss: 0.0473, Accuracy: 0.9831





Validation Loss: 0.0799, Validation Accuracy: 0.9799


Epoch 13/15: 100%|██████████| 72/72 [00:52<00:00,  1.38it/s, acc=1, loss=0.0158]    

Epoch 13/15, Loss: 0.0468, Accuracy: 0.9836





Validation Loss: 0.0511, Validation Accuracy: 0.9834


Epoch 14/15: 100%|██████████| 72/72 [00:53<00:00,  1.35it/s, acc=1, loss=0.0153]    

Epoch 14/15, Loss: 0.0330, Accuracy: 0.9891





Validation Loss: 0.0614, Validation Accuracy: 0.9816


Epoch 15/15: 100%|██████████| 72/72 [00:52<00:00,  1.36it/s, acc=1, loss=0.00215]   

Epoch 15/15, Loss: 0.0692, Accuracy: 0.9766





Validation Loss: 0.0549, Validation Accuracy: 0.9869


We can now implement our training loop. We use a tqdm progress bar to keep track of current progress in each epoch, and ensure training is going to plan. The most notible feature we implement here for scalability is our early_stopping_patience, which keeps track of our best validation loss and if we perform worse than our best validation loss a predefined number of times (here we have set this to 5), we stop training the model. This makes sure that we do not overfit the model, or waste computational resources running a large amount of unneccessary epochs over our data.

Furthermore, at the end of training we save our final model. This allows us to load our model parameters in the future to load the model for predictions, or continue training the model at a later date, perhaps with new data.

---

In [26]:
def predict_on_test_set(model, test_loader, device):
    model.eval()
    true_labels = []
    predicted_labels = []

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            true_labels.extend(labels.cpu().numpy())
            predicted_labels.extend(preds.cpu().numpy())

    return true_labels, predicted_labels

# Make predictions on the test set
true_labels, predicted_labels = predict_on_test_set(trained_model, test_loader, device)

Now that our model is trained using a defined number of epochs and early stopping conditions, we predict for our test set. The predictions allow us to calculate standard metrics for multiclass classification problems.

In [31]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Evaluation metrics
accuracy = accuracy_score(true_labels, predicted_labels)
precision = precision_score(true_labels, predicted_labels, average='weighted')
recall = recall_score(true_labels, predicted_labels, average='weighted')
f1 = f1_score(true_labels, predicted_labels, average='weighted')

# Print evaluation metrics
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1-score: {f1:.4f}")


Accuracy: 0.9825
Precision: 0.9829
Recall: 0.9825
F1-score: 0.9825


We see that our model performs very well.

# Summary

To summarise how we have made this implementation focus on scalability:
* Preprocessing apply function: We make a function to apply augmentations and transformations to data, such that in the future with more data, we can apply this function to our new data and concatinate it with the already processed data
* Batch Dataloaders: We create batches of the data of a defined batch size. This makes sure that we limit the memory usage of the model, allowing us to process as much data as we would like.
* Efficient use of multiple GPUs: Using in built pytorch functionality, we look for multiple GPUs and parallise the model to use both at the same time. This could be scaled up for a very large amount of GPUs
* Stochastic Gradient Descent: Mini batches are used in training for stochastic gradient updates, meaning we have a faster convergence through using less redundent data and more frequent gradient updates. This will be a large noticable speed upgrade proportional to the amount of data.
* Early stopping condition: We stop the model training when we are no longer performing better than our best model, saving computational resources,
* Model checkpointing: We save our best model at the end of training which allows us to reload this model at a later date for new predictions or more training in the future.

In the future we would like to do more research about how the parallel GPUs work together in this setting. We assume that the built-in parallelism of the pytorch functions is very efficient, but it could be useful to know if there was any additional ways we could have set out our data which would allow even better performance with multiple GPUs enabled.

## References

[1]: Wikipedia contributors. "ImageNet." Wikipedia, The Free Encyclopedia. Wikipedia, The Free Encyclopedia, 27 Mar. 2023. Web. 3 May. 2023.