This notebook was motivated by

[2] Kaiming He et al. ‘Deep Residual Learning for Image Recognition’. In: CoRR abs/1512.03385 (2015). arXiv: 1512.03385.
url: http: //arxiv.org/abs/1512.03385.

Implementation: Oleh Bakumenko, University of Duisburg-Essen

# Imports

In [None]:
import sys
#sys.path.append("/datashare/MLCourse/Course_Materials") # Preferentially import from the datashare.
sys.path.append("../") # Otherwise, import from the local folder's parent folder, where your stuff lives.

import os
import numpy as np
import time
import matplotlib.pyplot as plt
import torch, torch.nn as nn
import torchvision, torchvision.transforms as tt
from torch.multiprocessing import Manager
torch.multiprocessing.set_sharing_strategy("file_system")
from pathlib import Path

from utility import utils as uu
from utility.eval import evaluate_classifier_model
from utility.confusion_matrix import calculate_confusion_matrix


from utility.trainLoopClassifier import training_loop
from utility.plotImageModel import *


# Data augmentations

Data augmentation is a technique used to artificially increase the size of a dataset by transforming existing data points to create new, similar instances. This can help prevent overfitting in machine learning models, as well as improve their ability to generalize to unseen data. Common types of data augmentation include flipping, rotation, scaling, and adding noise to images.
We can generate the augmentation list with torchvision.transforms module


In [None]:
data_augments = torchvision.transforms.Compose([ 
    torchvision.transforms.RandomHorizontalFlip(p = .5),
    torchvision.transforms.RandomVerticalFlip(p = .5),
    torchvision.transforms.ColorJitter(brightness=(0.5,1.5), contrast=(1), hue=(-0.1,0.1)),
    #torchvision.transforms.RandomCrop((224, 224)), 
    ])


Load the dataset from utils

In [None]:
cur_path = Path("plots_and_graphs.ipynb")
parent_dir = cur_path.parent.absolute()
masterThesis_folder = str(parent_dir.parent.absolute())+'/'
data_dir = masterThesis_folder+"data/Clean_LiTS/"

cache_me = False
if cache_me is True:
    cache_mgr = Manager()
    cache_mgr.data = cache_mgr.dict()
    cache_mgr.cached = cache_mgr.dict()
    for k in ["train", "val", "test"]:
        cache_mgr.data[k] = cache_mgr.dict()
        cache_mgr.cached[k] = False
# function from utils, credit: Institute for Artificial Intelligence in Medicine. url: https://mml.ikim.nrw/
# dataset outputs a tensor image (dimensions [1,256,256]) and a tensor target (0, 1 or 2)

ds = uu.LiTS_Classification_Dataset(
    data_dir=data_dir,
    transforms=data_augments,
    verbose=True,
    cache_data=cache_me,
    cache_mgr=(cache_mgr if cache_me is True else None),
    debug=True,
)

# Hyperparameters

In [None]:
# Default settings
batch_size = 32
epochs = 50
device = ("cuda" if torch.cuda.is_available() else "cpu")
time_me  = True

The `torch.utils.data.DataLoader` is a utility class in PyTorch that makes the loading and batching of data for training purposes faster. It simplifies the process by allowing us to specify the dataset, batch size (often 32), and whether the data should be shuffled before each epoch. Additionally, there are other parameters available to further customize the data loading process.

In [None]:
# Dataloader
dl = torch.utils.data.DataLoader(
    dataset = ds, 
    batch_size = batch_size, 
    num_workers = 4, 
    shuffle = True, 
    drop_last = False, 
    pin_memory = True,
    persistent_workers = (not cache_me),
    prefetch_factor = 1
    )

# ResNet 34

It is strongly recommended to parallel look into Table 1 (page 5) and Figure 5 (page 6), ResNetPaper,

Implementing the normal ResNet Block = [conv -> batch_norm -> activation] *2

At the beginnig of each new layer (in the Table 1, left) the image size will be reduced using convolution with kernel 1 and a stride of 2 (so-called projection), this feature was generalised in the implemention of ResNet 50 below. As an example it was decided to include both variations.

First we start with building the blocks. Note the downsampling operation in the ResBlockDimsReduction, because the input image $x$ has different dimentions that the output. If this is not clear, try print(out.shape).

