Reference to load DICOM Images: https://www.kaggle.com/raddar/convert-dicom-to-np-array-the-correct-way
and https://www.kaggle.com/tanlikesmath/siim-covid-19-detection-a-simple-eda

# Classification Training Script
This script is written to be modified and facilitate different types of experiments.

In [None]:
!pip install timm

In [None]:
import glob
import os
import time
import random

import numpy as np  # linear algebra!
import pandas as pd  # data processing, CSV file I/O (e.g. pd.read_csv)
import PIL

from sklearn.model_selection import GroupKFold, StratifiedKFold
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.metrics import roc_auc_score

import cv2
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm
from tqdm.contrib.concurrent import process_map

import torch
import torchvision
from torch.utils.data.dataset import Dataset
import torch.cuda.amp as amp

import timm

import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut

import albumentations as A
from albumentations.pytorch import ToTensorV2

DATA_DIR = "/kaggle/input/siim-covid19-detection"
RESIZE_DIR = "/kaggle/working/"

SIZE = (384, 384)
FOLDS = 5
NUM_CLASSES = 4
BATCHSIZE = 48
SEED = 420
MODEL_TYPE = "4_CLASS"
MODEL_NAME = "vit_small_r26_s32_384"

In [None]:
!mkdir "/kaggle/working/train_{SIZE[0]}x{SIZE[1]}"
!tar -xzf "/kaggle/input/jpeg-and-archive-{SIZE[0]}/train_{SIZE[0]}x{SIZE[1]}.tar.gz" -C "/kaggle/working/train_{SIZE[0]}x{SIZE[1]}" .

