In [1]:
import copy
import torch
import torch.nn as nn
import numpy as np

import training_configuration
from data_generation.generate_3d import ImageGenerator
from data_generation.generate_utils import get_batch
from data_generation.config import original_image_shape, cubic_simple_dims

from models.net_utils import calculate_jaccard_score, calculate_dice_score, save_model
from models.net_utils import get_best_device, prepare_image_for_network_input, prepare_image_for_analysis
from models.net_visualizations import three_d_visualize_model_progress, display3DImageMaskTuple

from models.unet3D import UNet3D, dice_bce_loss
from data.data_utils import pad_image, divide_3d_image_into_patches, get_padded_patches

from sklearn.model_selection import train_test_split
from server_specific.server_utils import get_patients

In [2]:
patients = get_patients()
patients = np.array(patients)

train_patients = patients[0:700]
test_patients = patients[700:]

print("Train dataset:", train_patients.shape)
print("Test dataset:", test_patients.shape)

amt of detected_files:  1800
amt of patients:  800
Train dataset: (700,)
Test dataset: (100,)


In [3]:
# define which device is used for training
# todo replace with original image shape
default_image_shape = original_image_shape # only works for 3d

device = get_best_device()

if device == "mps":
    # mps is not supported for 3d
    device = "cpu"

torch.set_default_device(device)
print(f"Using {device} device. Every tensor created will be by default on {device}")

Using cuda device. Every tensor created will be by default on cuda


In [4]:
print("----------------TRAINING-------------")

patch_size = training_configuration.PATCH_SIZE
block_shape = (patch_size, patch_size, patch_size)

def train_loop(model, loss_fn, optimizer, image, mask):
    model.train()
    
    avg_train_loss = 0
    avg_dice_score = 0

    image_patches, mask_patches = get_padded_patches(image, mask, patch_size)

    patch_shape = image_patches.shape
    amt_of_image_patches = patch_shape[0] * patch_shape[1] * patch_shape[2]
    amt_of_image_patches_with_arteries = amt_of_image_patches
    
    patch_counter = 0
    for i in range(patch_shape[0]):
        for j in range(patch_shape[1]):
            for k in range(patch_shape[2]):
                patch_counter += 1
                
                current_image_patch = image_patches[i, j, k]
                current_mask_patch = mask_patches[i, j, k]
                
                non_zero_count = np.count_nonzero(current_mask_patch)
                if non_zero_count < 10000:
                    # print("skipping patch")
                    amt_of_image_patches_with_arteries -= 1
                    continue
                    
                
                current_image_patch = current_image_patch.astype(np.float16)
                current_mask_patch = current_mask_patch.astype(np.float16)
                
                current_image_patch = prepare_image_for_network_input(current_image_patch)
                current_mask_patch = prepare_image_for_network_input(current_mask_patch)

                optimizer.zero_grad()

                current_prediction_patch = model(current_image_patch)
                loss = loss_fn(current_prediction_patch, current_mask_patch)

                loss.backward()
                optimizer.step()

                train_loss = loss.item()
                avg_train_loss += train_loss
                avg_dice_score += calculate_dice_score(current_mask_patch, current_prediction_patch)
                
                train_log = f"Patch number: {patch_counter} / {amt_of_image_patches}, Train loss: {train_loss:>8f}"

                print(train_log, end="\r")
    
    avg_train_loss /= amt_of_image_patches_with_arteries
    avg_dice_score /= amt_of_image_patches_with_arteries

    print(f"For this patient: average train loss: {avg_train_loss:8f} | average Dice Score: {avg_dice_score:8f}\n")
    
    return avg_train_loss

----------------TRAINING-------------


In [5]:
# resetting the model
model = UNet3D(in_channels=1, num_classes=1)
model.to(device)

# running it
loss_fn = dice_bce_loss
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
epochs = 2

In [6]:
# overriding test patients to always test on the same patients
amt_test_patients = 10
current_test_patients = test_patients[0:amt_test_patients]

try:
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        print("tuple of 7 train patients, 1 test patient")
        for i in range(100):
            print(f"processing tuple {i} from 100 tuples")
            
            current_train_patients = train_patients[i*7:(i+1)*7]
            # current_test_patients = test_patients[i]
            
            average_train_losses = []
            average_test_losses = []
            for current_train_idx, current_train_patient in enumerate(current_train_patients):
                print(f"current train patient: {current_train_idx + 1} / 7")
                current_train_image, current_train_mask = current_train_patient.get_preprocessed_image_mask_tuple()
                current_average_train_loss = train_loop(model, loss_fn, optimizer, current_train_image, current_train_mask)
                average_train_losses.append(current_average_train_loss)
                
    print("Done!")
except:
    print("Keyboard interruption.")
    model.eval()

Epoch 1
-------------------------------
tuple of 7 train patients, 1 test patient
processing tuple 0 from 100 tuples
current train patient: 1 / 7
For this patient: average train loss: 1.248214 | average Dice Score: 0.011983

current train patient: 2 / 7
For this patient: average train loss: 1.191321 | average Dice Score: 0.007877

current train patient: 3 / 7
For this patient: average train loss: 1.153168 | average Dice Score: 0.000956

current train patient: 4 / 7
For this patient: average train loss: 1.129499 | average Dice Score: 0.000000

current train patient: 5 / 7
For this patient: average train loss: 1.104601 | average Dice Score: 0.000000

current train patient: 6 / 7
For this patient: average train loss: 1.086791 | average Dice Score: 0.000000

current train patient: 7 / 7
For this patient: average train loss: 1.072593 | average Dice Score: 0.000000

processing tuple 1 from 100 tuples
current train patient: 1 / 7
For this patient: average train loss: 1.068395 | average Dice S

In [7]:
print("------INFERENCE--------")

------INFERENCE--------


In [8]:
# save_model(model)

model saved at: saved_models/3d_model20241124-115339.pth