Class ResNetMLMed34 will inherit the torch.nn.module, so we need to write the init() and forward() functions. Using the Table 1 and Figure 5 form ResNetPaper we define each resblocks2-5 part, the indexing is the same as in Table 1 so the one can compare number blocks, kernel sizes and number channels.
Do not forget to put downsampling block as the first in each resblocks2-5


Couple words about torch.nn.init. part:
Pytorch initialise the parameters for Conv and batch norm randomly. Initialization of the weights and biases in a normal distribution helps the model backtrack gradients in early epoch's.
For smaller models like 34 and 50 layer it was tested, that initialization of the weights and biases has almost no impact on performance or convergence of the model.

For ResNet 152 on the other hand, random initialised model did not converge after 15 epochs and showed very bad error and accuracy rates. With initialization, it still was not great, but may could be tuned by the hyperparameters and better optimizer.

In [None]:
# ResBlock Class
#       - constructs a block [conv -> batch_norm -> activation] *2, which we will stack in the network
# Input:    int: n_chans - number channels
# Output:   nn.Sequential() block

class ResBlock(nn.Module):
    def __init__(self, n_chans):
        super().__init__()
        self.conv1 = nn.Conv2d(n_chans, n_chans, kernel_size=3, padding=1, bias= False)
        self.batch_norm1 = nn.BatchNorm2d(num_features=n_chans)
        self.relu = torch.nn.ReLU()
        self.conv2 = nn.Conv2d(n_chans, n_chans, kernel_size=3, padding=1, bias= False)
        self.batch_norm2 = nn.BatchNorm2d(num_features=n_chans)

        torch.nn.init.kaiming_normal_(self.conv1.weight,
                                      nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv2.weight,
                                      nonlinearity='relu')

        torch.nn.init.constant_(self.batch_norm1.weight, 0.5)
        torch.nn.init.zeros_(self.batch_norm1.bias)

        torch.nn.init.constant_(self.batch_norm2.weight, 0.5)
        torch.nn.init.zeros_(self.batch_norm2.bias)

    def forward(self, x):
        out = self.conv1(x)
        out = self.batch_norm1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.batch_norm2(out)
        out = self.relu(out)
        return out + x # this sum realise the skip connection


# ResBlockDimsReduction Class
#       - constructs a first block in the layer
#       - [conv -> batch_norm -> activation] *2
#       - downsampling performed with stride 2
# Input:    int: num_chans_in; int:num_chans_out
# Output:   nn.Sequential() block

class ResBlockDimsReduction(nn.Module):
    def __init__(self, num_chans_in, num_chans_out):
        super().__init__()
        self.conv1 = nn.Conv2d(num_chans_in, num_chans_out, kernel_size=3, stride=2,padding=1,bias= False)
        self.batch_norm1 = nn.BatchNorm2d(num_features=num_chans_out)
        self.relu = torch.nn.ReLU()
        self.conv2 = nn.Conv2d(num_chans_out, num_chans_out, kernel_size=3, padding=1, bias= False)
        self.batch_norm2 = nn.BatchNorm2d(num_features=num_chans_out)

        torch.nn.init.kaiming_normal_(self.conv1.weight,
                                      nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv2.weight,
                                      nonlinearity='relu')
        torch.nn.init.constant_(self.batch_norm1.weight, 0.5)
        torch.nn.init.zeros_(self.batch_norm1.bias)
        torch.nn.init.constant_(self.batch_norm2.weight, 0.5)
        torch.nn.init.zeros_(self.batch_norm2.bias)

        self.downsample = nn.Sequential(
            nn.Conv2d(num_chans_in, num_chans_out, kernel_size=1, stride=2,bias= False),
            nn.BatchNorm2d(num_features=num_chans_out),
            nn.ReLU()
        )


    def forward(self, x):
        out = self.conv1(x)
        out = self.batch_norm1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.batch_norm2(out)
        out = self.relu(out)
        # input and output dimensions not match, so we need to project x into the dimensions of out
        x = self.downsample(x)
        return out + x

