In [16]:
import copy
import torch
import torch.nn as nn
import numpy as np
import random

from pathlib import Path
import training_configuration
from data_generation.generate_3d import ImageGenerator
from data_generation.generate_utils import get_batch
from data_generation.config import original_image_shape, cubic_simple_dims

from models.net_utils import calculate_jaccard_score, calculate_dice_scores, save_model
from models.net_utils import get_best_device, prepare_image_for_network_input, prepare_image_for_analysis
from models.net_visualizations import three_d_visualize_model_progress, display3DImageMaskTuple

from models.unet3D import UNet3D, dice_bce_loss, DICEBCE
from data.data_utils import pad_image, divide_3d_image_into_patches, get_padded_patches

from sklearn.model_selection import train_test_split
from server_specific.server_utils import get_patients

In [17]:
patches_folder = Path("/home/tu-philw/group/gecko/pweinmann/mip_local_unet/preprocessed_patches/")
def get_preprocessed_patches():
    patch_fps = list(patches_folder.iterdir())

    patch_fps = [file for file in patch_fps if "ipynb_checkpoints" not in str(file)]

    print("amt of detected patch files: ", len(patch_fps))
    random.shuffle(patch_fps)
    
    return patch_fps

def get_image_mask_from_patch_fp(patch_fp):
    patch = np.load(patch_fp)
    image = patch["image"]
    mask = patch["mask"].astype(np.bool_)

    return image, mask

In [18]:
preprocessed_patches = get_preprocessed_patches()

amt of detected patch files:  22667


In [32]:
def get_patient_idx(fp):
    return str(fp).split('/')[-1].split("_")[0]

def divide_in_train_test_split(preprocessed_patches):
    idxs = []
    for preprocessed_path in preprocessed_patches:
        idxs.append(str(preprocessed_path).split('/')[-1].split("_")[0])
    
    print(f"len idxs list: {len(idxs)}")
    indexes_list = list(set(idxs))
    print(f"len idxs listset: {len(indexes_list)}")
    train_idxs = indexes_list[0:740]
    val_idxs = indexes_list[740:760]
    test_idxs = indexes_list[740:800]
    return train_idxs, val_idxs, test_idxs

train_idxs, val_idxs, test_idxs = divide_in_train_test_split(preprocessed_patches)

train_fps = [fp for fp in preprocessed_patches if get_patient_idx(fp) in train_idxs]
test_fps = [fp for fp in preprocessed_patches if get_patient_idx(fp) in test_idxs]
val_fps = [fp for fp in preprocessed_patches if get_patient_idx(fp) in val_idxs]


len idxs list: 22667
len idxs listset: 800


In [None]:
# define which device is used for training
device = get_best_device()

if device == "mps":
    # mps is not supported for 3d
    device = "cpu"

torch.set_default_device(device)
print(f"Using {device} device. Every tensor created will be by default on {device}")

In [None]:
max_dice_threshold = 20000
total_weight = 1.5
min_bce_weight = 0.2

def get_appropriate_dice_weight(amt_positive_voxels):
    bce_weight = (-(total_weight - min_bce_weight)/max_dice_threshold) * amt_positive_voxels + 1.5
    dice_weight = total_weight - bce_weight
    
    return dice_weight, bce_weight

In [None]:
pos_voxel_threshold = 7000
max_lr_threshold = 0.1

def calculate_learning_rate(amt_positive_voxels, epoch):
    # make the dice loss an exponential function. 0.0001 if there are no pos voxels, 0.1 if above pos_voxel_threshold
    lr = 10 ** -(epoch) * 0.00001 * np.exp(amt_positive_voxels * (3*np.log(10)/pos_voxel_threshold))
    lr = min(lr, max_lr_threshold)
    
    # print(f"learning rate: {lr}, amt_positive_voxels: {amt_positive_voxels}")
    return lr

In [None]:
print("----------------TRAINING-------------")

patch_size = training_configuration.PATCH_SIZE
block_shape = (patch_size, patch_size, patch_size)
dice_thresholds = np.array([0.1, 0.25, 0.5])

