# Imports

In [1]:
#%pip install medmnist
#%pip install einops
#%pip install torcheval
import albumentations as A
import numpy as np
import torch
from torchvision.transforms import Resize, ToTensor
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import accuracy_score, roc_auc_score, f1_score
from torch import nn

from tqdm import tqdm

import os
import pickle
import torch.optim as optim

from medmnist import PneumoniaMNIST, RetinaMNIST, ChestMNIST
import time

from einops import rearrange
from einops import repeat
from einops.layers.torch import Rearrange
from medmnist import PneumoniaMNIST
from medmnist import PneumoniaMNIST, RetinaMNIST, ChestMNIST
from random import random
from torch import Tensor
from torch import nn
from torch.utils.data import DataLoader, random_split
from torcheval.metrics import MulticlassAccuracy
from torchvision.transforms import Resize, ToTensor
from torchvision.transforms.functional import to_pil_image
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import os
import pickle
import tensorflow as tf
import time
import torch
import torch.nn.functional as F
import torch.optim as optim
import torcheval.metrics

# Preprocessing

In [2]:
class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image):
        for t in self.transforms:
            image = t(image)
        return image

class MedMNISTDataset(Dataset):
    def __init__(self, dataset, transform=None, dataset_type='train', img_size=224, nSamples=0,augment_data=False, balance_classes=False):
        self.dataset = dataset(split=dataset_type, download=True, size=img_size, as_rgb=True)
        self.transform = transform
        self.augment_data = augment_data
        self.increaseSize = nSamples
        if augment_data == False:
            self.Transform()

        if balance_classes:
            self.BalanceClasses(nSamples, verbose=True)

    def Transform(self):
        tempDataset = []
        for idx, (image, label) in enumerate(self.dataset):
            if self.transform is not None:
                image = self.transform(image)
            tempDataset.append((image, label))
        self.dataset = tempDataset

    def BalanceClasses(self, increaseSize=0, verbose=False):
        '''
        Balance the classes in the dataset by resampling the minority classes. For the best efftect
        augmentation should be enabled. Otherwise the same image will be duplicated within the dataset.
        '''
        print('Balancing classes...')

        # Get the number of samples in each class
        num_samples = {}
        for _, label in self.dataset:
            label = label[0]
            #print(f'Label: {label}', type(label))
            if label not in num_samples:
                num_samples[label] = 0
            num_samples[label] += 1
        self.num_classes = len(num_samples)
        if verbose:
            print(f'Before balancing: {num_samples} | Num classes: {self.num_classes}')

        # Find the class with the most samples
        if increaseSize > 0:
            if int(increaseSize / len(num_samples)) < max(num_samples.values()):
                max_samples = max(num_samples.values())
            else:
                max_samples = int(increaseSize / len(num_samples))
                print(len(num_samples), increaseSize, increaseSize / len(num_samples), int(increaseSize / len(num_samples)))
                print(f'Increasing size to {max_samples} samples per class.')
        else:
            max_samples = max(num_samples.values())

        # Create a balanced dataset
        balanced_dataset = []
        for image, label in self.dataset:
            balanced_dataset.append((image, label))

        # Resample minority classes
        for label, count in num_samples.items():
            if count < max_samples:
                # Number of samples to add
                num_to_add = max_samples - count

                # Get indices of samples in the minority class
                class_indices = [idx for idx, (_, l) in enumerate(self.dataset) if l == label]

                # Select indices to resample
                selected_indices = np.random.choice(class_indices, num_to_add, replace=True)

                for idx in selected_indices:
                    image, label = self.dataset[idx]
                    balanced_dataset.append((image, label))
        self.dataset = balanced_dataset

        # Control the balance
        if verbose:
            num_samples = {}
            for _, label in self.dataset:
                label = label[0]
                if label not in num_samples:
                    num_samples[label] = 0
                num_samples[label] += 1
            print('After balancing: ', num_samples)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        '''
        Get an item from the dataset, if augmention is enabled, augment the data.
        '''
        image, label = self.dataset[idx]
        if self.augment_data and type(image) != torch.Tensor:
            image = np.asarray(image)
            image = self.transform(image=image)["image"]
            if type(image) != torch.Tensor:
                image = ToTensor()(image)
        return image, label