# ResNetMLMed34 Class
#       - constructs a ResNet34 as described [2, Table 1].
# Input:    Tensor: [Batch,1,Height,Width]
# Output:   Tensor: [Batch,3]
class ResNetMLMed34(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size =7, stride =2, padding=1, bias= False)
        self.batch_norm1 = nn.BatchNorm2d(num_features=64)
        self.pool2 = torch.nn.MaxPool2d(kernel_size = 3, stride = 2)
        self.relu = torch.nn.ReLU()

        self.resblocks2 =nn.Sequential(
            *(3 * [ResBlock(n_chans=64)]))
        self.resblocks3 = nn.Sequential(ResBlockDimsReduction(num_chans_in=64,num_chans_out=128),
            *(3 * [ResBlock(n_chans=128)]))
        self.resblocks4 = nn.Sequential(ResBlockDimsReduction(num_chans_in=128,num_chans_out=256),
            *(5 * [ResBlock(n_chans=256)]))
        self.resblocks5 = nn.Sequential(ResBlockDimsReduction(num_chans_in=256,num_chans_out=512),
            *(2 * [ResBlock(n_chans=512)]))
        self.avgpool6 = nn.AdaptiveAvgPool2d(output_size=(1, 1))
        self.fc = nn.Linear(in_features=512, out_features=3, bias=True)


    def forward(self, x):

        out_1 = self.conv1(x)
        out_1 = self.batch_norm1(out_1)
        out_1 = self.relu(out_1)

        out_1 = self.pool2(out_1)

        out_2 = self.resblocks2(out_1)

        out_3 = self.resblocks3(out_2)

        out_4 = self.resblocks4(out_3)

        out_5 = self.resblocks5(out_4)

        out_6 = self.avgpool6(out_5)

        out_6= self.fc(torch.flatten(out_6, start_dim=1))

        return out_6

In [None]:
model = ResNetMLMed34()
model = model.to(device)

In [None]:
for step, (data, targets) in enumerate(dl):
    data, targets = data.to(device), targets.to(device)
    if step ==1:
        break

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
criterion = nn.CrossEntropyLoss()

In [None]:
learning_rate = 1e-5
run_name = "ResNet34_fixed_time_lre5"

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
criterion = nn.CrossEntropyLoss()

In [None]:
uu.csv_logger(
        logfile = f"../logs/{run_name}_hyperparams.csv",
        content = {"learning_rate": learning_rate, "batch_size": batch_size, "epochs": epochs},
        first= True,
        overwrite= True)

In [None]:
mod_step = 5000
wantToPrint = False
stop_bool = False
eval_test_10_min = False
eval_test_15_min = False
eval_test_20_min = False
skip_test_10_min = False
skip_test_15_min = False
skip_test_20_min = False

Modified training loop: The starting time is saved, and the time elapsed is calculated at the beginning of each epoch. If the elapsed time is greater than 10, 15, or 20 minutes, the boolean flag for test evaluation is set to True. However, it is important to ensure that the test evaluation happens only once. Therefore, after the calculation, the boolean flag for skipping is set to True.

During the test evaluation, the dataset mode is switched to "test", the model is switched to evaluation mode, and the test accuracy, loss, confusion matrix, and per-class accuracy are calculated and saved.

The same procedure is repeated three times for each time step.

In [None]:
train_start = time.time()

num_steps = len(ds.file_names['train'])//batch_size

