In [None]:
import copy
import torch
import torch.nn as nn
import numpy as np

from data_generation.generate_3d import ImageGenerator
from data_generation.generate_utils import get_batch
from data_generation.config import original_image_shape, cubic_simple_dims

from matplotlib import pyplot as plt
from models.net_utils import calculate_jaccard_score, calculate_dice_score
from models.net_utils import get_best_device, prepare_image_for_network_input, prepare_image_for_analysis
from models.net_visualizations import three_d_visualize_model_progress, display3DImageMaskTuple

from models.cnnBinary import CNN_Binary
from data.data_utils import pad_image, divide_3d_image_into_patches, get_padded_patches

from sklearn.model_selection import train_test_split
from server_specific.server_utils import get_patients

In [2]:
patients = get_patients()
patients = np.array(patients)

train_patients = patients[0:700]
test_patients = patients[700:]

print("Train dataset:", train_patients.shape)
print("Test dataset:", test_patients.shape)


amt of detected_files:  1800
amt of patients:  800
Train dataset: (700,)
Test dataset: (100,)


In [3]:
# define which device is used for training
# todo replace with original image shape
default_image_shape = original_image_shape # only works for 3d

device = get_best_device()

if device == "mps":
    # mps is not supported for 3d
    device = "cpu"

torch.set_default_device(device)
print(f"Using {device} device. Every tensor created will be by default on {device}")

Using cuda device. Every tensor created will be by default on cuda


In [None]:
model = CNN_Binary(in_channels=1, num_classes=1)
model.to(device);

setting default functions for three dimensions


In [None]:
print("----------------TRAINING-------------")

def coronary_arteries_in_masks(mask_batch):
    present = []
    for mask in mask_batch:
        coronary_arteries_present = np.any(mask)

        present.append(coronary_arteries_present)
    
    present = torch.tensor(present, dtype=torch.float32)
    present = present[:, None]
    return present

In [None]:
padded_shape = (320, 512, 512)
patch_size = 64
block_shape = (patch_size, patch_size, patch_size)

def train_loop(model, loss_fn, optimizer, image, mask):
    model.train()
    
    avg_train_loss = 0

    image_patches, mask_patches = get_padded_patches(image, mask, patch_size)

    patch_shape = image_patches.shape
    amt_of_image_patches = patch_shape[0] * patch_shape[1] * patch_shape[2]
    # reconstructed_prediction_mask = np.zeros((padded_shape))
    
    patch_counter = 0
    for i in range(patch_shape[0]):
        for j in range(patch_shape[1]):
            for k in range(patch_shape[2]):
                patch_counter += 1
                
                current_image_patch = image_patches[i, j, k]
                current_mask_patch = mask_patches[i, j, k]

                
                    
                current_image_patch = prepare_image_for_network_input(current_image_patch)
                current_mask_patch = prepare_image_for_network_input(current_mask_patch)

                optimizer.zero_grad()

                current_prediction_patch = model(current_image_patch)
                loss = loss_fn(current_prediction_patch, coronary_arteries_in_masks(current_mask_patch))

                loss.backward()
                optimizer.step()

                train_loss = loss.item()
                avg_train_loss += train_loss

                print(f"Patch number: {patch_counter} / {amt_of_image_patches}, Train loss: {train_loss:>8f}", end="\r")
    
    avg_train_loss /= amt_of_image_patches

    print(f"For this patient: average test loss: {avg_train_loss:>8f}\n")
    
    return avg_train_loss

----------------TRAINING-------------


