In [None]:
import torch, torchvision
from torch.utils.data import DataLoader

import numpy as np

import matplotlib.pyplot as plt
%matplotlib widget
# use widget in vscode

plt.set_loglevel('critical')

from sklearn.metrics import roc_auc_score

from importlib import reload # while updating code this is important
import models 
from models import train_and_eval

from IPython.display import clear_output, display

import gc
import os
import time

user = 'abenneck'

In [None]:
# currently I'm not using this
def cutout(I,val=0.5):
    rows = np.random.randint(0,I.shape[-2],(2,))    
    rows.sort()
    cols = np.random.randint(0,I.shape[-1],(2,))    
    cols.sort()
    I[...,rows[0]:rows[1],cols[0]:cols[1]] = val
    return I

### Download data from the CIFAR-10 dataset

In [None]:
if user == 'abenneck':
    data_save_path = '/home/abenneck/rotnet_work/data'
else: # user == 'dtward'
    data_save_path = '/home/dtward/data'

# Define transform using mean and std to normalize the input data
normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

my_dataset = torchvision.datasets.CIFAR10(    
    data_save_path,
    train=True,
    download=True,
    transform=torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),
                                              torchvision.transforms.RandomCrop(32,padding=4,padding_mode='edge'),
                                              torchvision.transforms.ToTensor(),                                              
                                             # torchvision.transforms.Lambda(cutout)
                                              normalize,
                                             ]
                                            )
)

# Define data loader for the CIFAR-10 data
my_loader = DataLoader(my_dataset, batch_size=64, num_workers=8, shuffle=True)

### Download testing data from the CIFAR-10 dataset

In [None]:
# my_dataset_test = torchvision.datasets.CIFAR10(    
#     data_save_path,
#     train=False,
#     download=True,
#     transform=torchvision.transforms.Compose([
#                                               torchvision.transforms.ToTensor(),
#                                               normalize,
#                                               ])
# )
# my_loader_test = DataLoader(my_dataset_test, batch_size=64, num_workers=8, shuffle=True)

In [None]:
normalize = torchvision.transforms.Normalize(mean=0.5, std=0.5)
# 0.5, 0.5 is used in medmnist code

In [None]:
# PIP: Preferred Installer Program; One of the main methods for installing Python packages
# !pip install medmnist

# from medmnist import DermaMNIST as Dataset
from medmnist import BloodMNIST as Dataset
#from medmnist import OrganAMNIST as Dataset
#from medmnist import PathMNIST as Dataset
#from medmnist import TissueMNIST as Dataset # really big
#from medmnist import BreastMNIST as Dataset # much smaller

In [None]:
# Redefine normalization function + 0.5, 0.5 is used in medmnist code
normalize = torchvision.transforms.Normalize(mean=0.5, std=0.5)

my_dataset = Dataset(     
    download=True,
    split='train',    
    transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),   
                                              normalize
                                             ]
                                            )
)

# Intialize data loader for training data
my_loader = DataLoader(my_dataset, batch_size=128, num_workers=8, shuffle=True)

In [None]:
my_dataset_val = Dataset(        
    split='val',    
    transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                              normalize,
                                             ]
                                            )
)

# Initialize data loader for validation data
my_loader_val = DataLoader(my_dataset_val, batch_size=128, num_workers=8, shuffle=True)

In [None]:
my_dataset_test = Dataset(        
    split='test',    
    transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                              normalize,
                                             ]
                                            )
)

# Intialize data loader for test data
my_loader_test = DataLoader(my_dataset_test, batch_size=128, num_workers=8, shuffle=True)

# Create set of all labels in the dataset
labels = set()
for x,l in my_loader_test:
    labels_i = {li.squeeze().item() for li in l}
    labels = labels.union(labels_i)
n_labels = len(labels)    
print(n_labels)

# Define additional inputs to the model evaluation process
device = 'cuda:0'
nepochs = 5