for epoch in range(epochs):
    time_elapsed = time.time() - train_start
    print(f"Time_elapsed: {time_elapsed/60 :.2f} min")
    if time_elapsed > 10*60:
        eval_test_10_min = True
    if time_elapsed > 15*60:
        eval_test_15_min = True
    if time_elapsed > 20*60:
        eval_test_20_min = True

    if stop_bool:
        print('Stop time')
        break

    if eval_test_10_min and not skip_test_10_min:
        print('Evaluate after first 10 min')
        with torch.no_grad():
            ds.set_mode("test")
            model.eval()
            test_accuracy, avg_test_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
            confusion_matrix, acc = calculate_confusion_matrix(model=model, dataloader=dl, device=device)
            torch.save(confusion_matrix, f = 'confusion_matr_' + run_name+ '_10min' + '.pt')
            print(f"Evaluate after first 20 min: Test Loss: {avg_test_loss:.4f}\t Test Accuracy: {test_accuracy:.4f}, Confusion Matrix: \n{confusion_matrix}, Per-class Accuracy: {acc}")
            uu.csv_logger(
                logfile = f"../logs/{run_name}_test.csv",
                content = {"epoch": epoch,"test_phase": 1, "test_loss": avg_test_loss, "test_accuracy": test_accuracy, "time_elapsed": time_elapsed})
        skip_test_10_min = True

    if eval_test_15_min and not skip_test_15_min:
        print('Evaluate after first 15 min')
        with torch.no_grad():
            ds.set_mode("test")
            model.eval()
            test_accuracy, avg_test_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
            confusion_matrix, acc = calculate_confusion_matrix(model=model, dataloader=dl, device=device)
            torch.save(confusion_matrix, f = 'confusion_matr_' + run_name+ '_15min' + '.pt')
            print(f"Evaluate after first 20 min: Test Loss: {avg_test_loss:.4f}\t Test Accuracy: {test_accuracy:.4f}, Confusion Matrix: \n{confusion_matrix}, Per-class Accuracy: {acc}")
            uu.csv_logger(
                logfile = f"../logs/{run_name}_test.csv",
                content = {"epoch": epoch,"test_phase": 2, "test_loss": avg_test_loss, "test_accuracy": test_accuracy, "time_elapsed": time_elapsed})
        skip_test_15_min = True

    if eval_test_20_min and not skip_test_20_min:
        print('Evaluate after first 20 min')
        with torch.no_grad():
            ds.set_mode("test")
            model.eval()
            test_accuracy, avg_test_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
            confusion_matrix, acc = calculate_confusion_matrix(model=model, dataloader=dl, device=device)
            torch.save(confusion_matrix, f = 'confusion_matr_' + run_name+ '_20min' + '.pt')
            print(f"Evaluate after first 20 min: Test Loss: {avg_test_loss:.4f}\t Test Accuracy: {test_accuracy:.4f}, Confusion Matrix: \n{confusion_matrix}, Per-class Accuracy: {acc}")
            uu.csv_logger(
                logfile = f"../logs/{run_name}_test.csv",
                content = {"epoch": epoch,"test_phase": 3, "test_loss": avg_test_loss, "test_accuracy": test_accuracy, "time_elapsed": time_elapsed})
        stop_bool=True
        break

    # Go to train mode
    ds.set_mode("train")
    model.train()

    # Train loop
    for step, (data, targets) in enumerate(dl):

        # Manually drop last batch (this is for example relevant with BatchNorm)
        if step == num_steps - 1 and (epoch > 0 or ds.cache_data is False):
            continue

        # Train loop: Zero gradients, forward step, evaluate, log, backward step
        optimizer.zero_grad()
        data, targets = data.to(device), targets.to(device)
        predictions = model(data)
        loss = criterion(predictions, targets)
        loss.backward()
        optimizer.step()

    # Go to eval mode
    ds.set_mode("val")
    model.eval()

    # Validation loop
    val_accuracy, avg_val_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
    print(f"Epoch [{epoch+1}/{epochs}]\t Val Loss: {avg_val_loss:.4f}\t Val Accuracy: {val_accuracy:.4f}")
    uu.csv_logger(
        logfile = f"../logs/{run_name}_val.csv",
        content = {"epoch": epoch, "val_loss": avg_val_loss, "val_accuracy": val_accuracy},
        first = (epoch == 0),
        overwrite = (epoch == 0)
            )

---

In [None]:
run_name = "ResNet34_fixed_time_lre4"
learning_rate = 1e-4

In [None]:
del model

In [None]:
model = ResNetMLMed34()
model = model.to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
criterion = nn.CrossEntropyLoss()

In [None]:
stop_bool = False
eval_test_10_min = False
eval_test_15_min = False
eval_test_20_min = False
skip_test_10_min = False
skip_test_15_min = False
skip_test_20_min = False

In [None]:
train_start = time.time()

num_steps = len(ds.file_names['train'])//batch_size

