In [1]:
import copy
import torch
import torch.nn as nn
import numpy as np
import random

from pathlib import Path
import training_configuration
from data_generation.generate_3d import ImageGenerator
from data_generation.generate_utils import get_batch
from data_generation.config import original_image_shape, cubic_simple_dims

from models.net_utils import calculate_jaccard_score, calculate_dice_scores, save_model
from models.net_utils import get_best_device, prepare_image_for_network_input, prepare_image_for_analysis
from models.net_visualizations import three_d_visualize_model_progress, display3DImageMaskTuple

from models.unet3D import UNet3D, dice_bce_loss
from data.data_utils import pad_image, divide_3d_image_into_patches, get_padded_patches

from sklearn.model_selection import train_test_split
from server_specific.server_utils import get_patients

In [2]:
patches_folder = Path("/home/tu-philw/group/gecko/pweinmann/mip_local_unet/preprocessed_patches/")
def get_preprocessed_patches():
    patch_fps = list(patches_folder.iterdir())

    patch_fps = [file for file in patch_fps if "ipynb_checkpoints" not in str(file)]

    print("amt of detected patch files: ", len(patch_fps))
    random.shuffle(patch_fps)
    
    return patch_fps

def get_image_mask_from_patch_fp(patch_fp):
    patch = np.load(patch_fp)
    image = patch["image"]
    mask = patch["mask"]

    return image, mask

In [3]:
preprocessed_patches = get_preprocessed_patches()

amt of detected patch files:  22667


In [4]:
# define which device is used for training
device = get_best_device()

if device == "mps":
    # mps is not supported for 3d
    device = "cpu"

torch.set_default_device(device)
print(f"Using {device} device. Every tensor created will be by default on {device}")

Using cuda device. Every tensor created will be by default on cuda


In [5]:
print("----------------TRAINING-------------")

patch_size = training_configuration.PATCH_SIZE
block_shape = (patch_size, patch_size, patch_size)
dice_thresholds = np.array([0.1, 0.25, 0.4, 0.5])

def train_loop(model, loss_fn, optimizer, patch_fps):
    model.train()
    
    avg_train_loss = 0
    avg_dice_scores = np.zeros(dice_thresholds.shape)
    
    processed_patch_counter = 0
    
    amt_of_patches = len(patch_fps)
    for patch_number, patch_fp in enumerate(patch_fps):
        print(f"{patch_number} / {amt_of_patches}", end="\r")
        image_patch, mask_patch = get_image_mask_from_patch_fp(patch_fp)
        
        non_zero_count = np.count_nonzero(mask_patch)
        if non_zero_count < 10000:
            continue
        
        image_patch = prepare_image_for_network_input(image_patch)
        mask_patch = prepare_image_for_network_input(mask_patch)

        optimizer.zero_grad()

        patch_pred = model(image_patch)
        loss = loss_fn(patch_pred, mask_patch)

        loss.backward()
        optimizer.step()

        train_loss = loss.item()
        avg_train_loss += train_loss

        # list of tuples with threshold, score
        avg_dice_scores += calculate_dice_scores(mask_patch, patch_pred, thresholds=dice_thresholds)
        
        processed_patch_counter += 1
        if (processed_patch_counter) % 20 == 0:
            avg_train_loss /= 20
            avg_dice_scores /= 20
            formatted_scores = ', '.join([f'({t:.2f}, {s:.6f})' for t, s in zip(dice_thresholds, avg_dice_scores)])
            train_log = f"Patch number: {patch_number} / {amt_of_patches}, Train loss: {avg_train_loss:>8f}, Dice Scores: {formatted_scores}"
            print(train_log)
            
            avg_train_loss = 0
            avg_dice_scores.fill(0)

----------------TRAINING-------------


In [6]:
# resetting the model
model = UNet3D(in_channels=1, num_classes=1)
model.to(device)

# running it
loss_fn = dice_bce_loss
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
epochs = 2

In [None]:
debugging = True

if debugging:
    train_loop(model, loss_fn, optimizer, preprocessed_patches)
else:
    try:
        for t in range(epochs):
            print(f"Epoch {t+1}\n-------------------------------")

            train_loop(model, loss_fn, optimizer, preprocessed_patches)
        print("Done!")
    except:
        print("Keyboard interruption.")
        model.eval()

15 / 22667

In [None]:
print("------INFERENCE--------")

In [None]:
if not debugging:
    save_model(model)