This notebook was motivated by

https://arxiv.org/pdf/1512.03385.pdf

will be later cited as ResNetPaper
Implementation: Oleh Bakumenko, University of Duisburg-Essen

# Imports

In [None]:
import sys
sys.path.append("../")
import os
import numpy as np
import time
import matplotlib.pyplot as plt
import torch, torch.nn as nn
import torchvision, torchvision.transforms as tt
from torch.multiprocessing import Manager
torch.multiprocessing.set_sharing_strategy("file_system")
from pathlib import Path

from utility import utils as uu
from utility.eval import evaluate_classifier_model
from utility.confusion_matrix import calculate_confusion_matrix


from utility.trainLoopClassifier import training_loop
from utility.plotImageModel import *


# Data augmentations

Data augmentation is a technique used to artificially increase the size of a dataset by transforming existing data points to create new, similar instances. This can help prevent overfitting in machine learning models, as well as improve their ability to generalize to unseen data. Common types of data augmentation include flipping, rotation, scaling, and adding noise to images.
We can generate the augmentation list with torchvision.transforms module


In [None]:
data_augments = torchvision.transforms.Compose([ 
    torchvision.transforms.RandomHorizontalFlip(p = .5),
    torchvision.transforms.RandomVerticalFlip(p = .5),
    torchvision.transforms.ColorJitter(brightness=(0.5,1.5), contrast=(1), hue=(-0.1,0.1)),
    #torchvision.transforms.RandomCrop((224, 224)), 
    ])


Load the dataset from utils

In [None]:
cur_path = Path("plots_and_graphs.ipynb")
parent_dir = cur_path.parent.absolute()
masterThesis_folder = str(parent_dir.parent.absolute())+'/'
data_dir = masterThesis_folder+"data/Clean_LiTS/"

cache_me = False
if cache_me is True:
    cache_mgr = Manager()
    cache_mgr.data = cache_mgr.dict()
    cache_mgr.cached = cache_mgr.dict()
    for k in ["train", "val", "test"]:
        cache_mgr.data[k] = cache_mgr.dict()
        cache_mgr.cached[k] = False
# function from utils, credit: Institute for Artificial Intelligence in Medicine. url: https://mml.ikim.nrw/
# dataset outputs a tensor image (dimensions [1,256,256]) and a tensor target (0, 1 or 2)

ds = uu.LiTS_Classification_Dataset(
    data_dir=data_dir,
    transforms=data_augments,
    verbose=True,
    cache_data=cache_me,
    cache_mgr=(cache_mgr if cache_me is True else None),
    debug=True,
)

# Hyperparameters

In [None]:
# Default settings
batch_size = 32
epochs = 50
device = ("cuda" if torch.cuda.is_available() else "cpu")
time_me  = True

In [None]:
# Dataloader
dl = torch.utils.data.DataLoader(
    dataset = ds, 
    batch_size = batch_size, 
    num_workers = 4, 
    shuffle = True, 
    drop_last = False, 
    pin_memory = True,
    persistent_workers = (not cache_me),
    prefetch_factor = 1
    )

ResNet (Residual Network) is a deep neural network architecture introduced in 2015 in [2]. It was designed to address the issue of vanishing gradients in very deep networks. ResNet is named as such because it utilizes residual connections (skip connections), which enable the flow of gradients from earlier layers to later layers, even in very deep networks.

The residual connections in ResNet involve adding the input of a layer to the output of a layer that is several layers deeper. This allows the network to more easily learn identity functions. This design helps prevent the issue of vanishing gradients and enables ResNet to train much deeper networks than was previously possible. This architecture has shown significant improvements in benchmarks compared to the earlier AlexNet model.

The original ResNet was used in the ImageNet Challenge to classify 1000 classes. However, in our exercise, we only use 3 classes:
0: Image does not include the liver.
1: Liver is visible.
2: Liver is visible and a lesion is visible.

# ResNet 50

ResNet50 introduced a new structure called the bottleneck block, which consists of a sequence of [convolution -> batch normalization -> activation] repeated three times. In addition to the 3x3 convolution, the number of channels is also modified within each block using a 1x1 convolution.

To provide more flexibility for customization, the variables "number of channels in," "between," and "out" are specified for the block. In most cases, the number of input channels will be the same as the number of output channels.