# ViT

In [3]:
class Attention(nn.Module):
    def __init__(self, dim, heads, dropout):
        super().__init__()
        self.heads = heads
        self.att = torch.nn.MultiheadAttention(embed_dim=dim, num_heads=heads, dropout=dropout)
        self.q = nn.Linear(dim, dim)
        self.k = nn.Linear(dim, dim)
        self.v = nn.Linear(dim, dim)

    def forward(self, x):
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)
        attn_output, attn_output_weights = self.att(x,x,x)
        return attn_output


class PatchEmbedding(nn.Module):
    def __init__(self,
                 in_channels=3,
                 patch_size=8,
                 embedding_size=224
                 ):
        super().__init__()
        self.projection = nn.Sequential(
            Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size),
            nn.Linear(patch_size * patch_size * in_channels, embedding_size)
            )

    def forward(self, x: Tensor) -> Tensor:
        return self.projection(x)


class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image):
        for t in self.transforms:
            if hasattr(t, "is_albumentation"):
                # If the transform is an Albumentations transform, apply it
                image = t(image=image)["image"]
            else:
                # If it's a torchvision transform, apply it
                image = t(image)
            #image = t(image=image)
        return image


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Sequential):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )


class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x


class ViT(nn.Module):
    def __init__(self, channels=3,
                 img_size=224,
                 patch_size=4,
                 embedding_dim=32,
                 layers=6,
                 out_dim=37,
                 dropout=0.1,
                 heads=2
                 ):
        super(ViT, self).__init__()

        # Attributes
        self.channels = channels # Number of channels in the input image (Grayscale = 1, RGB = 3)
        self.height = img_size # Height of the input image
        self.width = img_size # Width of the input image
        self.patch_size = patch_size # Size of the patches to be extracted from the input image (Think of mini images within image or kenel snapshots)
        self.n_layers = layers

        # Patching
        self.patch_embedding = PatchEmbedding(in_channels=channels,
                                              patch_size=patch_size,
                                              embedding_size=embedding_dim
                                              )

        # Learnable parameters
        num_patches = (img_size // patch_size) ** 2
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embedding_dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))

        # Transformer encoder
        self.layers = nn.ModuleList([])
        for _ in range(layers):
            transformer_block = nn.Sequential(
                ResidualAdd(PreNorm(embedding_dim, Attention(embedding_dim, heads, dropout))),
                ResidualAdd(PreNorm(embedding_dim, FeedForward(embedding_dim, embedding_dim, dropout)))
            )
            self.layers.append(transformer_block)

        # Classification head
        self.head = nn.Sequential(nn.LayerNorm(embedding_dim), nn.Linear(embedding_dim, out_dim))

    def forward(self, img):
        # Get patch embedding vectors
        img = self.patch_embedding(img)
        b, n, _ = img.shape

        # Add positional embedding to the patches
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
        img = torch.cat([cls_tokens, img], dim=1)
        img += self.pos_embedding[:, :(n + 1)]

        # Transformer layers
        for layer in self.layers:
            img = layer(img)

        # Classification head
        assigned_class = self.head(img[:, 0, :])
        return assigned_class


# Optimiser

