# 3D classification based on DenseNet

Based on a tutorial by the MONAI Consortium: https://github.com/Project-MONAI/tutorials/blob/main/3d_classification/densenet_training_array.ipynb

## Setup imports

In [None]:
import logging
import os
import glob
import sys
import shutil
import tempfile
import datetime
import socket
import functools

from tqdm import tqdm

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import nibabel as nib

import torch
from torch.utils.tensorboard import SummaryWriter

from sklearn.metrics import f1_score, accuracy_score, balanced_accuracy_score, roc_auc_score

import monai
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import DataLoader, ImageDataset
from monai.transforms import (
    EnsureChannelFirst,
    Compose,
    RandRotate90,
    Resize,
    ScaleIntensity,
)

pin_memory = torch.cuda.is_available()
torch.backends.cudnn.benchmark = False #torch.cuda.is_available() # Set this to true if the code fails
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
print_config()

## Helper functions to get predictions/confusion matrices

In [None]:
def get_predicted_values(model, loader):
    # Iterates through a dataloader and gets the raw predictions and labels
    t_model_outputs = []
    t_test_labels = []
    for test_data in loader:
        test_images, test_labels = test_data[0].to(device), test_data[1].to(device)
        with torch.no_grad():
            val_outputs = model(test_images)
            t_model_outputs.append(val_outputs.cpu().detach().numpy())
            t_test_labels.append(test_labels.cpu().detach().numpy())
    conf_model_outputs = np.concatenate(t_model_outputs)
    conf_test_labels = np.concatenate(t_test_labels)
    return conf_model_outputs, conf_test_labels

def get_cm(conf_model_outputs, conf_test_labels, num_classes):
    # Get a (count) confusion matrix based on the class predictions and true labels
    conf_matrix = np.zeros((num_classes,num_classes))
    for i in range(num_classes):
        in_class_i = conf_test_labels==i
        for j in range(num_classes):
            in_class_i_predicted_in_class_j = sum(conf_model_outputs[in_class_i]==j)
            conf_matrix[i,j] = in_class_i_predicted_in_class_j
    return conf_matrix

## Set up data

In [None]:
root_dir = os.path.expanduser(os.path.expandvars("~/data/medicaldecathlon/"))
data_dir = os.path.join(root_dir, "Task10_Colon")
train_dataset_frailty_path = os.path.join(data_dir,"train_clean.csv")
test_dataset_frailty_path = os.path.join(data_dir,"test_clean.csv")
print(root_dir)

Get list of images