In [None]:
def test_loop(model, loss_fn, image, mask):
    model.eval()

    test_loss = 0

    image_patches, mask_patches = get_padded_patches(image, mask, patch_size)

    patch_shape = image_patches.shape
    amt_of_image_patches = patch_shape[0] * patch_shape[1] * patch_shape[2]
    
    patch_counter = 0
    for i in range(patch_shape[0]):
        for j in range(patch_shape[1]):
            for k in range(patch_shape[2]):
                patch_counter += 1
                current_image_patch = image_patches[i, j, k]
                current_mask_patch = mask_patches[i, j, k]
                    
                current_image_patch = prepare_image_for_network_input(current_image_patch)
                current_mask_patch = prepare_image_for_network_input(current_mask_patch)

                with torch.no_grad():
                    current_prediction_patch = model(current_image_patch)
                    patch_loss = loss_fn(current_prediction_patch, coronary_arteries_in_masks(current_mask_patch)).item()
                    test_loss += patch_loss

    test_loss /= amt_of_image_patches

    print(f"Patch number: {patch_counter} / {amt_of_image_patches}, average test loss: {test_loss:>8f}\n", end="\r")

    # returning for patience
    return test_loss

In [None]:
# resetting the model
model = CNN_Binary(in_channels=1, num_classes=1)
model.to(device)

# print("model prediction at initialization: ")
# default_model_progress_visualization_function(model, image_generator.get_3DImage)

# running it
# param initialization for patience
best_loss = float('inf')  
best_model_weights = None  
patience_base_value = 8
patience = patience_base_value

loss_fn = nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
epochs = 2

In [9]:
try:
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        print("tuple of 7 train patients, 1 test patient")
        for i in range(100):
            print(f"processing tuple {i} from 100 tuples")
            
            current_train_patients = train_patients[i*7:(i+1)*7]
            current_test_patients = test_patients[i]
            
            average_train_losses = []
            average_test_losses = []
            for current_train_idx, current_train_patient in enumerate(current_train_patients):
                print(f"current train patient: {current_train_idx + 1} / 7")
                current_train_image, current_train_mask = current_train_patient.get_image_mask_tuple()
                current_average_train_loss = train_loop(model, loss_fn, optimizer, current_train_image, current_train_mask)
                average_train_losses.append(current_average_train_loss)
            # testing for each epoch to track the models performance during training.

            current_test_image, current_test_mask = current_train_patient.get_image_mask_tuple()
            test_loss = test_loop(model, loss_fn, current_test_image, current_test_mask)
            average_test_losses.append(test_loss)

            if test_loss < best_loss:
                best_loss = test_loss
                best_model_weights = copy.deepcopy(model.state_dict())
                patience = patience_base_value
            else:
                patience -= 1
                if patience <= 0:
                    break
            print("patience: ", patience)
    print("Done!")
except KeyboardInterrupt:
    print("training interrupted by the user")
    model.eval()

Epoch 1
-------------------------------
tuple of 7 train patients, 1 test patient
processing tuple 0 from 100 tuples
current train patient: 1 / 7
For this patient: average test loss: 1.165038  | average Jaccard Score: 0.003749 | average Dice Score: 0.00699910709565955

current train patient: 2 / 7
For this patient: average test loss: 1.007009  | average Jaccard Score: 0.052802 | average Dice Score: 0.08112673726541626

current train patient: 3 / 7
For this patient: average test loss: 0.977706  | average Jaccard Score: 0.080685 | average Dice Score: 0.1201979578457924

current train patient: 4 / 7
For this patient: average test loss: 0.953297  | average Jaccard Score: 0.096339 | average Dice Score: 0.1457657385186291

current train patient: 5 / 7
For this patient: average test loss: 0.965267  | average Jaccard Score: 0.084299 | average Dice Score: 0.13409543383947705

current train patient: 6 / 7
For this patient: average test loss: 0.902576  | average Jaccard Score: 0.129627 | average 

In [10]:
print("------INFERENCE--------")

'''
for i in range(10):
    mask, pred = default_model_progress_visualization_function(model, get_image_fct=image_generator.get_3DImage)
    print("jaccard score for above image: ", calculate_jaccard_score(mask, pred))
'''

------INFERENCE--------


'\nfor i in range(10):\n    mask, pred = default_model_progress_visualization_function(model, get_image_fct=image_generator.get_3DImage)\n    print("jaccard score for above image: ", calculate_jaccard_score(mask, pred))\n'

In [11]:
torch.save(model.state_dict(), 'model_weights.pth')
