In [1]:
import matplotlib.pyplot as plt
import numpy as np
import h5py
from ipywidgets import widgets
%matplotlib widget

def mask_overlay(image, mask, color=(0., 1., 0.), weight=.8):
    """
    Helper function to visualize mask on the top of the aneurysm
    """
    mask = np.dstack((mask, mask, mask)) * np.array(color)
    mask = mask.astype(np.uint8)
    weighted_sum = image * weight + mask * (1 - weight) # cv2.addWeighted(image, 1 - weight, mask, weight, 0.,  dtype=cv2.CV_32F)
    img = image.copy()
    ind = mask[:, :, 1] > 0
    img[ind] = weighted_sum[ind]
    return img


class ImageSlicer(object):

    def __init__(self, ax, img, mask):
        self.ax = ax

        # convert to numpy array
        self.image_np = img
        self.mask_np = mask

        # get number of slices
        _, _, self.slices = self.image_np.shape
        self.ind = self.slices // 2

        # plot image with mask overlay
        self.image_plt = self.ax.imshow(self.overlay)
        self._draw()

    @property
    def overlay(self):
        # get image and mask slice
        image = self.image_np[:, :, self.ind]
        image = image / np.max(image)
        image = np.dstack((image, image, image))
        mask = self.mask_np[:, :, self.ind]

        # create masked overlay
        return mask_overlay(image, mask)

    def onscroll(self, event):

        # get new slice number
        self.ind = event['new']
        with open("file.txt", "a") as f:
            f.write(f"{self.ind}\n")
        self.update()

    def update(self):

        # draw overlay
        self.image_plt.set_data(self.overlay)
        self._draw()

    def _draw(self):
        try:
            self.image_plt.axes.figure.canvas.draw()
        except Exception as e:
            with open("file.txt", "a") as f:
                f.write(f"Exception {e}\n")
            pass


def plot3d(img, mask):

    #
    figure, ax = plt.subplots(1, 1)
    tracker = ImageSlicer(ax, img, mask)

    #
    int_slider = widgets.IntSlider(
        value=tracker.ind,
        min=0,
        max=tracker.slices,
        step=1,
        description='Slice',
        continuous_update=True
    )
    int_slider.observe(tracker.onscroll, 'value')

    return figure, int_slider

def load_case(i):
    f = h5py.File(f"{datapath}/{files[i]}", "r")
    return f["raw"][:], f["label"][:]

In [3]:
import yaml
import os
from pytorch3dunet.augment.transforms import Transformer

config_path = "./train-configs/aug_tests/perlin.yml"

config = yaml.safe_load(open(config_path, 'r'))

transform_config = config["loaders"]["train"]["transformer"]
datapath = config["loaders"]["train"]["file_paths"][0]

transformer = Transformer(transform_config, {})
raw_transform = transformer.raw_transform()
label_transform = transformer.label_transform()

files = os.listdir(datapath)
print(files)

['A038_R.h5', 'A051_R.h5']


In [16]:
import scipy

In [15]:
def augment( m):
        m[m<0.4]=0
        kernel = np.asarray([[[0.125,0.25,125],
                    [0.25,0.5,0.25],
                    [0.125,0.25,0.125]],[[0.25,0.5,0.25],
                    [0.5,1,0.5],
                    [0.25,0.5,0.25]],
                    [[0.125,0.25,125],
                    [0.25,0.5,0.25],
                    [0.125,0.25,0.125]]])
        print(kernel.shape)
        print(m.shape)
        scipy.ndimage.convolve(m,kernel)
        return m
%reload_ext autoreload
%autoreload 2
i = 0

img, mask = load_case(i)
img /= img.max()

img_t = augment(img)
mask_t = label_transform(mask)

fig, slider = plot3d(img_t, mask_t)
slider

(3, 3, 3)
(220, 256, 256)


ValueError: object too deep for desired array