In [None]:
def test_model(net,my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path='', verbose = False):
    """Train and evaluate a deep learning model (net) using the provided train/test/val splits over 'nepochs'

    Parameters
    ----------
    net : torch.nn.Module
        A deep learning model defined using PyTorch's nn.Module class
    my_loader : torch.utils.data.DataLoader
        A DataLoader defined using PyTorch's torch.utils.data library which will help to manage the training data being input to the model
    my_loader_val : torch.utils.data.DataLoader
        A DataLoader defined using PyTorch's torch.utils.data library which will help to manage the validation data being input to the model
    my_loader_test : torch.utils.data.DataLoader
        A DataLoader defined using PyTorch's torch.utils.data library which will help to manage the testing data being input to the model
    n0 : int
        TODO: 
    device : str
        The device on which torch computations will be performed
    nepochs : int
        The number of epochs to train the model
    out_path : str
        /the/location/fname.npz where the output dictionary of performance metrics will be saved
    verbose : bool
        (Default - False); If true, print out various parameters and performance metrics before and after training

    Returns
    -------
    out : dict
        A dictionary containing various performance metrics computed during training, testing, and validation
    
    """
    # If verbose = True, print number of parameters in model
    if verbose:
        print(models.count_parameters(net))

    # Train net using provided data loaders and return dict of performance metrics
    out = train_and_eval(net,my_loader, my_loader_val, my_loader_test, device=device, nepochs=nepochs)
    auc_val, auc_test, hard_auc_test, accuracy_test = out['auc_val'], out['auc_test'], out['hard_auc_test'], out['accuracy_test']
    ind = auc_val.index(np.max(auc_val))

    # if verbose = True, print key performance metrics
    if verbose:
        print('best auc, hard auc, accuracy')
        print(auc_test[ind],hard_auc_test[ind], accuracy_test[ind])
        print('final auc, hard auc, accuracy')
        print(auc_test[-1],hard_auc_test[-1], accuracy_test[-1])

    # Update several output values
    out['n0'] = n0
    out['n_parameters'] = models.count_parameters(net)
    out['net'] = None

    # Clear local memory
    gc.collect()
    torch.cuda.empty_cache()

    # If out_path is provided, save outputs
    if out_path != '':
        np.savez(out_path, out)

    return out

# 18 layer resnets from medmnist

## First the actual resnet

In [None]:
net = models.ResNet18(n1=n_labels)
n0 = 64
out_path = '/home/abenneck/rotnet_work/outputs/metrics/resnet18_out.npz'
resnet18_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path, verbose = True)

### Load the output data

In [None]:
out = np.load(out_path, allow_pickle = True)
out = out['arr_0'].item()
[k for k in out]

## Next the rotnet with the same number of feature maps

In [None]:
net = models.RotNet18(n1=n_labels)
n0 = 63
out_path = '/home/abenneck/rotnet_work/outputs/metrics/rotnet18_out.npz'
rotnet18_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path, verbose = True)

In [None]:
# this is the "default" which does 63 layers (close to 64, but less)
net = models.RotNet18(n1=n_labels,reflection=True)
n0 = 63
out_path = '/home/abenneck/rotnet_work/outputs/metrics/refnet18_out.npz'
refnet18_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path, verbose = True)

## Next the rotnet with the same number of independent feature maps

In [None]:
# 18 layers was used in the medmnist
# or let's try 32 scalars and 32 vectors (since they did 64 total), this would be 96 feature maps
# this is way less parameters
n0 = 96
net = models.RotNet18(n0=n0,n1=n_labels)
out_path = '/home/abenneck/rotnet_work/outputs/metrics/rotnet18_n096_out.npz'
rotnet18_n096_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path, verbose = True)

In [None]:
# Same as above, but with reflection = True
n0 = 96
net = models.RotNet18(n0=n0,n1=n_labels,reflection=True)
out_path = '/home/abenneck/rotnet_work/outputs/metrics/refnet18_n096_out.npz'
refnet18_n096_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path, verbose = True)

## Last the rotnet with the same number of parameters

In [None]:
# 18 layers was used in the medmnist
# with 126 it does best, but its a bit more parameters than the resnet below
# I could try 123, this is a tiny bit more but pretty close
# I will do 120
n0 = 120
net = models.RotNet18(n0=n0,n1=n_labels) # (05/31/24): Original n0-126, changed to 120 based on comments and fname
# 126 gives 11 million parameters, the resnet below is only 10 million, so we could do a bit less
out_path = '/home/abenneck/rotnet_work/outputs/metrics/rotnet18_n0120_out.npz'
rotnet18_n0120_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path, verbose = True)

In [None]:
# 18 layers was used in the medmnist
# for refnet I can use 159 parameters
n0 = 159
net = models.RotNet18(n0=n0,n1=n_labels,reflection=True)
# 126 gives 11 million parameters, the resnet below is only 10 million, so we could do a bit less
out_path = '/home/abenneck/rotnet_work/outputs/metrics/refnet18_n0159_out.npz'
refnet18_n0159_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path, verbose = True)

# Now the 20 layer ones from cifar

## First the resnet

In [None]:
# this is the resnet from the resnet paper for cifar10
# it has 16 channels at its input layer, and about 267K parameters
net = models.ResNet20(n1=n_labels)
n0 = 16
out_path = '/home/abenneck/rotnet_work/outputs/metrics/resnet20_out.npz'
resnet20_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path, verbose = True)

