In [1]:
import torch
from torchvision import transforms
from utils import mypreprocess, util_functions, eff_unet2, dataset2d
from torch.utils.data import DataLoader
import os
import matplotlib.pyplot as plt
from collections import defaultdict
from tqdm import tqdm
import matplotlib.pyplot as plt
import cv2

In [2]:
model_path = 'results/2023-12-14_13-43/best_model.pt' 
model = eff_unet2.EffUNet(in_channels=1, classes=1)
model.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [8]:
IMG_SIZE = 512
BATCH_SIZE = 1
base_path = '/scratch/student/sinaziaee/datasets/2d_dataset/'
train_dir = os.path.join(base_path, 'training')
valid_dir = os.path.join(base_path, 'validation')
test_dir = os.path.join(base_path, 'testing')

inference_transformer = transforms.Compose([transforms.ToTensor(), 
                                            transforms.Resize((IMG_SIZE, IMG_SIZE), antialias=True)])

test_dataset = dataset2d.SegmentationDataset(input_root=f'{test_dir}/images/',target_root=f'{test_dir}/labels/',
                               transform_input= inference_transformer, transform_target=inference_transformer, with_path=True)
inference_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [9]:
print(f"dataset_size: {len(inference_loader.dataset)}, batch: {next(iter(inference_loader))[0].shape}")

dataset_size: 1063, batch: torch.Size([1, 1, 512, 512])


In [10]:
for idx, batch in tqdm(enumerate(inference_loader), total=len(inference_loader)):
    each1, each2, path = batch
    path = path[0]
    img_idx = path[-13:-9]
    slice_idx = path[-8:-4]
    print(img_idx, slice_idx)
    break

  0%|          | 0/1063 [00:00<?, ?it/s]

0072 0125





In [11]:
# Save predictions only
device = torch.device('cuda:0')



def save_prediction(prediction, path, filename, output_size=(512, 512)):
    """Save the prediction mask as a PNG file with a specific size."""
    # Resize prediction to the desired output size
    resized_prediction = cv2.resize(prediction, output_size, interpolation=cv2.INTER_NEAREST)

    plt.figure(figsize=(6, 6))
    plt.imshow(resized_prediction, cmap='gray')
    plt.axis('off')
    plt.savefig(os.path.join(path, filename), bbox_inches='tight', pad_inches=0)
    plt.close()

total_iou = 0.0
total_dice = 0.0
num_images = 0

# Define the 2D predictions folder
predictions_folder = 'preds/2d_predictions'
os.makedirs(predictions_folder, exist_ok=True)
model = model.to(device)
# Process each image and mask in the dataset
for idx, (images, true_masks, path) in tqdm(enumerate(inference_loader)):
    path = path[0]
    images = images.to(device)
    true_masks = true_masks.to(device)
    image_index = path[-13:-9]
    slice_index = path[-8:-4]

    # Run inference
    with torch.no_grad():
        logits_mask = model(images)
        pred_mask = torch.sigmoid(logits_mask)
        # Ensure we are extracting a single 2D slice
        pred_mask_cpu = (pred_mask > 0.5).float().cpu().numpy()[0, 0, :, :]

    # Calculate IoU and Dice Score
    iou = util_functions.calculate_IoU(pred_mask, true_masks)
    dice_score = util_functions.dice_coefficient2_modified(true_masks, pred_mask)

    total_iou += iou.item()
    total_dice += dice_score
    num_images += 1

    # Save the prediction mask
    save_filename = f'image_{image_index}_{slice_index}.png'
    save_prediction(pred_mask_cpu, predictions_folder, save_filename)

# Calculate and print the overall IoU and Dice Score
overall_iou = total_iou / num_images
overall_dice = total_dice / num_images
print(f"Overall IoU: {overall_iou:.2f}, Overall Dice Score: {overall_dice:.2f}")

1063it [00:59, 17.89it/s]

Overall IoU: 0.45, Overall Dice Score: 0.47





In [12]:
import os
import numpy as np
import nibabel as nib
from skimage import io
from tqdm import tqdm

def load_and_combine_slices_to_nifti(png_folder, output_folder):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # Dictionary to hold slices for each image
    image_slices = {}

    # Iterate over PNG files and group slices by image
    for file in sorted(os.listdir(png_folder)):
        if file.endswith('.png'):
            parts = file.split('_')
            image_index = int(parts[1])
            slice_index = int(parts[2].split('.')[0])
            slice_path = os.path.join(png_folder, file)
            slice_data = io.imread(slice_path, as_gray=True)

            if image_index not in image_slices:
                image_slices[image_index] = {}
            image_slices[image_index][slice_index] = slice_data

    # Process each image group
    for image_index in image_slices:
        slices = [image_slices[image_index][i] for i in sorted(image_slices[image_index])]
        image_3d = np.stack(slices, axis=-1)

        # Save as NIfTI file
        nifti_path = os.path.join(output_folder, f'image_{image_index:04d}.nii.gz')
        nifti_img = nib.Nifti1Image(image_3d, np.eye(4))
        nib.save(nifti_img, nifti_path)

    print("3D NIfTI files created.")

# Example usage
png_folder = 'preds/2d_predictions'
output_folder = 'preds/3d_predictions'
load_and_combine_slices_to_nifti(png_folder, output_folder)

3D NIfTI files created.
