In [None]:
import glob
import random
from datetime import datetime

import cv2
import numpy as np
import skimage.transform
from matplotlib import pyplot as plt
from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage.interpolation import map_coordinates


In [None]:
    def gaussian_noise(self, img, mean=0, sigma=0.003):
        img = img.copy()
        noise = np.random.normal(mean, sigma, img.shape)
        mask_overflow_upper = img + noise >= 1.0
        mask_overflow_lower = img + noise < 0
        noise[mask_overflow_upper] = 1.0
        noise[mask_overflow_lower] = 0
        img = img + noise
        return img

    def random_crop_resize(self, img, label, crop_size=500):
        size_img = img.shape
        size_label = label.shape
        crop_size = random.randint(crop_size, img.shape[0] - 1)
        crop_size = (crop_size, crop_size)

        # "Crop size should be less than image size"
        assert crop_size[0] <= img.shape[0] and crop_size[1] <= img.shape[1]

        w, h = img.shape[:2]
        x, y = np.random.randint(h - crop_size[0]), np.random.randint(w - crop_size[1])

        img = img[y : y + crop_size[0], x : x + crop_size[1], :]
        img = skimage.transform.resize(img, size_img)

        label = label[y : y + crop_size[0], x : x + crop_size[1], :]
        label = skimage.transform.resize(label, size_label)
        return img, label

    def affine_transform(self, image, label, alpha_affine=0.5, random_state=None):

        if random_state is None:
            random_state = np.random.RandomState(None)

        shape = image.shape
        shape_size = shape[:2]
        center_square = np.float32(shape_size) // 2
        square_size = min(shape_size) // 3
        pts1 = np.float32(
            [
                center_square + square_size,
                [center_square[0] + square_size, center_square[1] - square_size],
                center_square - square_size,
            ]
        )
        pts2 = pts1 + random_state.uniform(
            -alpha_affine, alpha_affine, size=pts1.shape
        ).astype(np.float32)
        M = cv2.getAffineTransform(pts1, pts2)

        image = cv2.warpAffine(
            image, M, shape_size[::-1], borderMode=cv2.BORDER_REFLECT_101
        )
        image = image[..., np.newaxis]
        label = cv2.warpAffine(
            label, M, shape_size[::-1], borderMode=cv2.BORDER_REFLECT_101
        )
        return image, label

    def elastic_transform(self, image, label, alpha, sigma, random_state=None):

        if random_state is None:
            random_state = np.random.RandomState(None)

        shape = image.shape
        shape_label = label.shape

        dx = (
            gaussian_filter(
                (random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0
            )
            * alpha
        )
        dy = (
            gaussian_filter(
                (random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0
            )
            * alpha
        )
        dz = np.zeros_like(dx)

        # image
        x, y, z = np.meshgrid(
            np.arange(shape[0]), np.arange(shape[1]), np.arange(shape[2])
        )
        indices = (
            np.reshape(y + dy, (-1, 1)),
            np.reshape(x + dx, (-1, 1)),
            np.reshape(z, (-1, 1)),
        )
        image = map_coordinates(image, indices, order=1, mode="reflect").reshape(shape)

        # label
        x, y, z = np.meshgrid(
            np.arange(shape_label[0]),
            np.arange(shape_label[1]),
            np.arange(shape_label[2]),
        )
        indices = (
            np.reshape(y + dy, (-1, 1)),
            np.reshape(x + dx, (-1, 1)),
            np.reshape(z, (-1, 1)),
        )
        label = map_coordinates(label, indices, order=1, mode="reflect").reshape(
            shape_label
        )

        return image, label



In [None]:
    def data_augment(self, img, mask, chance=0.5):
        # flip l/r
        if random.uniform(0, 1) < 0.5:
            img = cv2.flip(img, 1)
            mask = cv2.flip(mask, 1)
            if len(img.shape) == 2:
                img = img[..., np.newaxis]
            if len(mask.shape) == 2:
                mask = mask[..., np.newaxis]

        # random crop and resize
        if random.uniform(0, 1) < chance:
            img, mask = self.random_crop_resize(img, mask)
            if len(img.shape) == 2:
                img = img[..., np.newaxis]
            if len(mask.shape) == 2:
                label = label[..., np.newaxis]

        # random affine transformation
        if random.uniform(0, 1) < chance:
            img, mask = self.affine_transform(img, mask, alpha_affine=20)
            if len(img.shape) == 2:
                img = img[..., np.newaxis]
            if len(mask.shape) == 2:
                mask = mask[..., np.newaxis]

        if random.uniform(0, 1) < chance:
            args = random.choice(((1201, 10), (1501, 12), (991, 8)))
            img, mask = self.elastic_transform(img, mask, *args)

        # random Gaussian noise
        if random.uniform(0, 1) < chance:
            sigma = random.choice(np.arange(0.1, 0.3, 0.02))
            img = self.gaussian_noise(img, mean=0, sigma=sigma)

        return img, mask


In [None]:
import tensorflow as tf
physical_devices = tf.config.experimental.list_physical_devices('GPU')
assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
config = tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [None]:
#import train as train

import loss as loss
import unet as unet
import random
import paths as paths
import numpy as np
import glob

# The meaning of lifes
random.seed(42)  
np.random.seed(42)

In [None]:
DATA_MEAN = 191.46748269704375
DATA_STD = 369.2190429494859
BATCH_SIZE = 1
OUTPUT_CHANNELS = 3
DATA_PATH = "/home/matthew/masters_code/dataset_prostate_cleaned/"

In [None]:
patient_paths = paths.get_patient_paths(DATA_PATH)
patient_paths.sort()

img_paths = [glob.glob(path + "/img/*") for path in patient_paths]
mask_paths = [glob.glob(path + "/mask/*") for path in patient_paths]

valid = int(len(img_paths) * 0.15 // 1)
test = int(len(img_paths) * 0.1 // 1)
train = int(len(img_paths) - valid - test)

train_inputs = paths.flatten_list(img_paths[0:train])
train_truths = paths.flatten_list(mask_paths[0:train])

train_inputs.sort()
train_truths.sort()

valid_inputs = paths.flatten_list(img_paths[train:train+valid])
valid_truths = paths.flatten_list(mask_paths[train:train+valid])

valid_inputs.sort()
valid_truths.sort()

test_inputs = paths.flatten_list(img_paths[train+valid:])
test_truths = paths.flatten_list(mask_paths[train+valid:])

test_inputs.sort()
test_truths.sort()

In [None]:
test_inputs = np.array([np.load(array) for array in test_inputs])
test_truths = np.array([np.load(array) for array in test_truths])

In [None]:
test_inputs = (test_inputs - DATA_MEAN) / DATA_STD

In [None]:
test_inputs.shape, test_truths.shape

In [None]:
img = test_inputs[0].copy()

In [None]:
from matplotlib import pyplot as plt

In [None]:
plt.imshow(img[...,0])

In [None]:
dx, dy = 64, 64
img = test_inputs[0].copy()
print(img.shape)

grid_color = np.max(img)

# Modify the image to include the grid
img[:,::dy,:] = grid_color
img[::dx,:,:] = grid_color

# Show the result
plt.imshow(img[...,0], cmap='gray')
plt.show()

In [None]:
img_m, label = random_crop_resize(1, img, img)
plt.imshow(img_m[...,0], cmap='gray')

In [None]:
sigma = random.choice(np.arange(0.1, 0.3, 0.02))
img_m = gaussian_noise(1, img, sigma=sigma)
plt.imshow(img_m[...,0])

In [None]:
img_m = cv2.flip(img, 1)
plt.imshow(img_m[...], cmap='gray')

In [None]:
img_m, mask = affine_transform(1, img, img, alpha_affine=20)
plt.imshow(img_m[...,0], cmap='gray')

In [None]:
args = random.choice(((1201, 10), (1501, 12), (991, 8)))
img_m, mask = elastic_transform(1, img, img, *args)
plt.imshow(img_m[...,0], cmap='gray')

In [None]:
dx, dy = 64, 64
img = test_inputs[0].copy()
print(img.shape)

grid_color = np.max(img)

# Modify the image to include the grid
img[:,::dy,:] = grid_color
img[::dx,:,:] = grid_color

# Show the result
plt.figure(figsize = (12,12))
plt.imshow(img[...,0])
plt.show()

In [None]:
img_m, label = random_crop_resize(1, img, img)

sigma = random.choice(np.arange(0.1, 0.3, 0.02))
img_m = gaussian_noise(1, img_m, sigma=sigma)

img_m = cv2.flip(img_m, 1)

img_m, mask = affine_transform(1, img_m, img_m, alpha_affine=20)


args = random.choice(((1201, 10), (1501, 12), (991, 8)))
img_m, mask = elastic_transform(1, img_m, img_m, *args)

plt.figure(figsize = (12,12))
plt.imshow(img_m[...,0])

In [None]:
dx, dy = 64, 64
img = test_inputs[0].copy()
print(img.shape)

grid_color = np.max(img)

# Modify the image to include the grid
img[:,::dy,:] = grid_color
img[::dx,:,:] = grid_color

# Show the result
plt.imshow(img[...,0], cmap='gray')
plt.show()

In [None]:
from matplotlib import pyplot as plt
import cv2
from skimage import exposure


fig, axs = plt.subplots(nrows=3, ncols=4, sharex=True, sharey=True,squeeze=True, figsize=(12,9))
for index in range(3):
    ax = axs[index]
    
    for a in ax:
        a.set_xticklabels([])
        a.set_yticklabels([])
        a.set_xticks([])
        a.set_yticks([])
        
        
    img_c, _ = random_crop_resize(1, img, img, 470)
    img_a, _ = affine_transform(1, img, img, alpha_affine=20)
    
    args = random.choice(((1201, 10), (1501, 12), (991, 8)))
    img_e, _ = elastic_transform(1, img, img, *args)
    
    img_all, _ = random_crop_resize(1, img, img, 470)
    img_all, _ = affine_transform(1, img_all, img_all, alpha_affine=20)
    args = random.choice(((1201, 10), (1501, 12), (991, 8)))
    img_all, _ = elastic_transform(1, img_all, img_all, *args)

    
    ax[0].imshow(img_c[...,0], cmap='gray')
    ax[1].imshow(img_a[...,0], cmap='gray')
    ax[2].imshow(img_e[...,0], cmap='gray')
    ax[3].imshow(img_all[...,0], cmap='gray')

plt.tight_layout()

axs[0][0].set_title('Crop', fontdict={'fontsize': 23, 'fontweight': 'medium'})
axs[0][1].set_title('Affine', fontdict={'fontsize': 23, 'fontweight': 'medium'})
axs[0][2].set_title('Elastic', fontdict={'fontsize': 23, 'fontweight': 'medium'})
axs[0][3].set_title('Combined', fontdict={'fontsize': 23, 'fontweight': 'medium'})

axs[0][0].text(-0.2,0.5, "A", size=23, ha="center", transform=axs[0][0].transAxes)
axs[1][0].text(-0.2,0.5, "B", size=23, ha="center", transform=axs[1][0].transAxes)
axs[2][0].text(-0.2,0.5, "C", size=23, ha="center", transform=axs[2][0].transAxes)
#axs[3][0].text(-0.2,0.5, "D", size=23, ha="center", transform=axs[3][0].transAxes)

plt.savefig("augment.png", bbox_inches='tight')