In [1]:
import copy
import torch
import torch.nn as nn
import numpy as np

from data_generation.generate_3d import ImageGenerator
from data_generation.generate_utils import get_batch
from data_generation.config import original_image_shape, cubic_simple_dims

from models.net_utils import calculate_jaccard_score, calculate_dice_score, save_model
from models.net_utils import get_best_device, prepare_image_for_network_input, prepare_image_for_analysis
from models.net_visualizations import three_d_visualize_model_progress, display3DImageMaskTuple

from models.unet3D import UNet3D, dice_bce_loss
from data.data_utils import pad_image, divide_3d_image_into_patches, get_padded_patches

from sklearn.model_selection import train_test_split
from server_specific.server_utils import get_patients

In [2]:
patients = get_patients()
patients = np.array(patients)

train_patients = patients[0:700]
test_patients = patients[700:]

print("Train dataset:", train_patients.shape)
print("Test dataset:", test_patients.shape)

amt of detected_files:  1800
amt of patients:  800
Train dataset: (700,)
Test dataset: (100,)


In [3]:
# define which device is used for training
# todo replace with original image shape
default_image_shape = original_image_shape # only works for 3d

device = get_best_device()

if device == "mps":
    # mps is not supported for 3d
    device = "cpu"

torch.set_default_device(device)
print(f"Using {device} device. Every tensor created will be by default on {device}")

Using cuda device. Every tensor created will be by default on cuda


In [4]:
print("setting default functions for three dimensions")
image_generator = ImageGenerator(default_image_shape)

default_image_mask_visulization_function = display3DImageMaskTuple
default_model_progress_visualization_function = three_d_visualize_model_progress

model = UNet3D(in_channels=1, num_classes=1)
model.to(device);

setting default functions for three dimensions


In [5]:
image, mask = image_generator.get_3DImage()
print("image_shape: ", image.shape)

image_shape:  (275, 512, 512)


In [6]:
print("----------------TRAINING-------------")

additional_score_tuples = [("jaccard score", calculate_jaccard_score), ("dice score", calculate_dice_score)] # ("hausdorff distance", calculate_hausdorff_distance)

padded_shape = (320, 512, 512)
patch_size = 64
block_shape = (patch_size, patch_size, patch_size)

custom_threshold = 0.5

def train_loop(model, loss_fn, optimizer, image, mask):
    model.train()
    
    avg_train_loss = 0
    avg_jaccard_score = 0
    avg_dice_score = 0

    image_patches, mask_patches = get_padded_patches(image, mask, patch_size)

    patch_shape = image_patches.shape
    amt_of_image_patches = patch_shape[0] * patch_shape[1] * patch_shape[2]
    amt_of_image_patches_with_arteries = amt_of_image_patches
    # reconstructed_prediction_mask = np.zeros((padded_shape))
    
    patch_counter = 0
    for i in range(patch_shape[0]):
        for j in range(patch_shape[1]):
            for k in range(patch_shape[2]):
                patch_counter += 1
                
                current_image_patch = image_patches[i, j, k]
                current_mask_patch = mask_patches[i, j, k]

                if np.all(current_mask_patch == 0):
                    # print("skipping patch")
                    amt_of_image_patches_with_arteries -= 1
                    continue
                    
                current_image_patch = prepare_image_for_network_input(current_image_patch)
                current_mask_patch = prepare_image_for_network_input(current_mask_patch)

                optimizer.zero_grad()

                current_prediction_patch = model(current_image_patch)
                loss = loss_fn(current_prediction_patch, current_mask_patch)
                
                additional_metrics = []
                
                for name, additional_score_function in additional_score_tuples:
                    score = additional_score_function(current_mask_patch, current_image_patch, threshold=custom_threshold)
                    additional_metrics.append((name, score))

                loss.backward()
                optimizer.step()

                train_loss = loss.item()
                avg_train_loss += train_loss
                avg_jaccard_score += calculate_jaccard_score(current_mask_patch, current_prediction_patch)
                avg_dice_score += calculate_dice_score(current_mask_patch, current_prediction_patch)
                
                train_log = f"Patch number: {patch_counter} / {amt_of_image_patches}, Train loss: {train_loss:>8f}"
                for name, score in additional_metrics:
                    train_log = train_log + f" | {name}: {score}"

                print(train_log, end="\r")
    
    avg_train_loss /= amt_of_image_patches_with_arteries
    avg_jaccard_score /= amt_of_image_patches_with_arteries
    avg_dice_score /= amt_of_image_patches_with_arteries

    print(f"For this patient: average train loss: {avg_train_loss:8f}  | average Jaccard Score: {avg_jaccard_score:8f} | average Dice Score: {avg_dice_score:8f}\n")
    
    return avg_train_loss

----------------TRAINING-------------