## Now the same number of feature maps

In [None]:
# 20 layers is from cifar
# default is 15 parameters, note this tends to be too little to do a good job
n0 = 15
net = models.RotNet20(n0=n0,n1=n_labels) # (05/31/24): Original script did not pass n0, so n0 was 16
out_path = '/home/abenneck/rotnet_work/outputs/metrics/rotnet20_out.npz'
rotnet20_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path, verbose = True)

In [None]:
# 20 layers is from cifar
# default is 15 parameters, note this tends to be too little to do a good job
# try with reflection equivariance too, it's way fewer parameters
n0 = 15
net = models.RotNet20(n0=n0,n1=n_labels,reflection=True) # (05/31/24): Original script did not pass n0, so n0 was 16
out_path = '/home/abenneck/rotnet_work/outputs/metrics/refnet20_out.npz'
refnet20_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path, verbose = True)

## Now the same number of independent feature maps

In [None]:
# 20 layers is from cifar
# instead of matching parameters, match the number of scalar + vector components
# the resnet20 uses 16 channels, so I'll do 8 scalar and 8 vector for 24 channels
n0 = 24
net = models.RotNet20(n0=n0,n1=n_labels)
out_path = '/home/abenneck/rotnet_work/outputs/metrics/rotnet20_n024_out.npz'
rotnet20_n024_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path, verbose = True)

In [None]:
# 20 layers is from cifar
# instead of matching parameters, match the number of scalar + vector components
# the resnet20 uses 16 channels, so I'll do 8 scalar and 8 vector for 24 channels
n0 = 24
net = models.RotNet20(n0=n0,n1=n_labels,reflection=True)
out_path = '/home/abenneck/rotnet_work/outputs/metrics/refnet20_n024_out.npz'
refnet20_n024_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path, verbose = True)

## Now the same number of parameters

In [None]:
# 20 layers is from cifar
# extra channels so it matches parameters
# below is about we want less than 260K
n0 = 30
net = models.RotNet20(n0=n0,n1=n_labels)
out_path = '/home/abenneck/rotnet_work/outputs/metrics/rotnet20_n030_out.npz'
rotnet20_n030_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path, verbose = True)

In [None]:
# 20 layers is from cifar
# extra channels so it matches parameters
# below is about we want less than 260K
n0 = 39
net = models.RotNet20(n0=n0,n1=n_labels,reflection=True)
out_path = '/home/abenneck/rotnet_work/outputs/metrics/refnet20_n039_out.npz'
refnet20_n039_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path, verbose = True)

# plot the data here

In [None]:
[k for k in out]

In [None]:
keys = ['accuracy_train', 'accuracy_val','accuracy_test', 
        'auc_train', 'auc_val', 'auc_test',
        'hard_auc_train', 'hard_auc_val', 'hard_auc_test']

# rotnet18 defaults to 63 channels, so doing 66 is not very differet
# rotnet20 had 15 channels efault
outputs = [resnet18_out,rotnet18_out,rotnet18_n096_out,rotnet18_n0120_out,refnet18_out,refnet18_n096_out, refnet18_n0159_out, resnet20_out, rotnet20_out,rotnet20_n024_out,rotnet20_n030_out,refnet20_out,refnet20_n024_out,refnet20_n039_out]
names = ['resnet18_64', 'rotnet18_63', 'rotnet18_96','rotnet18_120',      'refnet18_63','refnet18_96','refnet18_159',         'resnet20_16','rotnet20_15','rotnet20_24','rotnet20_30',       'refnet20_15','refnet20_24','refnet20_39']



In [None]:
def plot_evaluation(names,outputs,measure):
    """Plot the outputs['measure'] for every model in 'names' for the highest metric during validation and the metric on the last iteration. 'Rotnet' models
    """

    fig,ax = plt.subplots(1,2,sharey=True,sharex=True)
    rot_angle = 60

    k = '_'.join([measure,'test'])

    colors = ['r' if 'rot' in n else ('m' if 'ref' in n else 'b') for n in names]

    # this is the best
    data = []
    for out in outputs:
        ind = np.argmax(out[k.replace('test','val')])
        data.append(out[k][ind])
    
    ax[0].bar(names, data, color=colors)
    ax[0].set_ylim(np.min(data)-0.01,np.max(data)+0.01)
    for tick in ax[0].get_xticklabels():
        tick.set_rotation(rot_angle)
        tick.set_horizontalalignment('right')
    ax[0].set_title(f'{measure.upper()}, best on validation')

    # this is the last
    data1 = [out[k][-1] for out in outputs]
    ax[1].bar(names, data1, color=colors)
    #ax[1].set_ylim(np.min(data1)-0.01,np.max(data1)+0.01)
    ax[0].set_ylim(np.min(data+data1)-0.01,np.max(data+data1)+0.01)
    for tick in ax[1].get_xticklabels():
        tick.set_rotation(rot_angle)
        tick.set_horizontalalignment('right')
    ax[1].set_title(f'{measure.upper()}, last iteration')

    # for axi in ax:
    #     axi.set_ylim([0.6,1])

    fig.subplots_adjust(bottom=0.3)
    fig.supxlabel('Model Name')
    fig.supylabel('Performance Metric')

    return fig,ax