In [4]:
class ViT_Optimiser:
    def __init__(self, dataset, img_size=224, augment_data=False):
        self.device = "mps" if torch.backends.mps.is_available() else "cpu"
        print(f' > Using device: {self.device}\n')
        if self.device == "cpu":
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            print(f' > Using device: {self.device}\n')
            if self.device == "cpu":
                print("WARNING: MPS not available, using CPU instead.")
                if input("Continue? (y/n): ") != "y":
                    exit()

        # Parameters
        self.img_size = int(img_size)
        self.augment_data = augment_data

        # Load training and validation data
        self.LoadDatasets(dataset)
        self.dataset = dataset
        self.LoadPerformance()
        print(self.modelPerformance)
        if str(dataset.__name__) not in self.modelPerformance:
            self.modelPerformance[str(dataset.__name__)] = {'Training': {'Accuracy': 0, 'F1': 0}, 'Validation': {'Accuracy': 0, 'F1': 0}, 'Model': 'ViT', 'Loss function': 'CrossEntropyLoss'}
            print(self.modelPerformance)
            self.SavePerformance()

        # Define model
        self.model = ViT(out_dim=self.num_classes).to(self.device)
        try:
            if str(dataset.__name__) in self.modelPerformance:
                self.LoadModel(dataset.__name__)
        except:
            print("No model found, training new model...")

        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)

        self.trainingCriterion = nn.CrossEntropyLoss()

        #self.testCriterion = nn.Accuracy()

    def LoadDatasets(self, dataset):
        if self.augment_data:
            print("Augmenting data...")
            trainingTransformer = A.Compose([
                    A.Rotate(limit=30, p=0.5),              # Rotate the image by up to 30 degrees with a probability of 0.5
                    A.RandomScale(scale_limit=0.2, p=0.5),  # Randomly scale the image by up to 20% with a probability of 0.5
                    A.RandomBrightnessContrast(p=0.5),      # Randomly adjust brightness and contrast with a probability of 0.5
                    A.GaussianBlur(p=0.5),                  # Apply Gaussian blur with a probability of 0.5
                    #A.RandomNoise(p=0.5),                   # Add random noise with a probability of 0.5
                    A.HorizontalFlip(p=0.5),                # Flip the image horizontally with a probability of 0.5
                    A.VerticalFlip(p=0.5),                  # Flip the image vertically with a probability of 0.5
                    #A.RandomCrop(height=224, width=224),    # Randomly crop the image to size 224x224
                    A.GridDistortion(p=0.5),                # Apply grid distortion with a probability of 0.5
                    A.Resize(height=self.img_size, width=self.img_size),   # Resize the image to the desired size
                    A.Normalize(),                          # Normalize the image                             # Convert the image to a PyTorch tensor
                    ])
        else:
            trainingTransformer = Compose([
                Resize((self.img_size, self.img_size)),
                ToTensor()]
                )
        standardTransformer = Compose([Resize((self.img_size, self.img_size)), ToTensor()])
        self.training = MedMNISTDataset(dataset, transform=trainingTransformer, dataset_type='train', img_size=self.img_size, augment_data=self.augment_data, balance_classes=True)
        self.train_loader = DataLoader(self.training, batch_size=32, shuffle=True)
        self.num_classes = self.training.num_classes

        self.validation = MedMNISTDataset(dataset, transform=standardTransformer, dataset_type='val', img_size=self.img_size)
        self.validation_loader = DataLoader(self.validation, batch_size=32, shuffle=True)

        self.test = MedMNISTDataset(dataset, transform=standardTransformer, dataset_type='test', img_size=self.img_size)
        self.test_loader = DataLoader(self.test, batch_size=32, shuffle=True)

    def RunOptimiser(self, epochs):
        print(f"Running optimiser for {epochs} epochs on {str(self.dataset.__name__)} dataset...")

        for epoch in range(epochs):
            epoch_losses = { "training": [], "validation": []}
            self.model.train()
            collected_training_output = []
            collected_validation_output = []
            collected_test_output = []
            collected_training_labels = []
            collected_validation_labels = []
            collected_test_labels = []

            # Optimise model on training data
            for step, (input, labels) in tqdm(enumerate(self.train_loader), desc=f"Epoch {epoch+1}", total=len(self.train_loader)):
                input, labels = input.to(self.device), labels.to(self.device)
                self.optimizer.zero_grad()
                #print(input.shape)
                output = self.model(input)
                loss = self.trainingCriterion(output, labels.squeeze())
                collected_training_output += output.tolist()
                collected_training_labels += labels.squeeze().tolist()
                loss.backward()
                self.optimizer.step()
                epoch_losses["training"].append(loss.item())

            trainingPerformance = self.EvaluatePerformance(collected_training_output, collected_training_labels)

            # Run model over validation data
            for step, (input, labels) in enumerate(self.validation_loader):
                input, labels = input.to(self.device), labels.to(self.device)
                output = self.model(input)
                collected_validation_output += output.tolist()
                collected_validation_labels += labels.squeeze().tolist()
                loss = self.trainingCriterion(output, labels.squeeze())
                epoch_losses["validation"].append(loss.item())

            validationPerformance = self.EvaluatePerformance(collected_validation_output, collected_validation_labels)

            print(f'\nEpoch {epoch+1}/{epochs}\n')
            print(f'   Training set:\n')
            print(f'      - Loss: {np.mean(epoch_losses["training"]):.2f} | Accuracy: {trainingPerformance["Accuracy"]:.2f} | F1: {trainingPerformance["F1"]:.2f}\n')
            print(f'   Validation set:\n')
            print(f'      - Loss: {np.mean(epoch_losses["validation"]):.2f} | Accuracy: {validationPerformance["Accuracy"]:.2f} | F1: {validationPerformance["F1"]:.2f}\n')

            # If model has better performance on validation set than previous runs, save model
            if validationPerformance["Accuracy"] > self.modelPerformance[str(self.dataset.__name__)]['Validation']['Accuracy']:
                self.modelPerformance[str(self.dataset.__name__)]['Training'] = trainingPerformance
                self.modelPerformance[str(self.dataset.__name__)]['Validation'] = validationPerformance
                self.SavePerformance()
                self.SaveModel(self.dataset.__name__)

            #print(f"\nEpoch {epoch+1}\n   - Training loss: {np.mean(epoch_losses['training'])}\n   - Validation loss: {np.mean(epoch_losses['validation'])}\n\n")

        for step, (input, labels) in enumerate(self.test_loader):
            input, labels = input.to(self.device), labels.to(self.device)
            output = self.model(input)
            collected_test_output += output.tolist()
            collected_test_labels += labels.squeeze().tolist()
            loss = self.trainingCriterion(output, labels.squeeze())

        testPerformance = self.EvaluatePerformance(collected_test_output, collected_test_labels)
        print(f'   Test set:')
        print(f'      - Loss: {loss.item():.2f} | Accuracy: {testPerformance["Accuracy"]:.2f} | F1: {testPerformance["F1"]:.2f}')

    def EvaluatePerformance(self, output, labels):
        # Make predictions available on CPU
        output_np = np.array(output)
        labels_np = np.array(labels)
        #output_np = output.detach().cpu().numpy()
        #labels_np = labels.detach().cpu().numpy()
        output_np = np.argmax(output_np, axis=1)

        # Calculate performance metrics
        accuracy = accuracy_score(labels_np, output_np)                 # Calculate accuracy
        f1 = f1_score(labels_np, output_np, average='macro')            # Calculate F1 score
        return {'Accuracy': accuracy, 'F1': f1}

    def SaveModel(self, filename):
        print("Saving model...")
        directory = 'Transformer/Models'
        if not os.path.exists(directory):
            os.makedirs(directory)

        filepath = os.path.join(directory, filename + '.pth')

        torch.save(self.model.state_dict(), filepath)
        print(f"Model saved to '{filepath}'.")

    def LoadModel(self, filename):
        print("Loading model...")
        path = 'Transformer/Models/' + filename + '.pth'
        self.model.load_state_dict(torch.load(path))
        print(f"Model loaded from '{path}'.")

    def SavePerformance(self):
        path = 'Transformer/Models/Performance.pkl'
        with open(path, 'wb') as file:
            pickle.dump(self.modelPerformance, file)

    def LoadPerformance(self):
        with open('Transformer/Models/Performance.pkl', 'rb') as file:
            self.modelPerformance = pickle.load(file)