In [None]:
train_images = sorted(
    glob.glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
train_images[:5]

In [None]:
test_image_paths = sorted(
    glob.glob(os.path.join(data_dir, "imagesTs", "*.nii.gz")))
test_image_paths[:5]

Get training and testing labels. You should move these from this repo's `data` folder to the same folder as that contain the image folders above.

In [None]:
# Which columns should be forced to be integer columns
int_cols = ["Skeletal Muscle Wasting","Fat Excess","Bone density","Aortic Calcium","Liver fat","Pancreatic fat","Total Score"]

In [None]:
df_labels = pd.read_csv(train_dataset_frailty_path, index_col="PatientID").dropna().astype({col:int for col in int_cols})
df_labels.head()

In [None]:
df_labels_test = pd.read_csv(test_dataset_frailty_path, index_col="PatientID").dropna().astype({col:int for col in int_cols})
df_labels_test.head()

Change values in the dataframes from string to class ids (Not used here)

In [None]:
df_labels.loc[df_labels["Risk Category"]=="LOW","Risk Category"] = 0
df_labels.loc[df_labels["Risk Category"]=="MEDIUM","Risk Category"] = 1
df_labels.loc[df_labels["Risk Category"]=="HIGH","Risk Category"] = 2

df_labels_test.loc[df_labels_test["Risk Category"]=="LOW","Risk Category"] = 0
df_labels_test.loc[df_labels_test["Risk Category"]=="MEDIUM","Risk Category"] = 1
df_labels_test.loc[df_labels_test["Risk Category"]=="HIGH","Risk Category"] = 2

In [None]:
def get_id_from_filepath(fpath):
    # Gets patient ID from its filepath
    return int(os.path.basename(fpath).split("_")[1].split(".")[0])

In [None]:
df_labels.dtypes

In [None]:
label_types = ["Skeletal Muscle Wasting", "Fat Excess", "Bone density", "Aortic Calcium", "Liver fat", "Pancreatic fat"]
label_type_weights = [2,1,2,1,1,1]

In [None]:
MIN_SLICES = 64

In [None]:
data_dicts = [
    {"image": image_name, "label": df_labels.loc[patient_id,label_types].astype(int)}
    for image_name,patient_id in zip(train_images,map(get_id_from_filepath,train_images))
    if patient_id in df_labels.index and nib.load(image_name).get_fdata().shape[2]>=MIN_SLICES
]

In [None]:
test_data_dicts = [
    {"image": image_name, "label": df_labels_test.loc[patient_id,label_types].astype(int)}
    for image_name,patient_id in zip(test_image_paths,map(get_id_from_filepath,test_image_paths))
    if patient_id in df_labels_test.index and nib.load(image_name).get_fdata().shape[2]>=MIN_SLICES
]
len(test_data_dicts)

In [None]:
# IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
images = np.array([d["image"] for d in data_dicts])
# 2 binary labels for gender classification: man or woman
labels = np.array([d["label"] for d in data_dicts])

test_images = np.array([d["image"] for d in test_data_dicts])
# 2 binary labels for gender classification: man or woman
test_labels = np.array([d["label"] for d in test_data_dicts])


# Represent labels in one-hot format for binary classifier training,
# BCEWithLogitsLoss requires target to have same shape as input
#labels = torch.nn.functional.one_hot(torch.as_tensor(labels)).float()
num_labels = [int(max(labels[:,l]))+1 for l in range(labels.shape[1])] if isinstance(labels, np.ndarray) and len(labels.shape)>1 else int(max(labels))+1
num_labels_cumsum = np.cumsum(num_labels).tolist()
num_labels, num_labels_cumsum, labels[:5], test_labels[:5]

In [None]:
shapes = [nib.load(datapoint["image"]).get_fdata().shape for datapoint in data_dicts]
np_shapes = np.stack(shapes)
minshapes = np.min(np_shapes, axis=0)
minshapes

In [None]:
original_crop = 96
for newcrop in range(original_crop,2,-1):
    if newcrop <= min(minshapes):
        crop_shapes = tuple([newcrop]*3)
        break
crop_shapes

In [None]:
val_pct = 0.2
val_split = int(val_pct*len(labels))

# TODO: Stratified split
#possible_labels = sorted(list(set(labels)))
#proportion_in_labels = np.array([sum(labels==i)/len(labels) for i in possible_labels])
#val_per_labels = [int(l*val_split) for l in proportion_in_labels]
#proportion_in_labels, val_per_labels
val_split

In [None]:
label_weigths = None#1/proportion_in_labels[labels]

In [None]:
# TODO: Stratified split
#val_idx = np.concatenate([np.random.choice([i for i, l in enumerate(labels) if l==p], c, replace=False) for p,c in zip(possible_labels, val_per_labels)])
val_idx = np.random.choice([i for i in range(len(labels))], val_split, replace=False)
in_val = np.isin(np.arange(len(labels)),val_idx)
in_train = ~in_val
train_idx = np.arange(len(labels))[in_train]
val_idx

In [None]:
batch_size = 4#3
while (len(images)-val_split)%batch_size==1 or val_split%batch_size==1:
    batch_size +=1
    print("Changing batch size so that no batch has size 1")
batch_size

In [None]:
# Define transforms
train_transforms = Compose([ScaleIntensity(), EnsureChannelFirst(), Resize(crop_shapes), RandRotate90()])

val_transforms = Compose([ScaleIntensity(), EnsureChannelFirst(), Resize(crop_shapes)])

# Define nifti dataset, data loader
check_ds = ImageDataset(image_files=images, labels=labels, transform=train_transforms)
check_loader = DataLoader(check_ds, batch_size=3, num_workers=2, pin_memory=pin_memory)

im, label = monai.utils.misc.first(check_loader)
print(type(im), im.shape, label, label.shape)

# create a training data loader
train_ds = ImageDataset(image_files=images[train_idx].tolist(), labels=labels[train_idx], transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=pin_memory)

# create a validation data loader
val_ds = ImageDataset(image_files=images[val_idx].tolist(), labels=labels[val_idx], transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=2, pin_memory=pin_memory)

In [None]:
len(images)

In [None]:
len(train_ds), len(val_ds)

In [None]:
len(train_loader), len(val_loader)

In [None]:
metrics_fns = {
    (s.func.__name__ if isinstance(s, functools.partial) else s.__name__).split("_score")[0]: s
    for s in [functools.partial(f1_score, average="micro"), accuracy_score, balanced_accuracy_score]
}
val_metric = "balanced_accuracy"
metrics_fns

In [None]:
# Create DenseNet121, CrossEntropyLoss and Adam optimizer
model = monai.networks.nets.DenseNet(spatial_dims=3, in_channels=1, out_channels=np.sum(num_labels)).to(device)

#loss_function = [torch.nn.CrossEntropyLoss(torch.tensor(1/proportion_in_labels, device=device, dtype=torch.float32)) for _ in label_types]
loss_function = [torch.nn.CrossEntropyLoss() for _ in label_types]
# loss_function = torch.nn.BCEWithLogitsLoss()  # also works with this data

optimizer = torch.optim.Adam(model.parameters(), 1e-4)

# start a typical PyTorch training
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
writer = SummaryWriter(f"multitask_runs/{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}_{socket.gethostname()}")
max_epochs = 256

num_labels_ranges = [0]+num_labels_cumsum

for epoch in range(max_epochs):
    try:
        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_epochs}")
        model.train()
        epoch_loss = 0
        step = 0

        for batch_data in train_loader:
            step += 1
            b_inputs, b_labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(b_inputs)
            loss = 0
            for l_idx,(loss_fn,loss_w,s,e) in enumerate(zip(loss_function, label_type_weights, num_labels_ranges[:-1],num_labels_ranges[1:])):
                loss += loss_fn(outputs[:,s:e], b_labels[:,l_idx])*loss_w
            loss /= sum(label_type_weights)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)

        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()

            num_correct = 0.0
            metric_count = 0
            v_pred_raw, v_label = get_predicted_values(model, val_loader)
            v_pred = np.stack([
                    v_pred_raw[:,s:e].argmax(axis=1)
                    for (s,e) in zip(num_labels_ranges[:-1],num_labels_ranges[1:])
                ],
                axis=1,
            )

            all_metrics_per_label = {
                metric: [] for metric in metrics_fns
            }
            
            for l_idx, l in enumerate(label_types):
                for metric in all_metrics_per_label:
                    metric_value = metrics_fns[metric](v_label[:,l_idx], v_pred[:,l_idx])
                    writer.add_scalar(f"{l} {metric}",metric_value, epoch + 1)
                    all_metrics_per_label[metric].append(metric_value)
                cm_counts = get_cm(v_pred[:,l_idx], v_label[:,l_idx], num_labels[l_idx])
                cm_pct = cm_counts/cm_counts.sum(axis=1,keepdims=True)
                print(cm_pct)
                for i in range(num_labels[l_idx]):
                    for j in range(num_labels[l_idx]):
                        writer.add_scalar(f"{l} count l{i}_p{j}",cm_counts[i,j], epoch + 1)
                        writer.add_scalar(f"{l} pct l{i}_p{j}",cm_pct[i,j], epoch + 1)

            metric = np.mean(all_metrics_per_label[val_metric])
            metric_values.append(metric)

            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), "multitask_best_metric_model_classification3d_array.pth")
                print("saved new best metric model")

            print(f"Current epoch: {epoch+1} current {val_metric}: {metric:.4f} ")
            print(f"Best {val_metric}: {best_metric:.4f} at epoch {best_metric_epoch}")
            for metric in all_metrics_per_label:
                writer.add_scalar(f"val_{metric}", np.mean(all_metrics_per_label[metric]), epoch + 1)
    except KeyboardInterrupt:
        break