In [7]:
def test_loop(model, loss_fn, image, mask):
    model.eval()

    test_loss = 0
    jaccard_score = 0
    dice_score = 0

    image_patches, mask_patches = get_padded_patches(image, mask, patch_size)

    patch_shape = image_patches.shape
    amt_of_image_patches = patch_shape[0] * patch_shape[1] * patch_shape[2]
    amt_of_image_patches_with_arteries = amt_of_image_patches
    # reconstructed_prediction_mask = np.zeros((padded_shape))
    
    patch_counter = 0
    for i in range(patch_shape[0]):
        for j in range(patch_shape[1]):
            for k in range(patch_shape[2]):
                patch_counter += 1
                current_image_patch = image_patches[i, j, k]
                current_mask_patch = mask_patches[i, j, k]
                
                if np.all(current_mask_patch == 0):
                    # print("skipping patch")
                    amt_of_image_patches_with_arteries -= 1
                    continue
                    
                current_image_patch = prepare_image_for_network_input(current_image_patch)
                current_mask_patch = prepare_image_for_network_input(current_mask_patch)

                with torch.no_grad():
                    current_prediction_patch = model(current_image_patch)
                    patch_loss = loss_fn(current_prediction_patch, current_mask_patch).item()
                    test_loss += patch_loss

                    patch_jaccard_score = calculate_jaccard_score(current_mask_patch, current_prediction_patch, threshold=custom_threshold)
                    patch_dice_score = calculate_dice_score(current_mask_patch, current_prediction_patch, threshold=custom_threshold)

                    jaccard_score += patch_jaccard_score
                    dice_score += patch_dice_score

                    # print(f"Patch number: {patch_counter} / {amt_of_image_patches}, Test loss: {patch_loss:>8f}  | Jaccard Score: {patch_jaccard_score:>8f} | Dice Score: {patch_dice_score:>8f}\n", end="\r")

    test_loss /= amt_of_image_patches_with_arteries
    jaccard_score /= amt_of_image_patches_with_arteries
    dice_score /= amt_of_image_patches_with_arteries

    print(f"Patch number: {patch_counter} / {amt_of_image_patches}, average test loss: {test_loss:8f}  | average Jaccard Score: {jaccard_score:8f} | average Dice Score: {dice_score:8f}\n", end="\r")

    # default_model_progress_visualization_function(model, image_generator.get_3DImage)
    
    # to make sure that the plot gets displayed during training
    # plt.pause(0.001)

    # returning for patience
    return test_loss

In [8]:
# resetting the model
model = UNet3D(in_channels=1, num_classes=1)
model.to(device)

# print("model prediction at initialization: ")
# default_model_progress_visualization_function(model, image_generator.get_3DImage)

# running it
# param initialization for patience
best_loss = float('inf')  
best_model_weights = None  
patience_base_value = 6
patience = patience_base_value

# loss_fn = nn.BCELoss()
loss_fn = dice_bce_loss
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
epochs = 2

In [9]:
# overriding test patients to always test on the same patients
amt_test_patients = 10
current_test_patients = test_patients[0:amt_test_patients]

try:
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        print("tuple of 7 train patients, 1 test patient")
        for i in range(100):
            print(f"processing tuple {i} from 100 tuples")
            
            current_train_patients = train_patients[i*7:(i+1)*7]
            # current_test_patients = test_patients[i]
            
            average_train_losses = []
            average_test_losses = []
            for current_train_idx, current_train_patient in enumerate(current_train_patients):
                print(f"current train patient: {current_train_idx + 1} / 7")
                current_train_image, current_train_mask = current_train_patient.get_preprocessed_image_mask_tuple()
                current_average_train_loss = train_loop(model, loss_fn, optimizer, current_train_image, current_train_mask)
                average_train_losses.append(current_average_train_loss)
            # testing for each epoch to track the models performance during training.
            
            avg_test_loss = 0
            for test_patient_idx, current_test_patient in enumerate(current_test_patients):
                current_test_image, current_test_mask = current_test_patient.get_preprocessed_image_mask_tuple()
                current_test_loss = test_loop(model, loss_fn, current_test_image, current_test_mask)
                avg_test_loss += current_test_loss
            
            avg_test_loss /= amt_test_patients
            print(f"average test loss on the {amt_test_patients} first test patients: {avg_test_loss}")
            
            if avg_test_loss < best_loss:
                print(f"previous best loss: {best_loss}")
                best_loss = avg_test_loss
                best_model_weights = copy.deepcopy(model.state_dict())
                patience = patience_base_value
            else:
                print(f"current best loss: {best_loss}")
                patience -= 1
                if patience <= 0:
                    break
            print("patience: ", patience)
    print("Done!")
except KeyboardInterrupt:
    print("training interrupted by the user")
    model.eval()

Epoch 1
-------------------------------
tuple of 7 train patients, 1 test patient
processing tuple 0 from 100 tuples
current train patient: 1 / 7
For this patient: average train loss: 1.170848  | average Jaccard Score: 0.000672 | average Dice Score: 0.001330659946632

current train patient: 2 / 7
training interrupted by the user


In [None]:
print("------INFERENCE--------")

'''
for i in range(10):
    mask, pred = default_model_progress_visualization_function(model, get_image_fct=image_generator.get_3DImage)
    print("jaccard score for above image: ", calculate_jaccard_score(mask, pred))
'''

In [None]:
save_model(model)