In [9]:
import matplotlib.pyplot as plt
import numpy as np
import h5py
from ipywidgets import widgets
%matplotlib widget

def mask_overlay(image, mask, color=(0., 1., 0.), weight=.8):
    """
    Helper function to visualize mask on the top of the aneurysm
    """
    mask = np.dstack((mask, mask, mask)) * np.array(color)
    mask = mask.astype(np.uint8)
    weighted_sum = image * weight + mask * (1 - weight) # cv2.addWeighted(image, 1 - weight, mask, weight, 0.,  dtype=cv2.CV_32F)
    img = image.copy()
    ind = mask[:, :, 1] > 0
    img[ind] = weighted_sum[ind]
    return img


class ImageSlicer(object):

    def __init__(self, ax, img, mask):
        self.ax = ax

        # convert to numpy array
        self.image_np = img
        self.mask_np = mask

        # get number of slices
        _, _, self.slices = self.image_np.shape
        self.ind = self.slices // 2

        # plot image with mask overlay
        self.image_plt = self.ax.imshow(self.overlay)
        self._draw()

    @property
    def overlay(self):
        # get image and mask slice
        image = self.image_np[:, :, self.ind]
        image = image / np.max(image)
        image = np.dstack((image, image, image))
        mask = self.mask_np[:, :, self.ind]

        # create masked overlay
        return mask_overlay(image, mask)

    def onscroll(self, event):

        # get new slice number
        self.ind = event['new']
        with open("file.txt", "a") as f:
            f.write(f"{self.ind}\n")
        self.update()

    def update(self):

        # draw overlay
        self.image_plt.set_data(self.overlay)
        self._draw()

    def _draw(self):
        try:
            self.image_plt.axes.figure.canvas.draw()
        except Exception as e:
            with open("file.txt", "a") as f:
                f.write(f"Exception {e}\n")
            pass


def plot3d(img, mask):

    #
    figure, ax = plt.subplots(1, 1)
    tracker = ImageSlicer(ax, img, mask)

    #
    int_slider = widgets.IntSlider(
        value=tracker.ind,
        min=0,
        max=tracker.slices,
        step=1,
        description='Slice',
        continuous_update=True
    )
    int_slider.observe(tracker.onscroll, 'value')

    return figure, int_slider

def load_case(i):
    f = h5py.File(f"{datapath}/{files[i]}", "r")
    return f["raw"][:], f["label"][:]

In [12]:
import yaml
import os
from pytorch3dunet.augment.transforms import Transformer

config_path = "./train-configs/aug_tests/perlin.yml"

config = yaml.safe_load(open(config_path, 'r'))

transform_config = config["loaders"]["train"]["transformer"]
datapath = config["loaders"]["train"]["file_paths"][0]

transformer = Transformer(transform_config, {})
raw_transform = transformer.raw_transform()
label_transform = transformer.label_transform()

files = os.listdir(datapath)
print(files)

['A054_masks.h5', 'A009_masks.h5', 'A050_masks.h5', 'A108_masks.h5', 'A012_masks.h5', 'A018_masks.h5', 'A008_masks.h5', 'A098_masks.h5', 'A080_masks.h5', 'A013_masks.h5', 'A127_masks.h5', 'A084_masks.h5', 'A119_masks.h5', 'A115_masks.h5', 'A138_masks.h5', 'A121_masks.h5', 'A010_masks.h5', 'A082_masks.h5', 'A097_masks.h5', 'A057_masks.h5', 'A096_R_masks.h5', 'A078_R_masks.h5', 'A130_L_masks.h5', 'A024_masks.h5', 'A060_masks.h5', 'A071_masks.h5', 'A074_masks.h5', 'A003_masks.h5', 'A046_masks.h5', 'A134_masks.h5', 'A087_masks.h5', 'A093_masks.h5', 'A091_R_masks.h5', 'A129_masks.h5', 'A114_masks.h5', 'A067_masks.h5', 'A088_masks.h5', 'A137_masks.h5', 'A023_R_masks.h5', 'A001_masks.h5', 'A079_masks.h5', 'A100_masks.h5', 'A095_masks.h5', 'PA3_masks.h5', 'A059_L_masks.h5', 'A085_masks.h5', 'A064_masks.h5', 'A062_L_masks.h5', 'A051_R_masks.h5', 'A017_L_masks.h5', 'A096_L_masks.h5', 'A005_masks.h5', 'A006_masks.h5', 'A078_L_masks.h5', 'A049_masks.h5', 'A089_R_masks.h5', 'A126_masks.h5', 'A136_m

In [13]:
i = 0

img, mask = load_case(i)
img /= img.max()

img_t = raw_transform(img)
mask_t = label_transform(mask)

fig, slider = plot3d(img_t, mask_t)
slider

(256, 256, 220)


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


IntSlider(value=110, description='Slice', max=220)