In [202]:
import ipywidgets as widgets
import matplotlib.pyplot as plt
from IPython.display import display
import numpy as np

from monai.transforms import *

In [203]:
labs_3d = True
if labs_3d:
    file_num = 'P5'
    image_path = f"/data/vr_heart/datasets/3D_Lab_Abnormal_Heart/Abnormal_Heart/{file_num}/Dicom"
    label_path = f"/data/vr_heart/datasets/3D_Lab_Abnormal_Heart/Abnormal_Heart/{file_num}/{file_num}_labels.nii.gz"
else:
    file_num = '1081'
    image_path = f"/data/vr_heart/datasets/ImageCHD_dicom/dicoms/ct_{file_num}_image"
    label_path = f"/data/vr_heart/datasets/ImageCHD_dicom/labels/ct_{file_num}_label.nii.gz"
original_image = Compose([LoadImage(image_only=True), EnsureChannelFirst(), Orientation("RAS")])(image_path)
original_label = LoadImage(image_only=True)(label_path)

In [204]:
original_image.shape, original_label.shape

(torch.Size([1, 512, 512, 317]), torch.Size([512, 512, 317]))

In [205]:
image_transforms = Compose([
    LoadImage(image_only=True),
    EnsureChannelFirst(),
    # AdjustContrast(gamma=1.45),
    # Rand3DElastic(sigma_range=(5,7), magnitude_range=(100,200), prob=0.7, padding_mode="reflection"),
    Orientation("RAS"),
    NormalizeIntensity(),
    # RandStdShiftIntensity(factors=0.5),
    # RandGaussianNoise(prob=1, mean=0, std=0.5),
    # RandGibbsNoise(prob=1, alpha=(0.8, 1)),
    # RandRicianNoise(prob=1, std=0.5, mean=1),
    RandKSpaceSpikeNoise(prob=1, intensity_range=(10, 12)),
    EnsureType(),
    ToTensor()
])

label_transforms = Compose([
    LoadImage(image_only=True),
    EnsureChannelFirst(),
    # Rand3DElastic(sigma_range=(5,7), magnitude_range=(100,200), prob=0.7, padding_mode="reflection"),
    Orientation("RAS"),
    EnsureType(),
    ToTensor()
])

In [206]:
transformed_image = image_transforms(image_path)
transformed_label = label_transforms(label_path)

In [207]:
original_image_np = original_image[0].cpu().numpy()
transformed_image_np = transformed_image[0].cpu().numpy()
label_np = transformed_label[0].cpu().numpy()

#sanity check for the shapes
print(original_image_np.shape)
print(transformed_image_np.shape)
print(label_np.shape)

print(max(transformed_image_np.flatten()))
print(min(transformed_image_np.flatten()))

(512, 512, 317)
(512, 512, 317)
(512, 512, 317)
6.922323
-0.9277008


In [208]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets

def visualize_slices_with_sliders(original, transformed, label):
    views = ['Sagittal', 'Coronal', 'Axial']
    
    max_slices = [original.shape[2], original.shape[1], original.shape[0]]

    sagittal_slider = widgets.IntSlider(min=0, max=max_slices[0]-1, step=1, value=original.shape[2]//2, description='Sagittal Slice')
    coronal_slider = widgets.IntSlider(min=0, max=max_slices[1]-1, step=1, value=original.shape[1]//2, description='Coronal Slice')
    axial_slider = widgets.IntSlider(min=0, max=max_slices[2]-1, step=1, value=original.shape[0]//2, description='Axial Slice')

    def update_plot(sagittal_slice, coronal_slice, axial_slice):
        fig, axes = plt.subplots(3, 3, figsize=(15, 15))
        slice_nums = [sagittal_slice, coronal_slice, axial_slice]
        
        for i, (view, slice_num) in enumerate(zip(views, slice_nums)):
            if view == 'Sagittal':
                axes[i, 0].imshow(np.rot90(original[:, :, slice_num]), cmap='gray')
                axes[i, 1].imshow(np.rot90(transformed[:, :, slice_num]), cmap='gray')
                axes[i, 2].imshow(np.rot90(label[:, :, slice_num]))
            elif view == 'Coronal':
                axes[i, 0].imshow(np.rot90(original[:, slice_num, :]), cmap='gray')
                axes[i, 1].imshow(np.rot90(transformed[:, slice_num, :]), cmap='gray')
                axes[i, 2].imshow(np.rot90(label[:, slice_num, :]))
            elif view == 'Axial':
                axes[i, 0].imshow(np.rot90(original[slice_num, :, :]), cmap='gray')
                axes[i, 1].imshow(np.rot90(transformed[slice_num, :, :]), cmap='gray')
                axes[i, 2].imshow(np.rot90(label[slice_num, :, :]))

            axes[i, 0].set_title(f'Original {view} View')
            axes[i, 1].set_title(f'Transformed {view} View')
            axes[i, 2].set_title(f'Label {view} View')
            axes[i, 0].axis('off')
            axes[i, 1].axis('off')
            axes[i, 2].axis('off')

        plt.tight_layout()
        plt.show()

    ui = widgets.VBox([sagittal_slider, coronal_slider, axial_slider])
    out = widgets.interactive_output(update_plot, 
                                     {'sagittal_slice': sagittal_slider, 
                                      'coronal_slice': coronal_slider, 
                                      'axial_slice': axial_slider})
    
    display(ui, out)

visualize_slices_with_sliders(original_image_np, transformed_image_np, label_np)


VBox(children=(IntSlider(value=158, description='Sagittal Slice', max=316), IntSlider(value=256, description='…

Output()