In [None]:
from models.inference_pipeline import CCTAPipeline
from models.unet3D import UNet3D, dice_bce_loss
from models.net_utils import calculate_jaccard_score, calculate_dice_scores
import torch
from models.net_utils import get_best_device, calculate_dice_scores, calculate_overlap
from server_specific.server_utils import get_patients
from data.data_utils import get_preprocessed_patches, get_all_patches_with_certain_idx, combine_preprocessed_patches
import numpy as np
from data_generation.generate_3d import visualize3Dimage
import training_configuration
from training_utils import get_val_test_indexes
from training_utils import get_train_test_val_patches

In [2]:
# trained_model_path = "saved_models/3d_model20241120-163318.pth"
# trained_model_path = "saved_models/3d_model20241124-104231.pth"

current_trained_model = "saved_models/3d_model_128_bce_fixed_2_dez.pth"
block_size = 128

In [3]:
device = get_best_device()

torch.set_default_device(device)
print(f"Using {device} device. Every tensor created will be by default on {device}")

model = UNet3D(in_channels=1, num_classes=1)

model.load_state_dict(torch.load(current_trained_model, weights_only=True))
model.to(device)
model.eval()

Using cuda device. Every tensor created will be by default on cuda


UNet3D(
  (down_convolution_1): DownSample(
    (conv): DoubleConv3D(
      (conv_op): Sequential(
        (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): ReLU(inplace=True)
        (2): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (3): ReLU(inplace=True)
      )
    )
    (pool): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (down_convolution_2): DownSample(
    (conv): DoubleConv3D(
      (conv_op): Sequential(
        (0): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): ReLU(inplace=True)
        (3): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (4): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (5): ReLU(inplace=True)
      )
    )
    (pool): MaxP

In [None]:



print("-----INFERENCE_______")
# it would be better to do this with an inference pipeline

preprocessed_patches = get_preprocessed_patches(patches_folder = training_configuration.PATCHES_FOLDER)
val_idxs, test_idxs = get_val_test_indexes()
test_idxs_patches = get_all_patches_with_certain_idx(test_idxs, preprocessed_patches)

_, val_idxs_patches, test_idxs_patches = get_train_test_val_patches(patches_folder = training_configuration.PATCHES_FOLDER)

-----INFERENCE_______
amt of detected patch files:  22667
amt of detected_files:  1800
amt of patients:  800
validation patients: ['b123f9', 'a4ecdd', '90b0ce', '0f7854', '455ae2', 'aea921', '929f00', '86ab8c', '46db1c', '96061e', '01ce4b', 'bbb965', 'a20751', '718127', '5cf4e2', '82a93b', '4d0198', '72440a', 'a7f0b8', 'c4ed8e', '131b0c']
testing patients: ['131b0c', 'abbb92', '34e2ad', '9b15fe', 'a21855', '482d96', '73475c', '62bfa2', '73daa9', '2d1007', 'aee484', '2c7d54', '3837e1', 'c22e1a', '888204', '768b84', '5480ab', '3c1528', 'aaf01d', '5e5c74', 'b7a568', '23c657', '3d34f1', '79bf08', '1c49f2', '188c1f', '48f89b', '8e4676', 'bcaf44', '43d244', '3bd625', '3b86df', '7559ca', '8d23cf', '9197e4', '28ac59', '392a52', '3055e0', '41e521', '501a4e', '6aac0c']


In [5]:
block_size = 128

In [None]:
predictions = []
for test_idx_patches in test_idxs_patches:
    reconstructed_mask, reconstructed_prediction = combine_preprocessed_patches(test_idx_patches, model)
    
    dice_scores = calculate_dice_scores(reconstructed_mask, reconstructed_prediction, thresholds = [0.25, 0.5])
    overlap_scores = calculate_overlap(reconstructed_mask, reconstructed_prediction, thresholds = [0.25, 0.5])
    
    print(f"dice scores: {dice_scores}")
    print(f"overlap scores: {overlap_scores}")
    
    # visualize3Dimage(reconstructed_mask)
    # visualize3Dimage(reconstructed_prediction)
    

(384, 384, 384)
dice scores: [np.float64(0.3636632349730447), np.float64(0.3718563142840384)]
overlap scores: [np.float64(0.7918495114387689), np.float64(0.788236498742028)]
(512, 512, 384)
dice scores: [np.float64(0.23699681122770253), np.float64(0.24486214622579922)]
overlap scores: [np.float64(0.8645936707464991), np.float64(0.861616495754769)]
(384, 384, 256)
dice scores: [np.float64(0.3616464098972516), np.float64(0.37580697809371827)]
overlap scores: [np.float64(0.8405046910384989), np.float64(0.838058071821417)]
(384, 384, 256)
dice scores: [np.float64(0.25113211409513253), np.float64(0.26227298023938483)]
overlap scores: [np.float64(0.8907207366743404), np.float64(0.8887905082344608)]
(384, 384, 256)
dice scores: [np.float64(0.2797932534603468), np.float64(0.2936161761785323)]
overlap scores: [np.float64(0.8489272754281348), np.float64(0.8471022642870785)]
(384, 384, 256)
dice scores: [np.float64(0.352305449659702), np.float64(0.358611212104029)]
overlap scores: [np.float64(0.7