In [None]:
measures = ['accuracy','auc','hard_auc']
for measure in measures:
    fig,ax = plot_evaluation(names,outputs,measure)

# comparing
There are are a few ways we could compare models on "equal footing"

1. We could use the same number of feature maps.
1. We could use the same number of parameters
1. We could count vector components as one feature and then use the same number of feature maps.

# TODO
make sure outputs are saved.  They should be named according to their architecture.  The dataset they are trained on.  And which repeat they are.

In [None]:
def evaluate_all_models(Dataset, data_path, outdir, rep):
    
    # Redefine normalization function + 0.5, 0.5 is used in medmnist code
    normalize = torchvision.transforms.Normalize(mean=0.5, std=0.5)
    
    # Intialize data loader for training data
    my_dataset = Dataset(root=data_path, download=True, split='train', transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(), normalize]))    
    my_loader = DataLoader(my_dataset, batch_size=128, num_workers=8, shuffle=True)

    # Initialize data loader for validation data
    my_dataset_val = Dataset(root=data_path, download=True, split='val', transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(), normalize]))
    my_loader_val = DataLoader(my_dataset_val, batch_size=128, num_workers=8, shuffle=True)

    # Intialize data loader for test data
    my_dataset_test = Dataset(root=data_path, download=True, split='test', transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),normalize]))
    my_loader_test = DataLoader(my_dataset_test, batch_size=128, num_workers=8, shuffle=True)
    
    # Create set of all labels in for this dataset + count the total amount
    labels = set()
    for x,l in my_loader_test:
        labels_i = {li.squeeze().item() for li in l}
        labels = labels.union(labels_i)
    n_labels = len(labels)    
    
    # Define additional inputs to the model evaluation process
    device = 'cuda:0'
    nepochs = 5

    # Train and evaluate all models

    outdir_metrics = os.path.join(outdir,'metrics')
    if not os.path.exists(outdir_metrics):
        os.mkdir(outdir_metrics)

    start = time.time()
    # =============================
    # ===== 18 Layer Networks =====
    # =============================
    print('Starting 18 Layer Networks')
    net = models.ResNet18(n1=n_labels)
    n0 = 64
    out_path = os.path.join(outdir_metrics, f'resnet18_{Dataset.__name__}_r{rep}_out.npz')
    resnet18_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path)

    net = models.RotNet18(n1=n_labels)
    n0 = 63
    out_path = os.path.join(outdir_metrics, f'rotnet18_{Dataset.__name__}_r{rep}_out.npz')
    rotnet18_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path)

    net = models.RotNet18(n1=n_labels,reflection=True)
    n0 = 63
    out_path = os.path.join(outdir_metrics, f'refnet18_{Dataset.__name__}_r{rep}_out.npz')
    refnet18_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path)

    n0 = 96
    net = models.RotNet18(n0=n0,n1=n_labels)
    out_path = os.path.join(outdir_metrics, f'rotnet18_n096_{Dataset.__name__}_r{rep}_out.npz')
    rotnet18_n096_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path)

    n0 = 96
    net = models.RotNet18(n0=n0,n1=n_labels,reflection=True)
    out_path = os.path.join(outdir_metrics, f'refnet18_n096_{Dataset.__name__}_r{rep}_out.npz')
    refnet18_n096_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path)

    n0 = 120
    net = models.RotNet18(n0=n0,n1=n_labels)
    out_path = os.path.join(outdir_metrics, f'rotnet18_n0120_{Dataset.__name__}_r{rep}_out.npz')
    rotnet18_n0120_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path)

    n0 = 159
    net = models.RotNet18(n0=n0,n1=n_labels,reflection=True)
    out_path = os.path.join(outdir_metrics, f'refnet18_n0159_{Dataset.__name__}_r{rep}_out.npz')
    refnet18_n0159_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path)

    # =============================
    # ===== 20 Layer Networks =====
    # =============================
    print('Starting 20 Layer Networks')
    n0 = 16
    net = models.ResNet20(n1=n_labels)
    out_path = os.path.join(outdir_metrics, f'resnet20_{Dataset.__name__}_r{rep}_out.npz')
    resnet20_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path)

    n0 = 15
    net = models.RotNet20(n0=n0,n1=n_labels)
    out_path = os.path.join(outdir_metrics, f'rotnet20_{Dataset.__name__}_r{rep}_out.npz')
    rotnet20_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path)

    n0 = 15
    net = models.RotNet20(n0=n0,n1=n_labels,reflection=True)
    out_path = os.path.join(outdir_metrics, f'refnet20_{Dataset.__name__}_r{rep}_out.npz')
    refnet20_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path)

    n0 = 24
    net = models.RotNet20(n0=n0,n1=n_labels)
    out_path = os.path.join(outdir_metrics, f'rotnet20_n024_{Dataset.__name__}_r{rep}_out.npz')
    rotnet20_n024_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path)

    n0 = 24
    net = models.RotNet20(n0=n0,n1=n_labels,reflection=True)
    out_path = os.path.join(outdir_metrics, f'refnet20_n024_{Dataset.__name__}_r{rep}_out.npz')
    refnet20_n024_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path)

    n0 = 30
    net = models.RotNet20(n0=n0,n1=n_labels)
    out_path = os.path.join(outdir_metrics, f'rotnet20_n030_{Dataset.__name__}_r{rep}_out.npz')
    rotnet20_n030_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path)

    n0 = 39
    net = models.RotNet20(n0=n0,n1=n_labels,reflection=True)
    out_path = os.path.join(outdir_metrics, f'refnet20_n039_{Dataset.__name__}_r{rep}_out.npz')
    refnet20_n039_out = test_model(net, my_loader, my_loader_val, my_loader_test, n0, device, nepochs, out_path)
    print(f'Finished model training in {time.time() - start:.2f}s')

    # ======================================
    # ===== Compare model performances =====
    # ======================================

    # Define + create directory for output figures
    outdir_fig = os.path.join(outdir,'figures')
    if not os.path.exists(outdir_fig):
        os.mkdir(outdir_fig)
    
    # Define lists to be used for evaluation
    outputs = [resnet18_out, rotnet18_out, rotnet18_n096_out, rotnet18_n0120_out, refnet18_out, refnet18_n096_out, refnet18_n0159_out, resnet20_out, rotnet20_out, rotnet20_n024_out, rotnet20_n030_out, refnet20_out, refnet20_n024_out, refnet20_n039_out]
    names =   ['resnet18_64', 'rotnet18_63', 'rotnet18_96', 'rotnet18_120',      'refnet18_63','refnet18_96',     'refnet18_159',     'resnet20_16','rotnet20_15','rotnet20_24',     'rotnet20_30',     'refnet20_15','refnet20_24',     'refnet20_39']
    measures = ['accuracy','auc','hard_auc']
    
    for measure in measures:
        fig,ax = plot_evaluation(names, outputs, measure)
        if True:
            fig.savefig(os.path.join(outdir_fig,f'{Dataset.__name__}_{measure}_performance.png'))
    