for epoch in range(epochs):
    time_elapsed = time.time() - train_start
    print(f"Time_elapsed: {time_elapsed/60 :.2f} min")
    if time_elapsed > 10*60:
        eval_test_10_min = True
    if time_elapsed > 15*60:
        eval_test_15_min = True
    if time_elapsed > 20*60:
        eval_test_20_min = True

    if stop_bool:
        print('Stop time')
        break

    if eval_test_10_min and not skip_test_10_min:
        print('Evaluate after first 10 min')
        with torch.no_grad():
            ds.set_mode("test")
            model.eval()
            test_accuracy, avg_test_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
            confusion_matrix, acc = calculate_confusion_matrix(model=model, dataloader=dl, device=device)
            torch.save(confusion_matrix, f = 'confusion_matr_' + run_name+ '_10min' + '.pt')
            print(f"Evaluate after first 20 min: Test Loss: {avg_test_loss:.4f}\t Test Accuracy: {test_accuracy:.4f}, Confusion Matrix: \n{confusion_matrix}, Per-class Accuracy: {acc}")
            uu.csv_logger(
                logfile = f"../logs/{run_name}_test.csv",
                content = {"epoch": epoch,"test_phase": 1, "test_loss": avg_test_loss, "test_accuracy": test_accuracy, "time_elapsed": time_elapsed})
        skip_test_10_min = True

    if eval_test_15_min and not skip_test_15_min:
        print('Evaluate after first 15 min')
        with torch.no_grad():
            ds.set_mode("test")
            model.eval()
            test_accuracy, avg_test_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
            confusion_matrix, acc = calculate_confusion_matrix(model=model, dataloader=dl, device=device)
            torch.save(confusion_matrix, f = 'confusion_matr_' + run_name+ '_15min' + '.pt')
            print(f"Evaluate after first 20 min: Test Loss: {avg_test_loss:.4f}\t Test Accuracy: {test_accuracy:.4f}, Confusion Matrix: \n{confusion_matrix}, Per-class Accuracy: {acc}")
            uu.csv_logger(
                logfile = f"../logs/{run_name}_test.csv",
                content = {"epoch": epoch,"test_phase": 2, "test_loss": avg_test_loss, "test_accuracy": test_accuracy, "time_elapsed": time_elapsed})
        skip_test_15_min = True

    if eval_test_20_min and not skip_test_20_min:
        print('Evaluate after first 20 min')
        with torch.no_grad():
            ds.set_mode("test")
            model.eval()
            test_accuracy, avg_test_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
            confusion_matrix, acc = calculate_confusion_matrix(model=model, dataloader=dl, device=device)
            torch.save(confusion_matrix, f = 'confusion_matr_' + run_name+ '_20min' + '.pt')
            print(f"Evaluate after first 20 min: Test Loss: {avg_test_loss:.4f}\t Test Accuracy: {test_accuracy:.4f}, Confusion Matrix: \n{confusion_matrix}, Per-class Accuracy: {acc}")
            uu.csv_logger(
                logfile = f"../logs/{run_name}_test.csv",
                content = {"epoch": epoch,"test_phase": 3, "test_loss": avg_test_loss, "test_accuracy": test_accuracy, "time_elapsed": time_elapsed})
        stop_bool=True
        break

    # Go to train mode
    ds.set_mode("train")
    model.train()

    # Train loop
    for step, (data, targets) in enumerate(dl):

        # Manually drop last batch (this is for example relevant with BatchNorm)
        if step == num_steps - 1 and (epoch > 0 or ds.cache_data is False):
            continue

        # Train loop: Zero gradients, forward step, evaluate, log, backward step
        optimizer.zero_grad()
        data, targets = data.to(device), targets.to(device)
        predictions = model(data)
        loss = criterion(predictions, targets)
        loss.backward()
        optimizer.step()

    # Go to eval mode
    ds.set_mode("val")
    model.eval()

    # Validation loop
    val_accuracy, avg_val_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
    print(f"Epoch [{epoch+1}/{epochs}]\t Val Loss: {avg_val_loss:.4f}\t Val Accuracy: {val_accuracy:.4f}")
    uu.csv_logger(
        logfile = f"../logs/{run_name}_val.csv",
        content = {"epoch": epoch, "val_loss": avg_val_loss, "val_accuracy": val_accuracy},
        first = (epoch == 0),
        overwrite = (epoch == 0)
            )

---

In [None]:
run_name = "ResNet34_fixed_time_rerun_1004_lre3"
learning_rate = 1e-3

In [None]:
del model

In [None]:
model = ResNetMLMed34()
model = model.to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
criterion = nn.CrossEntropyLoss()

In [None]:
stop_bool = False
eval_test_10_min = False
eval_test_15_min = False
eval_test_20_min = False
skip_test_10_min = False
skip_test_15_min = False
skip_test_20_min = False

In [None]:
train_start = time.time()

num_steps = len(ds.file_names['train'])//batch_size