A generalized solution is used to address the downsampling issue. Instead of using two different blocks as in ResNet34, a boolean variable and a stride are defined for this purpose. If downsampling is required at the beginning of resblocks3-5, we set "downsample" to True and "stride" to 2. It's important to note that there is no need to change the stride in the second part of resblocks2.
If downsampling is not needed, the operation will be set to nn.Identity().

A few words about the torch.nn.init part:
PyTorch initializes the parameters for Conv and batch norm randomly from uniform distribution. Initialization of the weights and biases with a normal distribution helps the model backpropagate gradients in early epochs.

Tests were conducted on smaller models with 18, 34, and 50 layers, indicating that for adaptive optimizers, weight and bias initialization has minimal effect on model performance or convergence.

In contrast, the uniform initialized ResNet 152 model exhibited poor convergence after 15 epochs, with very high error and low accuracy rates. Although initialization improved the performance, it still required tuning of hyperparameters and a better optimizer.

In [None]:
# ResBlockBottleneck Class
#       - constructs a block [conv -> batch_norm -> activation]*3, which we will stack in the network
# Input:    int: num_chans_in, int:num_chans_between, int:num_chans_out
#           boolean: downsample = False, set True if first block
#           int: stride = 1, set 2 if want to downsample
# Output:   nn.Sequential() block
class ResBlockBottleneck(nn.Module):
    def __init__(self, num_chans_in,n_chans_between,num_chans_out, downsample = False, stride = 1):
        super().__init__()
        self.conv1 = nn.Conv2d(num_chans_in, n_chans_between, kernel_size=1, padding=0, bias=False)
        self.batch_norm1 = nn.BatchNorm2d(num_features=n_chans_between)
        self.relu = torch.nn.ReLU()
        self.conv2 = nn.Conv2d(n_chans_between, n_chans_between, kernel_size=3, stride= stride, padding=1, bias=False)
        self.batch_norm2 = nn.BatchNorm2d(num_features=n_chans_between)
        self.relu = torch.nn.ReLU()
        self.conv3 = nn.Conv2d(n_chans_between, num_chans_out, kernel_size=1, padding=0, bias=False)
        self.batch_norm3 = nn.BatchNorm2d(num_features=num_chans_out)
        self.relu = torch.nn.ReLU()

        torch.nn.init.kaiming_normal_(self.conv1.weight,
                                      nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv2.weight,
                                      nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv3.weight,
                                      nonlinearity='relu')

        torch.nn.init.constant_(self.batch_norm1.weight, 0.5)
        torch.nn.init.zeros_(self.batch_norm1.bias)

        torch.nn.init.constant_(self.batch_norm2.weight, 0.5)
        torch.nn.init.zeros_(self.batch_norm2.bias)

        torch.nn.init.constant_(self.batch_norm3.weight, 0.5)
        torch.nn.init.zeros_(self.batch_norm3.bias)

        if downsample:
            self.downsample = nn.Sequential(
                nn.Conv2d(num_chans_in, num_chans_out, kernel_size=1,padding=0,stride=stride, bias=False),
                nn.BatchNorm2d(num_features=num_chans_out),
                nn.ReLU()
            )
        else:
            self.downsample = nn.Identity()

    def forward(self, x):
        out = self.conv1(x)
        out = self.batch_norm1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.batch_norm2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.batch_norm3(out)
        out = self.relu(out)
        return out + self.downsample(x)