In [None]:
# from medmnist import DermaMNIST as Dataset
# from medmnist import BloodMNIST as Dataset
from medmnist import OrganAMNIST as Dataset
# from medmnist import PathMNIST as Dataset
# from medmnist import TissueMNIST as Dataset # really big
# from medmnist import BreastMNIST as Dataset # much smaller

if user == 'abenneck':
    data_save_path = '/home/abenneck/rotnet_work/data'
else: # user == 'dtward'
    data_save_path = '/home/dtward/data'

outdir = '/home/abenneck/rotnet_work/outputs/'

evaluate_all_models(Dataset, data_save_path, outdir, rep=0)

In [None]:
data_path = data_save_path

# Redefine normalization function + 0.5, 0.5 is used in medmnist code
normalize = torchvision.transforms.Normalize(mean=0.5, std=0.5)

# Intialize data loader for training data
my_dataset = Dataset(root=data_path, download=True, split='train', transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(), normalize]))    
my_loader = DataLoader(my_dataset, batch_size=128, num_workers=8, shuffle=True)

# Initialize data loader for validation data
my_dataset_val = Dataset(root=data_path, download=True, split='val', transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(), normalize]))
my_loader_val = DataLoader(my_dataset_val, batch_size=128, num_workers=8, shuffle=True)

# Intialize data loader for test data
my_dataset_test = Dataset(root=data_path, download=True, split='test', transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),normalize]))
my_loader_test = DataLoader(my_dataset_test, batch_size=128, num_workers=8, shuffle=True)