# Tests

In [5]:
def RunViT_Test():
    # Test normalisation class / layer
    print('\nTesting normalisation PreNorm class...')
    norm = PreNorm(64, Attention(64, 8, 0.1))
    print(norm(torch.ones(10, 64, 64)).shape)
    print('Normalisation test passed.\n\n\n')

    # Test feed forward class
    print('Testing feed forward class...')
    ff = FeedForward(64, 128)
    print(ff(torch.ones(10, 64, 64)).shape)
    print('Feed forward test passed.\n\n\n')

    # Test residual attention class
    print('Testing residual attention class...')
    residual_att = ResidualAdd(Attention(64, 8, 0.))
    print(residual_att(torch.ones(10, 64, 64)).shape)
    print('Residual attention test passed.\n\n\n')

    # Test patch embedding
    print('Testing patch embedding...')
    to_tensor = [Resize((224, 224)), ToTensor()]
    dataset = PneumoniaMNIST(split='train', download=True, size=224, as_rgb=True,transform=Compose(to_tensor))
    sample_datapoint = torch.unsqueeze(dataset[0][0], 0)
    print("Initial shape: ", sample_datapoint.shape)
    print(sample_datapoint)
    embedding = PatchEmbedding()(sample_datapoint)
    print("Patches shape: ", embedding.shape)
    print('Patch embedding test passed.\n\n\n')

    # Test ViT model
    print('Testing ViT model...')
    model = ViT(out_dim=5)
    print(model)
    print(model(torch.rand(1, 3, 224, 224)))
    print('ViT model test passed.')

