In [13]:
import copy
import torch
import torch.nn as nn
import numpy as np
import random

from pathlib import Path
import training_configuration
from data_generation.generate_3d import ImageGenerator
from data_generation.generate_utils import get_batch
from data_generation.config import original_image_shape, cubic_simple_dims

from models.net_utils import calculate_jaccard_score, calculate_dice_scores, save_model
from models.net_utils import get_best_device, prepare_image_for_network_input, prepare_image_for_analysis
from models.net_visualizations import three_d_visualize_model_progress, display3DImageMaskTuple

from models.unet3D import UNet3D, dice_bce_loss
from data.data_utils import pad_image, divide_3d_image_into_patches, get_padded_patches

from sklearn.model_selection import train_test_split
from server_specific.server_utils import get_patients

In [14]:
patches_folder = Path("/home/tu-philw/group/gecko/pweinmann/mip_local_unet/preprocessed_patches/")
def get_preprocessed_patches():
    patch_fps = list(patches_folder.iterdir())

    patch_fps = [file for file in patch_fps if "ipynb_checkpoints" not in str(file)]

    print("amt of detected patch files: ", len(patch_fps))
    random.shuffle(patch_fps)
    
    return patch_fps

def get_image_mask_from_patch_fp(patch_fp):
    patch = np.load(patch_fp)
    image = patch["image"]
    mask = patch["mask"]

    return image, mask

In [15]:
preprocessed_patches = get_preprocessed_patches()

amt of detected patch files:  22667


In [None]:
means = []

amt_patches = len(preprocessed_patches)
for p_idx, preprocessed_patch in enumerate(preprocessed_patches):
    print(f"patch {p_idx}/{amt_patches}", end="\r")
    image_patch, _ = get_image_mask_from_patch_fp(preprocessed_patch)
    means.append(image_patch.mean())

patch 712/22667

In [None]:
plt.hist(means, bins=10)
plt.xlabel("bin means")
plt.ylabel("amount of elements in bin")
plt.show()

In [None]:
mean_of_means = means.mean()
print(mean_of_means)

In [4]:
# define which device is used for training
device = get_best_device()

if device == "mps":
    # mps is not supported for 3d
    device = "cpu"

torch.set_default_device(device)
print(f"Using {device} device. Every tensor created will be by default on {device}")

Using cuda device. Every tensor created will be by default on cuda


In [8]:
print("----------------TRAINING-------------")

patch_size = training_configuration.PATCH_SIZE
block_shape = (patch_size, patch_size, patch_size)
dice_thresholds = np.array([0.1, 0.25, 0.4, 0.5])

def train_loop(model, loss_fn, optimizer, scheduler, patch_fps):
    model.train()
    
    avg_train_loss = 0
    avg_dice_scores = np.zeros(dice_thresholds.shape)
    
    processed_patch_counter = 0
    
    amt_of_patches = len(patch_fps)
    for patch_number, patch_fp in enumerate(patch_fps):
        print(f"{patch_number} / {amt_of_patches}", end="\r")
        image_patch, mask_patch = get_image_mask_from_patch_fp(patch_fp)
        
        non_zero_count = np.count_nonzero(mask_patch)
        if non_zero_count < 10000:
            continue
        
        image_patch = prepare_image_for_network_input(image_patch)
        mask_patch = prepare_image_for_network_input(mask_patch)

        optimizer.zero_grad()

        patch_pred = model(image_patch)
        loss = loss_fn(patch_pred, mask_patch)

        loss.backward()
        optimizer.step()

        train_loss = loss.item()
        avg_train_loss += train_loss

        # list of tuples with threshold, score
        avg_dice_scores += calculate_dice_scores(mask_patch, patch_pred, thresholds=dice_thresholds)
        
        processed_patch_counter += 1
        if (processed_patch_counter) % 20 == 0:
            avg_train_loss /= 20
            avg_dice_scores /= 20
            
            scheduler.step(avg_train_loss)
            formatted_scores = ', '.join([f'({t:.2f}, {s:.6f})' for t, s in zip(dice_thresholds, avg_dice_scores)])
            train_log = f"Patch number: {patch_number} / {amt_of_patches}, Train loss: {avg_train_loss:>8f}, Dice Scores: {formatted_scores}"
            print(train_log)
            
            avg_train_loss = 0
            avg_dice_scores.fill(0)

----------------TRAINING-------------


In [9]:
# resetting the model
model = UNet3D(in_channels=1, num_classes=1)
model.to(device)

# running it
loss_fn = dice_bce_loss
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=6)
epochs = 2

In [10]:
debugging = True

if debugging:
    train_loop(model, loss_fn, optimizer, scheduler, preprocessed_patches)
else:
    try:
        for t in range(epochs):
            print(f"Epoch {t+1}\n-------------------------------")

            train_loop(model, loss_fn, optimizer, scheduler, preprocessed_patches)
        print("Done!")
    except:
        print("Keyboard interruption.")
        model.eval()

Patch number: 341 / 22667, Train loss: 0.305021, Dice Scores: (0.10, 0.729422), (0.25, 0.729114), (0.40, 0.728553), (0.50, 0.727898)
Patch number: 530 / 22667, Train loss: 0.310938, Dice Scores: (0.10, 0.723755), (0.25, 0.722857), (0.40, 0.722115), (0.50, 0.721419)
Patch number: 842 / 22667, Train loss: 0.366079, Dice Scores: (0.10, 0.671942), (0.25, 0.670363), (0.40, 0.668125), (0.50, 0.665965)
Patch number: 1141 / 22667, Train loss: 0.321844, Dice Scores: (0.10, 0.714178), (0.25, 0.713467), (0.40, 0.712080), (0.50, 0.710153)
Patch number: 1449 / 22667, Train loss: 0.253448, Dice Scores: (0.10, 0.779204), (0.25, 0.779638), (0.40, 0.778945), (0.50, 0.778316)
Patch number: 1829 / 22667, Train loss: 0.246181, Dice Scores: (0.10, 0.786322), (0.25, 0.786911), (0.40, 0.786066), (0.50, 0.785086)
Patch number: 2079 / 22667, Train loss: 0.321343, Dice Scores: (0.10, 0.710631), (0.25, 0.710933), (0.40, 0.710460), (0.50, 0.710043)
Patch number: 2283 / 22667, Train loss: 0.312172, Dice Scores: (0

In [None]:
print("------INFERENCE--------")

In [None]:
if not debugging:
    save_model(model)