# ResNetMLMed50 Class
#       - constructs a ResNet50 as described in [2, Table 1].
# Input:    Tensor: [Batch,1,Height,Width]
# Output:   Tensor: [Batch,3]
class ResNetMLMed50(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size =7, stride =2, padding=1)
        self.batch_norm1 = nn.BatchNorm2d(num_features=64)
        self.relu = torch.nn.ReLU()
        self.pool2 = torch.nn.MaxPool2d(kernel_size = 3, stride = 2,padding=1)
        self.resblocks2 = nn.Sequential(
            ResBlockBottleneck(num_chans_in = 64,n_chans_between =64 ,num_chans_out=256,downsample= True),
            *(2 * [ResBlockBottleneck(num_chans_in = 256,n_chans_between=64,num_chans_out= 256)]))
        self.resblocks3 = nn.Sequential(
            ResBlockBottleneck(num_chans_in = 256, n_chans_between=128, num_chans_out= 512, downsample=True, stride=2),
            *(3 * [ResBlockBottleneck(num_chans_in = 512,n_chans_between=128,num_chans_out= 512)]))
        self.resblocks4 = nn.Sequential(
            ResBlockBottleneck(num_chans_in = 512, n_chans_between=256, num_chans_out= 1024,downsample=True, stride=2),
            *(5 * [ResBlockBottleneck(num_chans_in = 1024,n_chans_between=256,num_chans_out= 1024)]))
        self.resblocks5 = nn.Sequential(
            ResBlockBottleneck(num_chans_in = 1024, n_chans_between=512, num_chans_out= 2048,downsample=True, stride=2),
            *(2 * [ResBlockBottleneck(num_chans_in = 2048,n_chans_between=512,num_chans_out= 2048)]))
        self.avgpool6 = nn.AdaptiveAvgPool2d(output_size=(1, 1))
        self.fc = nn.Linear(in_features=2048, out_features=3, bias=True)



    def forward(self, x):
        out_1 = self.conv1(x)
        out_1 = self.batch_norm1(out_1)
        out_1 = self.relu(out_1)
        out_1 = self.pool2(out_1)

        out_2 = self.resblocks2(out_1)

        out_3 = self.resblocks3(out_2)

        out_4= self.resblocks4(out_3)

        out_5= self.resblocks5(out_4)

        out_6 = self.avgpool6(out_5)

        out_6= self.fc(torch.flatten(out_6, start_dim=1))
        return out_6

In [None]:
model = ResNetMLMed50()
model = model.to(device)

In [None]:
learning_rate = 1e-5
run_name = "ResNet50_fixed_time_lre5"

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
criterion = nn.CrossEntropyLoss()

In [None]:
uu.csv_logger(
        logfile = f"../logs/{run_name}_hyperparams.csv",
        content = {"learning_rate": learning_rate, "batch_size": batch_size, "epochs": epochs},
        first= True,
        overwrite= True)

In [None]:
mod_step = 5000
wantToPrint = False
stop_bool = False
eval_test_10_min = False
eval_test_15_min = False
eval_test_20_min = False
skip_test_10_min = False
skip_test_15_min = False
skip_test_20_min = False

Modified training loop: The starting time is saved, and the time elapsed is calculated at the beginning of each epoch. If the elapsed time is greater than 10, 15, or 20 minutes, the boolean flag for test evaluation is set to True. However, it is important to ensure that the test evaluation happens only once. Therefore, after the calculation, the boolean flag for skipping is set to True.

During the test evaluation, the dataset mode is switched to "test", the model is switched to evaluation mode, and the test accuracy, loss, confusion matrix, and per-class accuracy are calculated and saved.

The same procedure is repeated three times for each time step.

In [None]:
train_start = time.time()

num_steps = len(ds.file_names['train'])//batch_size