print(f"Training completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
torch.save(model.state_dict(), "multitask_last_model_classification3d_array.pth")
writer.close()

# Occlusion sensitivity
One method for trying to visualise why the network made a given prediction is occlusion sensitivity. We occlude part of the image, and see how the probability of a given prediction changes. We then iterate over the image, moving the occluded portion as we go, and in doing so we build up a sensitivity map detailing which areas were the most important in making the decision.

#### Bounds
If we were to test the occlusion centred on all voxels in our image, we would have to do `torch.prod(im.shape) = 96^3 = ~1e6` predictions. We can use the bounding box to only to the estimations in a region of interest, for example over one slice.

To do this, we simply give the bounding box as `(minC,maxC,minD,maxD,minH,maxH,minW,maxW)`. We can use `-1` for any value to use its full extent (`0` and `im.shape-1` for min's and max's, respectively).

#### Output
The output image in this example will look fairly bad, since our network hasn't been trained for very long. Training for longer should improve the quality of the occlusion map.

In [None]:
# create a validation data loader
test_ds = ImageDataset(image_files=test_images, labels=test_labels, transform=val_transforms)
test_loader = DataLoader(test_ds, batch_size=2, num_workers=2, pin_memory=pin_memory)

train_ds = ImageDataset(image_files=images[train_idx].tolist(), labels=labels[train_idx], transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=pin_memory)

In [None]:
len(test_images)

In [None]:
v_pred_raw, v_label = get_predicted_values(model, train_loader)
v_pred = np.stack([
        v_pred_raw[:,s:e].argmax(axis=1)
        for (s,e) in zip(num_labels_ranges[:-1],num_labels_ranges[1:])
    ],
    axis=1,
)

all_metrics_per_label = {
    metric: [] for metric in metrics_fns
}

for l_idx, l in enumerate(label_types):
    for metric in all_metrics_per_label:
        metric_value = metrics_fns[metric](v_label[:,l_idx], v_pred[:,l_idx])
        print(f"train {l} {metric}",metric_value)
        all_metrics_per_label[metric].append(metric_value)
    cm_counts = get_cm(v_pred[:,l_idx], v_label[:,l_idx], num_labels[l_idx])
    cm_pct = cm_counts/cm_counts.sum(axis=1,keepdims=True)
    print(cm_counts)
    print(cm_pct)

for metric in all_metrics_per_label:
    print(f"train_{metric}", np.mean(all_metrics_per_label[metric]))

In [None]:
v_pred_raw, v_label = get_predicted_values(model, val_loader)
v_pred = np.stack([
        v_pred_raw[:,s:e].argmax(axis=1)
        for (s,e) in zip(num_labels_ranges[:-1],num_labels_ranges[1:])
    ],
    axis=1,
)

all_metrics_per_label = {
    metric: [] for metric in metrics_fns
}

for l_idx, l in enumerate(label_types):
    for metric in all_metrics_per_label:
        metric_value = metrics_fns[metric](v_label[:,l_idx], v_pred[:,l_idx])
        print(f"val {l} {metric}",metric_value)
        all_metrics_per_label[metric].append(metric_value)
    cm_counts = get_cm(v_pred[:,l_idx], v_label[:,l_idx], num_labels[l_idx])
    cm_pct = cm_counts/cm_counts.sum(axis=1,keepdims=True)
    print(cm_counts)
    print(cm_pct)

for metric in all_metrics_per_label:
    print(f"val_{metric}", np.mean(all_metrics_per_label[metric]))

In [None]:
v_pred_raw, v_label = get_predicted_values(model, test_loader)
v_pred = np.stack([
        v_pred_raw[:,s:e].argmax(axis=1)
        for (s,e) in zip(num_labels_ranges[:-1],num_labels_ranges[1:])
    ],
    axis=1,
)

all_metrics_per_label = {
    metric: [] for metric in metrics_fns
}

for l_idx, l in enumerate(label_types):
    for metric in all_metrics_per_label:
        metric_value = metrics_fns[metric](v_label[:,l_idx], v_pred[:,l_idx])
        print(f"test {l} {metric}",metric_value)
        all_metrics_per_label[metric].append(metric_value)
    cm_counts = get_cm(v_pred[:,l_idx], v_label[:,l_idx], num_labels[l_idx])
    cm_pct = cm_counts/cm_counts.sum(axis=1,keepdims=True)
    print(cm_counts)
    print(cm_pct)

for metric in all_metrics_per_label:
    print(f"test_{metric}", np.mean(all_metrics_per_label[metric]))

In [None]:
model.load_state_dict(torch.load("multitask_best_metric_model_classification3d_array.pth"))

In [None]:
v_pred_raw, v_label = get_predicted_values(model, train_loader)
v_pred = np.stack([
        v_pred_raw[:,s:e].argmax(axis=1)
        for (s,e) in zip(num_labels_ranges[:-1],num_labels_ranges[1:])
    ],
    axis=1,
)

all_metrics_per_label = {
    metric: [] for metric in metrics_fns
}

for l_idx, l in enumerate(label_types):
    for metric in all_metrics_per_label:
        metric_value = metrics_fns[metric](v_label[:,l_idx], v_pred[:,l_idx])
        print(f"train {l} {metric}",metric_value)
        all_metrics_per_label[metric].append(metric_value)
    cm_counts = get_cm(v_pred[:,l_idx], v_label[:,l_idx], num_labels[l_idx])
    cm_pct = cm_counts/cm_counts.sum(axis=1,keepdims=True)
    print(cm_counts)
    print(cm_pct)

for metric in all_metrics_per_label:
    print(f"train_{metric}", np.mean(all_metrics_per_label[metric]))

In [None]:
v_pred_raw, v_label = get_predicted_values(model, val_loader)
v_pred = np.stack([
        v_pred_raw[:,s:e].argmax(axis=1)
        for (s,e) in zip(num_labels_ranges[:-1],num_labels_ranges[1:])
    ],
    axis=1,
)

all_metrics_per_label = {
    metric: [] for metric in metrics_fns
}

for l_idx, l in enumerate(label_types):
    for metric in all_metrics_per_label:
        metric_value = metrics_fns[metric](v_label[:,l_idx], v_pred[:,l_idx])
        print(f"val {l} {metric}",metric_value)
        all_metrics_per_label[metric].append(metric_value)
    cm_counts = get_cm(v_pred[:,l_idx], v_label[:,l_idx], num_labels[l_idx])
    cm_pct = cm_counts/cm_counts.sum(axis=1,keepdims=True)
    print(cm_counts)
    print(cm_pct)

for metric in all_metrics_per_label:
    print(f"val_{metric}", np.mean(all_metrics_per_label[metric]))

In [None]:
v_pred_raw, v_label = get_predicted_values(model, test_loader)
v_pred = np.stack([
        v_pred_raw[:,s:e].argmax(axis=1)
        for (s,e) in zip(num_labels_ranges[:-1],num_labels_ranges[1:])
    ],
    axis=1,
)

all_metrics_per_label = {
    metric: [] for metric in metrics_fns
}

for l_idx, l in enumerate(label_types):
    for metric in all_metrics_per_label:
        metric_value = metrics_fns[metric](v_label[:,l_idx], v_pred[:,l_idx])
        print(f"test {l} {metric}",metric_value)
        all_metrics_per_label[metric].append(metric_value)
    cm_counts = get_cm(v_pred[:,l_idx], v_label[:,l_idx], num_labels[l_idx])
    cm_pct = cm_counts/cm_counts.sum(axis=1,keepdims=True)
    print(cm_counts)
    print(cm_pct)

for metric in all_metrics_per_label:
    print(f"test_{metric}", np.mean(all_metrics_per_label[metric]))

## Full Training

In [None]:
full_images = np.concatenate([images, test_images])
full_labels = np.concatenate([labels, test_labels])

full_images.shape, full_labels.shape

In [None]:
val_pct = 0.1
val_split = int(val_pct*len(full_labels))

# TODO: Stratified split
#possible_labels = sorted(list(set(labels)))
#proportion_in_labels = np.array([sum(labels==i)/len(labels) for i in possible_labels])
#val_per_labels = [int(l*val_split) for l in proportion_in_labels]
#proportion_in_labels, val_per_labels
val_split

In [None]:
label_weigths = None#1/proportion_in_labels[labels]

In [None]:
# TODO: Stratified split
#val_idx = np.concatenate([np.random.choice([i for i, l in enumerate(labels) if l==p], c, replace=False) for p,c in zip(possible_labels, val_per_labels)])
val_idx = np.random.choice([i for i in range(len(full_labels))], val_split, replace=False)
in_val = np.isin(np.arange(len(full_labels)),val_idx)
in_train = ~in_val
train_idx = np.arange(len(full_labels))[in_train]
val_idx

In [None]:
batch_size = 4#3
while (len(full_images)-val_split)%batch_size==1 or val_split%batch_size==1:
    batch_size +=1
    print("Changing batch size so that no batch has size 1")
batch_size

In [None]:
# create a training data loader
train_ds = ImageDataset(image_files=full_images[train_idx].tolist(), labels=full_labels[train_idx], transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=pin_memory)

# create a validation data loader
val_ds = ImageDataset(image_files=full_images[val_idx].tolist(), labels=full_labels[val_idx], transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=2, pin_memory=pin_memory)

In [None]:
# Create DenseNet121, CrossEntropyLoss and Adam optimizer
model = monai.networks.nets.DenseNet(spatial_dims=3, in_channels=1, out_channels=np.sum(num_labels)).to(device)

#loss_function = [torch.nn.CrossEntropyLoss(torch.tensor(1/proportion_in_labels, device=device, dtype=torch.float32)) for _ in label_types]
loss_function = [torch.nn.CrossEntropyLoss() for _ in label_types]
# loss_function = torch.nn.BCEWithLogitsLoss()  # also works with this data

optimizer = torch.optim.Adam(model.parameters(), 1e-4)

# start a typical PyTorch training
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
writer = SummaryWriter(f"full_multitask_runs/{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}_{socket.gethostname()}")
max_epochs = 256

num_labels_ranges = [0]+num_labels_cumsum

for epoch in range(max_epochs):
    try:
        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_epochs}")
        model.train()
        epoch_loss = 0
        step = 0

        for batch_data in train_loader:
            step += 1
            b_inputs, b_labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(b_inputs)
            loss = 0
            for l_idx,(loss_fn,loss_w,s,e) in enumerate(zip(loss_function, label_type_weights, num_labels_ranges[:-1],num_labels_ranges[1:])):
                loss += loss_fn(outputs[:,s:e], b_labels[:,l_idx])*loss_w
            loss /= sum(label_type_weights)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)

        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()

            num_correct = 0.0
            metric_count = 0
            v_pred_raw, v_label = get_predicted_values(model, val_loader)
            v_pred = np.stack([
                    v_pred_raw[:,s:e].argmax(axis=1)
                    for (s,e) in zip(num_labels_ranges[:-1],num_labels_ranges[1:])
                ],
                axis=1,
            )

            all_metrics_per_label = {
                metric: [] for metric in metrics_fns
            }
            
            for l_idx, l in enumerate(label_types):
                for metric in all_metrics_per_label:
                    metric_value = metrics_fns[metric](v_label[:,l_idx], v_pred[:,l_idx])
                    writer.add_scalar(f"{l} {metric}",metric_value, epoch + 1)
                    all_metrics_per_label[metric].append(metric_value)
                cm_counts = get_cm(v_pred[:,l_idx], v_label[:,l_idx], num_labels[l_idx])
                cm_pct = cm_counts/cm_counts.sum(axis=1,keepdims=True)
                print(cm_pct)
                for i in range(num_labels[l_idx]):
                    for j in range(num_labels[l_idx]):
                        writer.add_scalar(f"{l} count l{i}_p{j}",cm_counts[i,j], epoch + 1)
                        writer.add_scalar(f"{l} pct l{i}_p{j}",cm_pct[i,j], epoch + 1)

            metric = np.mean(all_metrics_per_label[val_metric])
            metric_values.append(metric)

            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), "full_multitask_best_metric_model_classification3d_array.pth")
                print("saved new best metric model")

            print(f"Current epoch: {epoch+1} current {val_metric}: {metric:.4f} ")
            print(f"Best {val_metric}: {best_metric:.4f} at epoch {best_metric_epoch}")
            for metric in all_metrics_per_label:
                writer.add_scalar(f"val_{metric}", np.mean(all_metrics_per_label[metric]), epoch + 1)
    except KeyboardInterrupt:
        break

print(f"Training completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
torch.save(model.state_dict(), "full_multitask_last_model_classification3d_array.pth")
writer.close()