logging_frequency = 48
val_frequency = 100

def train_loop(model, loss_fn, optimizer, patch_fps, epoch):
    model.train()
    
    avg_train_loss = 0
    avg_dice_scores = np.zeros(dice_thresholds.shape)
    
    processed_patch_counter = 0
    
    combined_mask = []
    combined_pred = []
    
    amt_of_patches = len(patch_fps)
    for patch_number, patch_fp in enumerate(patch_fps):
        print(f"{patch_number} / {amt_of_patches}", end="\r")
        image_patch, mask_patch = get_image_mask_from_patch_fp(patch_fp)
        
        amt_positive_voxels = np.count_nonzero(mask_patch)
        dynamic_lr = calculate_learning_rate(amt_positive_voxels, epoch)
        
        dynamic_loss_weights = get_appropriate_dice_weight(amt_positive_voxels)
        
        DICEBCE.dice_weight = dynamic_loss_weights[0]
        DICEBCE.bce_weight = dynamic_loss_weights[1]
        
        # Set the learning rate in the optimizer
        for param_group in optimizer.param_groups:
            param_group['lr'] = dynamic_lr
        
        image_patch = prepare_image_for_network_input(image_patch)
        mask_patch = prepare_image_for_network_input(mask_patch)

        optimizer.zero_grad()

        patch_pred = model(image_patch)
        loss = loss_fn(patch_pred, mask_patch)

        loss.backward()
        optimizer.step()

        train_loss = loss.item()
        avg_train_loss += train_loss
        
        if (patch_number + 1) % logging_frequency == 0:
            avg_train_loss /= logging_frequency
            
            train_log = f"Patch number: {patch_number} / {amt_of_patches}, Train loss: {avg_train_loss:>8f}"
            print(train_log)
            
            avg_train_loss = 0
            
        if (patch_number + 1) % val_frequency == 0:
            # val_loop(
            pass

In [None]:
# resetting the model
model = UNet3D(in_channels=1, num_classes=1)
model.to(device)

# running it
loss_fn = DICEBCE(1,0.5)

# the lr does not matter here, it is set depending on the amt of positive voxels
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

epochs = 3

In [None]:
try:
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}\n-------------------------------")

        train_loop(model, loss_fn, optimizer, preprocessed_patches, epoch)
    print("Done!")
except:
    print("this better have been a keyboard interrupt")
    model.eval()

Patch number: 47 / 22667, Train loss: 1.286232, Dice Scores: (0.10, 0.003208), (0.25, 0.003208), (0.50, 0.004844)
Patch number: 95 / 22667, Train loss: 1.224598, Dice Scores: (0.10, 0.001662), (0.25, 0.001677), (0.50, 0.000000)
Patch number: 143 / 22667, Train loss: 1.170175, Dice Scores: (0.10, 0.001465), (0.25, 0.001320), (0.50, 0.000000)
Patch number: 191 / 22667, Train loss: 1.083066, Dice Scores: (0.10, 0.001621), (0.25, 0.000366), (0.50, 0.000000)
Patch number: 239 / 22667, Train loss: 1.036649, Dice Scores: (0.10, 0.000600), (0.25, 0.000192), (0.50, 0.000000)
Patch number: 287 / 22667, Train loss: 1.022834, Dice Scores: (0.10, 0.001331), (0.25, 0.000542), (0.50, 0.000000)
Patch number: 335 / 22667, Train loss: 1.011399, Dice Scores: (0.10, 0.003706), (0.25, 0.000656), (0.50, 0.000000)
Patch number: 383 / 22667, Train loss: 1.010887, Dice Scores: (0.10, 0.005520), (0.25, 0.000386), (0.50, 0.000000)
Patch number: 431 / 22667, Train loss: 1.003744, Dice Scores: (0.10, 0.021236), (0

In [None]:
print("------INFERENCE--------")

In [None]:
debugging = True

if not debugging:
    save_model(model)