for epoch in range(epochs):
    time_elapsed = time.time() - train_start
    print(f"Time_elapsed: {time_elapsed/60 :.2f} min")
    if time_elapsed > 10*60:
        eval_test_10_min = True
    if time_elapsed > 15*60:
        eval_test_15_min = True
    if time_elapsed > 20*60:
        eval_test_20_min = True

    if stop_bool:
        print('Stop time')
        break

    if eval_test_10_min and not skip_test_10_min:
        print('Evaluate after first 10 min')
        with torch.no_grad():
            ds.set_mode("test")
            model.eval()
            test_accuracy, avg_test_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
            confusion_matrix, acc = calculate_confusion_matrix(model=model, dataloader=dl, device=device)
            torch.save(confusion_matrix, f = 'confusion_matr_' + run_name+ '_10min' + '.pt')
            print(f"Evaluate after first 20 min: Test Loss: {avg_test_loss:.4f}\t Test Accuracy: {test_accuracy:.4f}, Confusion Matrix: \n{confusion_matrix}, Per-class Accuracy: {acc}")
            uu.csv_logger(
                logfile = f"../logs/{run_name}_test.csv",
                content = {"epoch": epoch,"test_phase": 1, "test_loss": avg_test_loss, "test_accuracy": test_accuracy, "time_elapsed": time_elapsed})
        skip_test_10_min = True

    if eval_test_15_min and not skip_test_15_min:
        print('Evaluate after first 15 min')
        with torch.no_grad():
            ds.set_mode("test")
            model.eval()
            test_accuracy, avg_test_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
            confusion_matrix, acc = calculate_confusion_matrix(model=model, dataloader=dl, device=device)
            torch.save(confusion_matrix, f = 'confusion_matr_' + run_name+ '_15min' + '.pt')
            print(f"Evaluate after first 20 min: Test Loss: {avg_test_loss:.4f}\t Test Accuracy: {test_accuracy:.4f}, Confusion Matrix: \n{confusion_matrix}, Per-class Accuracy: {acc}")
            uu.csv_logger(
                logfile = f"../logs/{run_name}_test.csv",
                content = {"epoch": epoch,"test_phase": 2, "test_loss": avg_test_loss, "test_accuracy": test_accuracy, "time_elapsed": time_elapsed})
        skip_test_15_min = True

    if eval_test_20_min and not skip_test_20_min:
        print('Evaluate after first 20 min')
        with torch.no_grad():
            ds.set_mode("test")
            model.eval()
            test_accuracy, avg_test_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
            confusion_matrix, acc = calculate_confusion_matrix(model=model, dataloader=dl, device=device)
            torch.save(confusion_matrix, f = 'confusion_matr_' + run_name+ '_20min' + '.pt')
            print(f"Evaluate after first 20 min: Test Loss: {avg_test_loss:.4f}\t Test Accuracy: {test_accuracy:.4f}, Confusion Matrix: \n{confusion_matrix}, Per-class Accuracy: {acc}")
            uu.csv_logger(
                logfile = f"../logs/{run_name}_test.csv",
                content = {"epoch": epoch,"test_phase": 3, "test_loss": avg_test_loss, "test_accuracy": test_accuracy, "time_elapsed": time_elapsed})
        stop_bool=True
        break

    # Go to train mode
    ds.set_mode("train")
    model.train()

    # Train loop
    for step, (data, targets) in enumerate(dl):

        # Manually drop last batch (this is for example relevant with BatchNorm)
        if step == num_steps - 1 and (epoch > 0 or ds.cache_data is False):
            continue

        # Train loop: Zero gradients, forward step, evaluate, log, backward step
        optimizer.zero_grad()
        data, targets = data.to(device), targets.to(device)
        predictions = model(data)
        loss = criterion(predictions, targets)
        loss.backward()
        optimizer.step()

    # Go to eval mode
    ds.set_mode("val")
    model.eval()

    # Validation loop
    val_accuracy, avg_val_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
    print(f"Epoch [{epoch+1}/{epochs}]\t Val Loss: {avg_val_loss:.4f}\t Val Accuracy: {val_accuracy:.4f}")
    uu.csv_logger(
        logfile = f"../logs/{run_name}_val.csv",
        content = {"epoch": epoch, "val_loss": avg_val_loss, "val_accuracy": val_accuracy},
        first = (epoch == 0),
        overwrite = (epoch == 0)
            )

---

In [None]:
run_name = "ResNet50_fixed_time_lre4"
learning_rate = 1e-4

In [None]:
del model

In [None]:
model = ResNetMLMed50()
model = model.to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
criterion = nn.CrossEntropyLoss()

In [None]:
run_name

In [None]:
stop_bool = False
eval_test_10_min = False
eval_test_15_min = False
eval_test_20_min = False
skip_test_10_min = False
skip_test_15_min = False
skip_test_20_min = False

In [None]:
train_start = time.time()

num_steps = len(ds.file_names['train'])//batch_size