for epoch in range(epochs):
    time_elapsed = time.time() - train_start
    print(f"Time_elapsed: {time_elapsed/60 :.2f} min")
    if time_elapsed > 10*60:
        eval_test_10_min = True
    if time_elapsed > 15*60:
        eval_test_15_min = True
    if time_elapsed > 20*60:
        eval_test_20_min = True

    if stop_bool:
        print('Stop time')
        break

    if eval_test_10_min and not skip_test_10_min:
        print('Evaluate after first 10 min')
        with torch.no_grad():
            ds.set_mode("test")
            model.eval()
            test_accuracy, avg_test_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
            confusion_matrix, acc = calculate_confusion_matrix(model=model, dataloader=dl, device=device)
            torch.save(confusion_matrix, f = 'confusion_matr_' + run_name+ '_10min' + '.pt')
            print(f"Evaluate after first 20 min: Test Loss: {avg_test_loss:.4f}\t Test Accuracy: {test_accuracy:.4f}, Confusion Matrix: \n{confusion_matrix}, Per-class Accuracy: {acc}")
            uu.csv_logger(
                logfile = f"../logs/{run_name}_test.csv",
                content = {"epoch": epoch,"test_phase": 1, "test_loss": avg_test_loss, "test_accuracy": test_accuracy, "time_elapsed": time_elapsed})
        skip_test_10_min = True

    if eval_test_15_min and not skip_test_15_min:
        print('Evaluate after first 15 min')
        with torch.no_grad():
            ds.set_mode("test")
            model.eval()
            test_accuracy, avg_test_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
            confusion_matrix, acc = calculate_confusion_matrix(model=model, dataloader=dl, device=device)
            torch.save(confusion_matrix, f = 'confusion_matr_' + run_name+ '_15min' + '.pt')
            print(f"Evaluate after first 20 min: Test Loss: {avg_test_loss:.4f}\t Test Accuracy: {test_accuracy:.4f}, Confusion Matrix: \n{confusion_matrix}, Per-class Accuracy: {acc}")
            uu.csv_logger(
                logfile = f"../logs/{run_name}_test.csv",
                content = {"epoch": epoch,"test_phase": 2, "test_loss": avg_test_loss, "test_accuracy": test_accuracy, "time_elapsed": time_elapsed})
        skip_test_15_min = True

    if eval_test_20_min and not skip_test_20_min:
        print('Evaluate after first 20 min')
        with torch.no_grad():
            ds.set_mode("test")
            model.eval()
            test_accuracy, avg_test_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
            confusion_matrix, acc = calculate_confusion_matrix(model=model, dataloader=dl, device=device)
            torch.save(confusion_matrix, f = 'confusion_matr_' + run_name+ '_20min' + '.pt')
            print(f"Evaluate after first 20 min: Test Loss: {avg_test_loss:.4f}\t Test Accuracy: {test_accuracy:.4f}, Confusion Matrix: \n{confusion_matrix}, Per-class Accuracy: {acc}")
            uu.csv_logger(
                logfile = f"../logs/{run_name}_test.csv",
                content = {"epoch": epoch,"test_phase": 3, "test_loss": avg_test_loss, "test_accuracy": test_accuracy, "time_elapsed": time_elapsed})
        stop_bool=True
        break

    # Go to train mode
    ds.set_mode("train")
    model.train()

    # Train loop
    for step, (data, targets) in enumerate(dl):

        # Manually drop last batch (this is for example relevant with BatchNorm)
        if step == num_steps - 1 and (epoch > 0 or ds.cache_data is False):
            continue

        # Train loop: Zero gradients, forward step, evaluate, log, backward step
        optimizer.zero_grad()
        data, targets = data.to(device), targets.to(device)
        predictions = model(data)
        loss = criterion(predictions, targets)
        loss.backward()
        optimizer.step()

    # Go to eval mode
    ds.set_mode("val")
    model.eval()

    # Validation loop
    val_accuracy, avg_val_loss = evaluate_classifier_model(model = model, dataloader = dl, device = device)
    print(f"Epoch [{epoch+1}/{epochs}]\t Val Loss: {avg_val_loss:.4f}\t Val Accuracy: {val_accuracy:.4f}")
    uu.csv_logger(
        logfile = f"../logs/{run_name}_val.csv",
        content = {"epoch": epoch, "val_loss": avg_val_loss, "val_accuracy": val_accuracy},
        first = (epoch == 0),
        overwrite = (epoch == 0)
            )