def GPUAccessTest():
    '''
    Check access to GPU devices for TensorFlow and PyTorch, might only work on macOS.
    '''
    # Check for TensorFlow GPU access
    print(f"\nTensorFlow has access to the following devices:\n{tf.config.list_physical_devices()}")
    # See TensorFlow version
    print(f" > TensorFlow version: {tf.__version__}")
    print(f' > Pytorch version', torch.__version__)
    print(f' > Is MPS built? {torch.backends.mps.is_built()}')
    print(f' > Is MPS available? {torch.backends.mps.is_available()}')

    device = "mps" if torch.backends.mps.is_available() else "cpu"
    print(f' > Using device: {device}\n')

    x = torch.rand(size=(3,4)).to(device)
    print(f" > Tensor: {x}")

def RunOptimisationTest(dataset, augment,balance_classes, epochs=2):
    '''
    Test the ViT optimiser class.
    '''
    optimiser = ViT_Optimiser(dataset, augment_data=augment, img_size=224)
    optimiser.RunOptimiser(epochs)

def SaveModelTest():
    '''
    Test saving and loading a model.
    '''
    trainer = ViT_Optimiser(RetinaMNIST, augment_data=True)
    trainer.RunOptimiser(2)
    model = trainer.model
    device = "mps" if torch.backends.mps.is_available() else "cpu"
    torch.save(model.state_dict(), 'Transformer/Models/RetinaModel.pth')
    print(f'Model saved on device {device}. At location: Transformers/Models/RetinaModel.pth.')

def LoadModelTest():
    '''
    Test loading a model.
    '''
    device = "mps" if torch.backends.mps.is_available() else "cpu"
    model = ViT().to(device)
    model.load_state_dict(torch.load('Transformer/Models/RetinaModel.pth'))
    model
    print('Model loaded.')
    trainer = ViT_Optimiser(RetinaMNIST, 2)
    trainer.model = model
    trainer.RunOptimiser(2)
    print('Model loaded and tested.')

def IntegratedSaveLoadTest(mode='save'):
    '''
    Test saving and loading a model using the integrated functions.
    '''
    if mode == 'save':
        trainer = ViT_Optimiser(RetinaMNIST, augment_data=False, img_size=224)
        trainer.RunOptimiser(2)
        trainer.SaveModel(trainer.dataset.__name__)
        print('Model saved succesfully.')
    elif mode == 'Load':
        trainer = ViT_Optimiser(RetinaMNIST, augment_data=True)
        trainer.LoadModel(trainer.dataset.__name__)
        trainer.RunOptimiser(2)
        print('Model loaded succesfully.')

# Run

In [1]:
#RunOptimisationTest(PneumoniaMNIST, True, True, 150)