for epoch in range(epochs):
    time_elapsed = time.time() - train_start
    print(f"Time_elapsed: {time_elapsed/60 :.2f} min")
    if time_elapsed > 10*60:
        eval_test_10_min = True
    if time_elapsed > 15*60:
        eval_test_15_min = True
    if time_elapsed > 20*60:
        eval_test_20_min = True

    if stop_bool:
        print('Stop time')
        break

    if eval_test_10_min and not skip_test_10_min:
        print('Evaluate after first 10 min')
        with torch.no_grad():
            ds.set_mode("test")
            model.eval()
            test_accuracy, avg_test_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
            confusion_matrix, acc = calculate_confusion_matrix(model=model, dataloader=dl, device=device)
            torch.save(confusion_matrix, f = 'confusion_matr_' + run_name+ '_10min' + '.pt')
            print(f"Evaluate after first 20 min: Test Loss: {avg_test_loss:.4f}\t Test Accuracy: {test_accuracy:.4f}, Confusion Matrix: \n{confusion_matrix}, Per-class Accuracy: {acc}")
            uu.csv_logger(
                logfile = f"../logs/{run_name}_test.csv",
                content = {"epoch": epoch,"test_phase": 1, "test_loss": avg_test_loss, "test_accuracy": test_accuracy, "time_elapsed": time_elapsed})
        skip_test_10_min = True

    if eval_test_15_min and not skip_test_15_min:
        print('Evaluate after first 15 min')
        with torch.no_grad():
            ds.set_mode("test")
            model.eval()
            test_accuracy, avg_test_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
            confusion_matrix, acc = calculate_confusion_matrix(model=model, dataloader=dl, device=device)
            torch.save(confusion_matrix, f = 'confusion_matr_' + run_name+ '_15min' + '.pt')
            print(f"Evaluate after first 20 min: Test Loss: {avg_test_loss:.4f}\t Test Accuracy: {test_accuracy:.4f}, Confusion Matrix: \n{confusion_matrix}, Per-class Accuracy: {acc}")
            uu.csv_logger(
                logfile = f"../logs/{run_name}_test.csv",
                content = {"epoch": epoch,"test_phase": 2, "test_loss": avg_test_loss, "test_accuracy": test_accuracy, "time_elapsed": time_elapsed})
        skip_test_15_min = True

    if eval_test_20_min and not skip_test_20_min:
        print('Evaluate after first 20 min')
        with torch.no_grad():
            ds.set_mode("test")
            model.eval()
            test_accuracy, avg_test_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
            confusion_matrix, acc = calculate_confusion_matrix(model=model, dataloader=dl, device=device)
            torch.save(confusion_matrix, f = 'confusion_matr_' + run_name+ '_20min' + '.pt')
            print(f"Evaluate after first 20 min: Test Loss: {avg_test_loss:.4f}\t Test Accuracy: {test_accuracy:.4f}, Confusion Matrix: \n{confusion_matrix}, Per-class Accuracy: {acc}")
            uu.csv_logger(
                logfile = f"../logs/{run_name}_test.csv",
                content = {"epoch": epoch,"test_phase": 3, "test_loss": avg_test_loss, "test_accuracy": test_accuracy, "time_elapsed": time_elapsed})
        stop_bool=True
        break

    # Go to train mode
    ds.set_mode("train")
    model.train()

    # Train loop
    for step, (data, targets) in enumerate(dl):

        # Manually drop last batch (this is for example relevant with BatchNorm)
        if step == num_steps - 1 and (epoch > 0 or ds.cache_data is False):
            continue

        # Train loop: Zero gradients, forward step, evaluate, log, backward step
        optimizer.zero_grad()
        data, targets = data.to(device), targets.to(device)
        predictions = model(data)
        loss = criterion(predictions, targets)
        loss.backward()
        optimizer.step()

    # Go to eval mode
    ds.set_mode("val")
    model.eval()

    # Validation loop
    val_accuracy, avg_val_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
    print(f"Epoch [{epoch+1}/{epochs}]\t Val Loss: {avg_val_loss:.4f}\t Val Accuracy: {val_accuracy:.4f}")
    uu.csv_logger(
        logfile = f"../logs/{run_name}_val.csv",
        content = {"epoch": epoch, "val_loss": avg_val_loss, "val_accuracy": val_accuracy},
        first = (epoch == 0),
        overwrite = (epoch == 0)
            )

---

In [None]:
run_name = "ResNet50_fixed_lre3"
learning_rate = 1e-3

In [None]:
del model

In [None]:
model = ResNetMLMed50()
model = model.to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
criterion = nn.CrossEntropyLoss()

In [None]:
stop_bool = False
eval_test_10_min = False
eval_test_15_min = False
eval_test_20_min = False
skip_test_10_min = False
skip_test_15_min = False
skip_test_20_min = False

In [None]:
train_start = time.time()

num_steps = len(ds.file_names['train'])//batch_size