In [None]:
# Make results reproducible
def seed_everything(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


seed_everything(SEED)

## Clean column names in CSV's
We clean up a few column names in the `train_image_level.csv` and the `train_study_level.csv`, to merge these two csvs into one.
We also rename the columns to simplified names to use later on.

In [None]:
train_images_df = pd.read_csv(os.path.join(DATA_DIR, "train_image_level.csv"))
train_study_df = pd.read_csv(os.path.join(DATA_DIR, "train_study_level.csv"))
train_images_df["StudyInstanceUID"] = train_images_df["StudyInstanceUID"] + "_study"

train_study_df.columns = train_study_df.columns.map(lambda x: x.split(" ")[0])

train_study_df.rename(columns={"id": "study_id"}, inplace=True)
train_images_df.rename(columns={"StudyInstanceUID": "study_id"}, inplace=True)

## Label Map Creation
At prediction time, we will need these labels in string form. These maps can also be modified to train different types of models, like Binary classification.

In [None]:
if NUM_CLASSES == 4:
    NAME_TO_LABEL_MAP = {"Negative": 0, "Typical": 1, "Indeterminate": 2, "Atypical": 3}


def get_str_label(row):
    for k in NAME_TO_LABEL_MAP:
        if row[k]:
            return k
    return None


def get_int_label(row):
    for k in NAME_TO_LABEL_MAP:
        if row[k]:
            return NAME_TO_LABEL_MAP[k]
    return None

## Group K Fold
We create a `fold` column to be used while we run cross validation.
We choose the `study_id` column created above to split into the number of groups defined by the `FOLDS` variable.

In [None]:
train_study_df["int_label"] = train_study_df.apply(get_int_label, axis=1)
train_study_df["str_label"] = train_study_df.apply(get_str_label, axis=1)

gkf = GroupKFold(n_splits=FOLDS)
train_study_df["fold"] = -1
for fold, (train_idx, val_idx) in enumerate(
    gkf.split(train_study_df, groups=train_study_df["study_id"].tolist())
):
    train_study_df.loc[val_idx, "fold"] = fold

### Plotting Group Counts
We see from the plots below that each group is fairly balanced. The `Typical` category is the most common followed by `Negative`.

In [None]:
sns.catplot(x="str_label", col="fold", data=train_study_df, kind="count")

## Clean Study level data
1. As per the recommendations made [here](https://www.kaggle.com/c/siim-covid19-detection/discussion/246597) for studies with more than one image it appears that there only one which has bounding boxes. 

2. As per the recommendation made [here](https://www.kaggle.com/c/siim-covid19-detection/discussion/240250#1351079), the label for only the one with bounding boxes is retained since the other images were not looked at by the annotators.

3. For studies that have more than one image but no bounding boxes associated with them, it is unclear as to which image was looked at therefore all images are retained in those studies.

In [None]:
train_samples_df = pd.merge(
    train_images_df, train_study_df, on="study_id", how="inner"
).reset_index(drop=True)

box_and_images_counts_df = (
    train_samples_df.groupby("study_id")[["id", "boxes"]]
    .count()
    .sort_values(ascending=False, by="id")
    .reset_index()
)
box_and_images_counts_df.rename(
    columns={"id": "id_count", "boxes": "boxes_count"}, inplace=True
)

In [None]:
sns.countplot(data=box_and_images_counts_df, x="id_count")

In [None]:
train_samples_df = pd.merge(
    train_samples_df, box_and_images_counts_df, how="inner", on="study_id"
)

In [None]:
train_samples_df.sort_values(["id_count", "boxes_count"], ascending=False, inplace=True)
train_samples_df.head(5)

## Dicom helpers

In [None]:
def resize(array, size, keep_ratio=False, resample=PIL.Image.LANCZOS):
    # Original from: https://www.kaggle.com/xhlulu/vinbigdata-process-and-resize-to-image
    im = PIL.Image.fromarray(array)

    if keep_ratio:
        im.thumbnail((size[0], size[1]), resample)
    else:
        im = im.resize((size[0], size[1]), resample)

    return im


def dicom2array(path, size, voi_lut=True, fix_monochrome=True):
    dicom = pydicom.read_file(path)
    # VOI LUT (if available by DICOM device) is used to
    # transform raw DICOM data to "human-friendly" view
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array
    # depending on this value, X-ray may look inverted - fix that:
    if fix_monochrome and dicom.PhotometricInterpretation == "MONOCHROME1":
        data = np.amax(data) - data
    data = data - np.min(data)
    data = data / np.max(data)
    data = (data * 255).astype(np.uint8)
    data = resize(data, size)
    return data

## Caching Images

It is imperative that the simplistic logic below, executes without errors. This improves training speeds very significantly. It merely convers the DICOM into JPG and caches it. Even fully expanded a 512x512 image dataset is only a few 100MB.

In [None]:
%%time
TRAIN_DATA_PATH = None

if os.path.exists(os.path.join(RESIZE_DIR, 'train_{}x{}'.format(SIZE[0], SIZE[1]))):
    TRAIN_DATA_PATH = os.path.join(RESIZE_DIR, 'train_{}x{}'.format(SIZE[0], SIZE[1]))
    print("{} Exists".format(TRAIN_DATA_PATH))
else:
    TRAIN_DATA_PATH = os.path.join(RESIZE_DIR, 'train_{}x{}'.format(SIZE[0], SIZE[1]))
    print("Creating Training dir at {}".format(TRAIN_DATA_PATH))
    os.makedirs(TRAIN_DATA_PATH)
    filenames = glob.glob(os.path.join(DATA_DIR, "train/*/*/*.dcm"))

    def persist_image(path):
        im = dicom2array(path, SIZE)
        fname = os.path.basename(os.path.splitext(path)[-2])
        jpg_fname = os.path.join(TRAIN_DATA_PATH, "{}.jpg".format(fname))
        im.save(jpg_fname)
        return jpg_fname

    process_map(persist_image, filenames, max_workers=8, chunksize=1)


In [None]:
jpg_counts = !ls -l {TRAIN_DATA_PATH} | wc -l
# assert int(jpg_counts[0]) - 2 == train_images_df.shape[0]

In [None]:
jpg_counts

### Cleaning Study Level data
As per the recommendations in the beginning of the section, in order to have clean data for studies that have multiple images we only retain the one with the bounding box. This is used later in the Dataset loader created below.

In [None]:
def keep_row(row):
    # keep as negative sample for study with 0 bboxes with opacity
    # or non-null bounding box
    if row["boxes_count"] and not pd.isna(row["boxes"]):
        return True
    if row["boxes_count"] == 0:
        return True
    else:
        return False

In [None]:
def get_img_path(row):
    study_id = row["study_id"][:-6]
    img_id = row["id"][:-6]
    paths = glob.glob(os.path.join(TRAIN_DATA_PATH, "{}.jpg".format(img_id)))
    for path in paths:
        if img_id in path:
            return path
    return None

In [None]:
train_samples_df["path"] = train_samples_df.apply(get_img_path, axis=1)

In [None]:
train_samples_df["keep_row"] = train_samples_df.apply(keep_row, axis=1)

In [None]:
train_samples_df = train_samples_df[train_samples_df["keep_row"]]
train_samples_df.head(5)

In [None]:
train_samples_df["path"][0]

In [None]:
train_samples_df.sort_values(["study_id"], inplace=True)
train_samples_df.set_index("study_id", inplace=True)

In [None]:
train_samples_df[
    (train_samples_df["id_count"] > 1) & (train_samples_df["boxes_count"] == 0)
].head(5)

In [None]:
dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
dev

## Dataset Class
For the others, we retain all images and randomly choose which image to train on. At prediction time, we will average predictions made, but this is not yet implemented instead we randomly produce a prediction by choosing an image within the study.

In [None]:
class XRayDatasetFromDF(Dataset):
    def __init__(self, df, train=True, augment=True, normalize=False, size=(384, 384)):
        self.df = df
        self.name_to_label_map = {
            "Negative": 0,
            "Typical": 1,
            "Indeterminate": 2,
            "Atypical": 3,
        }
        self.study_ids = df.index.sort_values()
        self.path_suffix = (
            os.path.join(DATA_DIR, "train") if train else os.path.join(DATA_DIR, "test")
        )
        self._augment = augment
        self._normalize = normalize
        self._size = size
        self._transform_list = [
            # A.Resize(size[0], size[1], p=1)
        ]

        if self._augment:
            self._transform_list.extend(
                [
                    A.VerticalFlip(p=0.5),
                    A.HorizontalFlip(p=0.5),
                    A.ShiftScaleRotate(
                        scale_limit=0.20,
                        rotate_limit=10,
                        shift_limit=0.1,
                        p=0.5,
                        border_mode=cv2.BORDER_CONSTANT,
                        value=0,
                    ),
                    A.RandomBrightnessContrast(p=0.5),
                    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                    ToTensorV2(),
                ]
            )
        elif self._normalize and not self._augment:  # test mode
            self._transform_list.extend(
                [
                    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                    ToTensorV2(),
                ]
            )
        self._transforms = A.Compose(self._transform_list)

    def __len__(self):
        return len(self.study_ids)

    def assign_label(self, row):
        for k in self.name_to_label_map:
            if row[k]:
                return self.name_to_label_map[k]

    def __getitem__(self, idx):
        study_id = self.study_ids[idx]
        study_imgs = self.df.loc[study_id]
        path = None
        label = None

        if len(study_imgs.shape) == 1:
            path = study_imgs["path"]
            label = study_imgs["int_label"]
        else:
            row = study_imgs.sample(1).loc[study_id]
            path = row["path"]
            label = row["int_label"]

        # ideally, we'd clean up the df,
        # but may be we use it to produce predictions as well.
        dicom_arr = (
            cv2.imread(path)
            if path.endswith(".jpg")
            else dicom2array(path, size=self._size)
        )
        img = cv2.cvtColor(dicom_arr, cv2.COLOR_BGR2RGB)
        img = self._transforms(image=img)["image"]

        return img, label

In [None]:
class AverageMeter:
    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def train_epoch(
    epoch,
    step,
    dataloader,
    batchsize,
    model,
    optimizer,
    loss_fn,
    log_every=10,
    scaler=None,
):

    steps = len(dataloader)
    batchsize = batchsize
    dataiter = iter(dataloader)

    time_now = time.time()

    loss_avg = AverageMeter()
    acc_avg = AverageMeter()

    model.train()
    loader = tqdm(range(steps))

    enable_autocast = not (scaler == None)

    for i in loader:
        optimizer.zero_grad()
        data, targets = next(dataiter)
        data = data.to(dev)
        targets = targets.to(dev)

        with amp.autocast(enabled=enable_autocast):
            output = model(data)
            loss = loss_fn(output, targets)

        if enable_autocast:
            scaled_loss = scaler.scale(loss)
            scaled_loss.backward()
            loss_avg.update(scaled_loss.item(), batchsize)
        else:
            loss_avg.update(loss.item(), batchsize)
            loss.backward()

        if scaler:
            scaler.step(optimizer)
            scaler.update()
        else:
            optimizer.step()

        time_spent = time.time() - time_now
        time_now = time.time()
        preds = output.argmax(axis=1)
        acc = (preds == targets).sum().item() / batchsize

        acc_avg.update(acc, batchsize)
        #         if step % log_every == 0:
        #             print("{}, Epoch : {}, Step : {}, Training Loss : {:.5f}, Run Time : {:.5g}"
        #                   .format(time.strftime("%Y-%m-%d %H:%M:%S"), epoch, step, loss.item(), time_spent))
        loader.set_description(
            "Training Epoch : {}, Time Spent {:.5g}, Step {}".format(
                epoch, time_spent, step
            )
        )
        loader.set_postfix(loss=loss_avg.avg, acc=acc_avg.avg)

        step += 1

    return acc_avg.avg, loss_avg.avg, step

In [None]:
def apk(actual, predicted, k=10):
    if len(predicted) > k:
        predicted = predicted[:k]

    score = 0.0
    num_hits = 0.0

    for i, p in enumerate(predicted):
        if p in actual and p not in predicted[:i]:
            num_hits += 1.0
            score += num_hits / (i + 1.0)

    if not actual:
        return 0.0

    return score / min(len(actual), k)


def mean_average_precision(actual, predicted, k=10):
    return np.mean([apk(a, p, k) for a, p in zip(actual, predicted)])

In [None]:
def valid_epoch(epoch, step, dataloader, batchsize, model, loss_fn):
    num_val = len(dataloader)
    probs_list = []
    targets_list = []
    dataiter = iter(dataloader)
    model.eval()

    loss_avg = AverageMeter()
    acc_avg = AverageMeter()

    with torch.no_grad():
        for i in tqdm(range(num_val)):
            data, targets = next(dataiter)
            data = data.to(dev)
            targets = targets.to(dev)

            outputs = model(data)
            loss = loss_fn(outputs, targets)

            probs = outputs.softmax(axis=1)

            loss_avg.update(loss.item(), batchsize)

            probs_list.append(probs)
            targets_list.append(targets)

        probs = torch.cat(probs_list).cpu().numpy()
        targets = torch.cat(targets_list).cpu().numpy()

    print(classification_report(targets, np.argmax(probs, axis=1)))
    print(confusion_matrix(targets, np.argmax(probs, axis=1)))
    auc = roc_auc_score(targets, probs, average="macro", multi_class="ovo")

    topk = (-probs).argsort(axis=1)[:, :2]
    mapk = mean_average_precision(targets[:, np.newaxis].tolist(), topk.tolist(), k=2)
    print("MAP@2 Score at Epoch {} and Step {}: {}".format(epoch, step, mapk))
    print("AUC Score at Epoch {} and Step {}: {}".format(epoch, step, auc))
    return mapk, auc, loss_avg.avg, probs, targets

In [None]:
def train_model(
    model,
    loss_fn,
    epochs,
    batchsize,
    optimizer,
    scheduler,
    save_path,
    train_dl,
    validation_dl,
    use_mp=True,
):
    if use_mp:
        scaler = amp.GradScaler()
    else:
        scaler = None
    if not os.path.exists(save_path):
        os.mkdir(save_path)

    train_vs_val = []

    loss_valid_best = float("inf")
    mapk_valid_best = -float("inf")
    auc_valid_best = -float("inf")
    step = 0

    for epoch in range(epochs):

        train_acc_avg, train_loss_avg, step = train_epoch(
            epoch, step, train_dl, batchsize, model, optimizer, loss_fn, scaler=scaler
        )

        time_now = time.time()
        val_mapk, val_auc, val_loss_avg, probs, targets = valid_epoch(
            epoch, step, validation_dl, batchsize * 2, model, loss_fn
        )
        time_spent = time.time() - time_now
        train_vs_val.extend(
            [
                (epoch, val_loss_avg, "Validation Loss"),
                (epoch, val_mapk, "Validation MAP@2"),
                (epoch, val_auc, "Validation AUC"),
                (epoch, train_loss_avg, "Training Loss"),
                (epoch, train_acc_avg, "Training Accuracy"),
            ]
        )

        if scheduler:
            scheduler.step(val_loss_avg)
            print(
                "Setting Learning Rate to: {:.6f}".format(
                    optimizer.param_groups[-1]["lr"]
                )
            )

        if mapk_valid_best < val_mapk:
            print("Found Model with best Map@2 {} at epoch {}".format(val_mapk, epoch))
            torch.save(
                {
                    "epoch": epoch,
                    "map_at_2": val_mapk,
                    "auc": val_auc,
                    "probs": probs,
                    "targets": targets,
                    "state_dict": model.module.state_dict(),
                },
                os.path.join(save_path, "best_map_at_2.pth"),
            )
            mapk_valid_best = val_mapk
        if auc_valid_best < val_auc:
            print("Found Model with best AUC {} at epoch {}".format(val_auc, epoch))
            torch.save(
                {
                    "epoch": epoch,
                    "map_at_2": val_mapk,
                    "auc": val_auc,
                    "probs": probs,
                    "targets": targets,
                    "state_dict": model.module.state_dict(),
                },
                os.path.join(save_path, "best_auc.pth"),
            )
            auc_valid_best = val_auc

        print(
            "{}, Epoch : {}, Step : {}, Validation Loss : {:.5f}, Run Time : {:.5g}".format(
                time.strftime("%Y-%m-%d %H:%M:%S"),
                epoch,
                step,
                val_loss_avg,
                time_spent,
            )
        )

    return train_vs_val

In [None]:
def find_lr(
    model,
    optimizer,
    training_dl,
    batch_size,
    loss_fn,
    init_value=1e-8,
    final_value=10,
    beta=0.98,
    update_every=10,
    num_epochs=3,
    use_mp=True,
):
    if use_mp:
        scaler = amp.GradScaler()
    else:
        scaler = None
    enable_autocast = not (scaler == None)

    num_samples = num_epochs * len(training_dl) - 1
    print(num_samples)
    multiplier = (final_value / init_value) ** (1 / num_samples)
    lr = init_value

    for group in optimizer.param_groups:
        group["lr"] = lr

    avg_loss = 0
    best_loss = 0
    batch_num = 0

    losses = []
    log_lrs = []
    smoothed_losses = []
    for epoch in range(num_epochs):
        for x_pos in training_dl:
            optimizer.zero_grad()
            batch_num += 1
            pos_img, pos_target = x_pos

            idx_rand = torch.randperm(batch_size, requires_grad=False).to(dev)
            imgs = pos_img[idx_rand]
            target = pos_target[idx_rand]
            imgs = imgs.to(dev)
            target = target.to(dev)

            if enable_autocast:
                scaled_loss = scaler.scale(loss)
                scaled_loss.backward()
                loss_avg.update(scaled_loss.item(), batchsize)
            else:
                loss_avg.update(loss.item(), batchsize)
                loss.backward()

            if scaler:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()
            # Smoothed Loss Computation
            avg_loss = beta * avg_loss + (1 - beta) * loss.item()
            smoothed_loss = avg_loss / (1 - beta ** batch_num)
            # Stop if the loss is exploding
            if batch_num > 1 and smoothed_loss > 2 * best_loss:
                print("boom")
                return log_lrs, losses, smoothed_losses
            # Record the best loss
            if smoothed_loss < best_loss or batch_num == 1:
                best_loss = smoothed_loss

            # Store values

            losses.append(loss.detach().cpu())
            log_lrs.append(np.log10(lr))
            smoothed_losses.append(smoothed_loss)

            with torch.no_grad():
                lr *= multiplier
                for group in optimizer.param_groups:
                    group["lr"] = lr
            if epoch * len(training_dl) // batch_size + batch_num % update_every == 0:
                print("Setting LR to be: {:.5g}".format(lr))
    return log_lrs, losses, smoothed_losses

In [None]:
# model = timm.create_model(model_name="tf_efficientnetv2_s", pretrained=True, in_chans=3)
# model.classifier = torch.nn.Linear(in_features=model.classifier.in_features, out_features=NUM_CLASSES)

# model = timm.create_model(
#     model_name="vit_small_r26_s32_384", pretrained=True, in_chans=3
# )
# model.head = torch.nn.Linear(
#     in_features=model.head.in_features, out_features=NUM_CLASSES
# )

# model = torch.nn.DataParallel(model)
# model = model.to(dev)
# optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-2)
# # optimizer = torch.optim.SGD(model.parameters(), lr=3e-4, momentum=0.9, weight_decay=1e-3)
# loss_fn = torch.nn.CrossEntropyLoss().to(dev)

In [None]:
# training_ds = XRayDatasetFromDF(train_samples_df, augment=True)
# training_dl = torch.utils.data.DataLoader(dataset=training_ds,
#                                           batch_size=BATCHSIZE,
#                                           pin_memory=True,
#                                           num_workers=8,
#                                           drop_last=True,
#                                           shuffle=True,
#                                           prefetch_factor=8)

In [None]:
# log_lrs, losses, smoothed_losses = find_lr(model, optimizer, training_dl, BATCHSIZE, loss_fn)

In [None]:
# i = 0
# j = 6

# plt.plot(log_lrs[i:-j], smoothed_losses[i:-j], c="r")
# plt.plot(log_lrs[i:-j], losses[i:-j], c="b")

In [None]:
for fold in range(FOLDS):
    print("Training Fold {}".format(fold))
    model = timm.create_model(model_name=MODEL_NAME, pretrained=True, in_chans=3)
    model.head = torch.nn.Linear(
        in_features=model.head.in_features, out_features=NUM_CLASSES
    )

    model = torch.nn.DataParallel(model)
    model = model.to(dev)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.1, patience=5, verbose=False, min_lr=1e-7
    )

    loss_fn = torch.nn.CrossEntropyLoss().to(dev)

    training_ds = XRayDatasetFromDF(
        train_samples_df[train_samples_df["fold"] != fold], augment=True
    )
    validation_ds = XRayDatasetFromDF(
        train_samples_df[train_samples_df["fold"] == fold],
        augment=False,
        normalize=True,
    )
    print("{} train len {} val len".format(len(training_ds), len(validation_ds)))

    training_dl = torch.utils.data.DataLoader(
        dataset=training_ds,
        batch_size=BATCHSIZE,
        pin_memory=True,
        num_workers=8,
        drop_last=True,
        shuffle=True,
        prefetch_factor=6,
    )
    validation_dl = torch.utils.data.DataLoader(
        dataset=validation_ds,
        batch_size=BATCHSIZE * 2,
        pin_memory=True,
        num_workers=8,
        drop_last=False,
        prefetch_factor=8,
    )

    print(
        "{} training data loader size {} validation dataloader size".format(
            len(training_dl), len(validation_dl)
        )
    )

    train_vs_val = train_model(
        model=model,
        loss_fn=loss_fn,
        epochs=20,
        optimizer=optimizer,
        scheduler=scheduler,
        batchsize=BATCHSIZE,
        save_path="{}-{}".format(MODEL_NAME, fold),
        train_dl=training_dl,
        validation_dl=validation_dl,
    )

    fold_report = pd.DataFrame.from_records(
        data=train_vs_val, columns=["Epoch", "Loss", "Type"]
    )
    sns.lineplot(data=fold_report, x="Epoch", y="Loss", hue="Type")

In [None]:
!nvidia-smi