In [7]:
import ipywidgets as widgets
import matplotlib.pyplot as plt
from IPython.display import display
import numpy as np

from monai.transforms import *

In [8]:
image_path = "/data/vr_heart/datasets/ImageCHD_dicom/dicoms/ct_1001_image"
label_path = "/data/vr_heart/datasets/ImageCHD_dicom/labels/ct_1001_label.nii.gz"
original_image = LoadImage(image_only=True)(image_path)
label = LoadImage(image_only=True)(label_path)

In [9]:
transforms = Compose([
    LoadImage(image_only=True),
    EnsureType(),
    AdjustContrast(gamma=1.75),
    # Affine(shear_params=[0.1,0.1], mode="bilinear", padding_mode="reflection"),
    Rand3DElastic(sigma_range=(5,7), magnitude_range=(50,150), prob=1, padding_mode="reflection"),
    NormalizeIntensity(),
    RandFlip()
])

In [10]:
transformed_image = transforms(image_path)
transformed_label = transforms(label_path)

In [11]:
original_image_np = original_image.cpu().numpy()
transformed_image_np = transformed_image.cpu().numpy()
label_np = transformed_label

#sanity check for the shapes
print(original_image_np.shape)
print(transformed_image_np.shape)
print(label_np.shape)

print(max(transformed_image_np.flatten()))
print(min(transformed_image_np.flatten()))

(512, 512, 221)
(512, 512, 221)
torch.Size([512, 512, 221])
14.019986
-1.164882


In [12]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets

def visualize_slices_with_sliders(original, transformed, label):
    views = ['Sagittal', 'Coronal', 'Axial']
    
    max_slices = [original.shape[2], original.shape[1], original.shape[0]]

    sagittal_slider = widgets.IntSlider(min=0, max=max_slices[0]-1, step=1, description='Sagittal Slice')
    coronal_slider = widgets.IntSlider(min=0, max=max_slices[1]-1, step=1, description='Coronal Slice')
    axial_slider = widgets.IntSlider(min=0, max=max_slices[2]-1, step=1, description='Axial Slice')

    def update_plot(sagittal_slice, coronal_slice, axial_slice):
        fig, axes = plt.subplots(3, 3, figsize=(15, 15))
        slice_nums = [sagittal_slice, coronal_slice, axial_slice]
        
        for i, (view, slice_num) in enumerate(zip(views, slice_nums)):
            if view == 'Sagittal':
                axes[i, 0].imshow(np.rot90(original[:, :, slice_num]), cmap='gray')
                axes[i, 1].imshow(np.rot90(transformed[:, :, slice_num]), cmap='gray')
                axes[i, 2].imshow(np.rot90(label[:, :, slice_num]))
            elif view == 'Coronal':
                axes[i, 0].imshow(np.rot90(original[:, slice_num, :]), cmap='gray')
                axes[i, 1].imshow(np.rot90(transformed[:, slice_num, :]), cmap='gray')
                axes[i, 2].imshow(np.rot90(label[:, slice_num, :]))
            elif view == 'Axial':
                axes[i, 0].imshow(np.rot90(original[slice_num, :, :]), cmap='gray')
                axes[i, 1].imshow(np.rot90(transformed[slice_num, :, :]), cmap='gray')
                axes[i, 2].imshow(np.rot90(label[slice_num, :, :]))

            axes[i, 0].set_title(f'Original {view} View')
            axes[i, 1].set_title(f'Transformed {view} View')
            axes[i, 2].set_title(f'Label {view} View')
            axes[i, 0].axis('off')
            axes[i, 1].axis('off')
            axes[i, 2].axis('off')

        plt.tight_layout()
        plt.show()

    ui = widgets.VBox([sagittal_slider, coronal_slider, axial_slider])
    out = widgets.interactive_output(update_plot, 
                                     {'sagittal_slice': sagittal_slider, 
                                      'coronal_slice': coronal_slider, 
                                      'axial_slice': axial_slider})
    
    display(ui, out)

visualize_slices_with_sliders(original_image_np, transformed_image_np, label_np)


VBox(children=(IntSlider(value=0, description='Sagittal Slice', max=220), IntSlider(value=0, description='Coro…

Output()