In [1]:
import torch
from data.data_utils import get_preprocessed_patches, get_all_patches_with_certain_idx
import training_configuration
from training_utils import get_train_test_val_patches, get_val_test_indexes, test_or_validate_model
from models.net_utils import get_best_device
from models.unet3D import UNet3D

In [2]:
# trained_model_path = "saved_models/3d_model20241209-211318.pth" # average dice score of 0.76

trained_model_path = "saved_models/3d_model20241216-101621.pth" # average dice score of 0.78
block_size = 128

In [3]:
device = get_best_device()

torch.set_default_device(device)
print(f"Using {device} device. Every tensor created will be by default on {device}")

model = UNet3D(in_channels=1, num_classes=1)

model.load_state_dict(torch.load(trained_model_path, weights_only=True))
model.to(device)
model.eval();

Using cuda device. Every tensor created will be by default on cuda


# INFERENCE

In [4]:
# it would be better to do this with an inference pipeline

preprocessed_patches = get_preprocessed_patches(patches_folder = training_configuration.PATCHES_FOLDER)
val_idxs, test_idxs = get_val_test_indexes()
id_idx_patches_list = get_all_patches_with_certain_idx(test_idxs, preprocessed_patches)
_, val_idxs_patches, test_idxs_patches = get_train_test_val_patches(patches_folder = training_configuration.PATCHES_FOLDER)

amt of detected patch files:  22667
amt of detected patch files:  22667
patients for- training: 740, validation: 20, testing: 40
training patches: 20955, validation patches: 591, test patches: 1121


In [None]:
'''
import optuna

def objective(trial):
    threshold_sug = trial.suggest_float('threshold', 0.0, 0.5)
    avg_overlap_scores, avg_dice_scores, avg_dice_scores_b_pp = test_or_validate_model(val_idxs_patches, model, threshold = threshold_sug, visualize=False)

    return avg_dice_scores

study = optuna.create_study(direction='maximize', sampler=optuna.samplers.TPESampler())

# jobs = -1 to use all cores
study.optimize(objective, n_trials=100, n_jobs=2, show_progress_bar=False)

print("best dice value: ", study.best_value)
print("best threshold: ", study.best_params)
'''

In [None]:
threshold = 0.204305 # best thresold, but it barely makes a difference.
avg_overlap_scores, avg_dice_scores, avg_dice_scores_b_pp = test_or_validate_model(test_idxs_patches, model, threshold = threshold, visualize=True)

## Final values on the test set:

In [None]:
print(f"dice scores on the test set before pp: {avg_dice_scores_b_pp}")
print(f"dice scores on the test set after pp: {avg_dice_scores}")
print(f"overlap scores before pp: {avg_overlap_scores}")