for epoch in range(epochs):
    time_elapsed = time.time() - train_start
    print(f"Time_elapsed: {time_elapsed/60 :.2f} min")
    if time_elapsed > 10*60:
        eval_test_10_min = True
    if time_elapsed > 15*60:
        eval_test_15_min = True
    if time_elapsed > 20*60:
        eval_test_20_min = True

    if stop_bool:
        print('Stop time')
        break

    if eval_test_10_min and not skip_test_10_min:
        print('Evaluate after first 10 min')
        with torch.no_grad():
            ds.set_mode("test")
            model.eval()
            test_accuracy, avg_test_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
            confusion_matrix, acc = calculate_confusion_matrix(model=model, dataloader=dl, device=device)
            torch.save(confusion_matrix, f = 'confusion_matr_' + run_name+ '_10min' + '.pt')
            print(f"Evaluate after first 20 min: Test Loss: {avg_test_loss:.4f}\t Test Accuracy: {test_accuracy:.4f}, Confusion Matrix: \n{confusion_matrix}, Per-class Accuracy: {acc}")
            uu.csv_logger(
                logfile = f"../logs/{run_name}_test.csv",
                content = {"epoch": epoch,"test_phase": 1, "test_loss": avg_test_loss, "test_accuracy": test_accuracy, "time_elapsed": time_elapsed})
        skip_test_10_min = True

    if eval_test_15_min and not skip_test_15_min:
        print('Evaluate after first 15 min')
        with torch.no_grad():
            ds.set_mode("test")
            model.eval()
            test_accuracy, avg_test_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
            confusion_matrix, acc = calculate_confusion_matrix(model=model, dataloader=dl, device=device)
            torch.save(confusion_matrix, f = 'confusion_matr_' + run_name+ '_15min' + '.pt')
            print(f"Evaluate after first 20 min: Test Loss: {avg_test_loss:.4f}\t Test Accuracy: {test_accuracy:.4f}, Confusion Matrix: \n{confusion_matrix}, Per-class Accuracy: {acc}")
            uu.csv_logger(
                logfile = f"../logs/{run_name}_test.csv",
                content = {"epoch": epoch,"test_phase": 2, "test_loss": avg_test_loss, "test_accuracy": test_accuracy, "time_elapsed": time_elapsed})
        skip_test_15_min = True

    if eval_test_20_min and not skip_test_20_min:
        print('Evaluate after first 20 min')
        with torch.no_grad():
            ds.set_mode("test")
            model.eval()
            test_accuracy, avg_test_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
            confusion_matrix, acc = calculate_confusion_matrix(model=model, dataloader=dl, device=device)
            torch.save(confusion_matrix, f = 'confusion_matr_' + run_name+ '_20min' + '.pt')
            print(f"Evaluate after first 20 min: Test Loss: {avg_test_loss:.4f}\t Test Accuracy: {test_accuracy:.4f}, Confusion Matrix: \n{confusion_matrix}, Per-class Accuracy: {acc}")
            uu.csv_logger(
                logfile = f"../logs/{run_name}_test.csv",
                content = {"epoch": epoch,"test_phase": 3, "test_loss": avg_test_loss, "test_accuracy": test_accuracy, "time_elapsed": time_elapsed})
        stop_bool=True
        break

    # Go to train mode
    ds.set_mode("train")
    model.train()

    # Train loop
    for step, (data, targets) in enumerate(dl):

        # Manually drop last batch (this is for example relevant with BatchNorm)
        if step == num_steps - 1 and (epoch > 0 or ds.cache_data is False):
            continue

        # Train loop: Zero gradients, forward step, evaluate, log, backward step
        optimizer.zero_grad()
        data, targets = data.to(device), targets.to(device)
        predictions = model(data)
        loss = criterion(predictions, targets)
        loss.backward()
        optimizer.step()

    # Go to eval mode
    ds.set_mode("val")
    model.eval()

    # Validation loop
    val_accuracy, avg_val_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
    print(f"Epoch [{epoch+1}/{epochs}]\t Val Loss: {avg_val_loss:.4f}\t Val Accuracy: {val_accuracy:.4f}")
    uu.csv_logger(
        logfile = f"../logs/{run_name}_val.csv",
        content = {"epoch": epoch, "val_loss": avg_val_loss, "val_accuracy": val_accuracy},
        first = (epoch == 0),
        overwrite = (epoch == 0)
            )