In [None]:
import dlcomp.augmentations as aug
from dlcomp.data_handling import load_train_dataset, AugmentedDataset
import matplotlib.pyplot as plt
import numpy as np
from imgaug.augmentables.heatmaps import HeatmapsOnImage
from imgaug import augmenters as iaa
import colorsys

In [None]:
ds = load_train_dataset('data/train_noisy.npy', 'data/train_clean.npy')

def plot_dataset(transform, start, n, m):
    fig, axs = plt.subplots(n, 3 + m, figsize=(2*(3 + m), 2*n))

    for i, ax in enumerate(axs[:, 0]):
        x, y = ds[start + i]
        ax.imshow(x)
        ax.axis('off')

    # augmentations
    for j in range(m):
        for i, ax in enumerate(axs[:, 3 + j]):
            x, y = ds[start + i]
            
            if j == 0:
                hm = HeatmapsOnImage(y.astype('f4'), shape=x.shape, min_value=0, max_value=255)
                x, y = transform(image=x, heatmaps=hm)
                ax_img = axs[i, 2]
                ax_img.imshow(x)
                ax_img.axis('off')
                ax.imshow(y.get_arr().astype(np.uint8))
            else:
                ax.imshow(transform(image=x))
                
            ax.axis('off')

    for i, ax in enumerate(axs[:, 1]):
        x, y = ds[start + i]
        ax.imshow(y)
        ax.axis('off')

In [None]:
print(f'N: {len(ds)}')
print(f'shape: {ds[0][0].shape}')

In [None]:
y = ds[0][1].astype('f8')
print('MSE Red vs Green: ', np.mean(np.sqrt((y[:,:,0] - y[:,:,1])**2)))
print('MSE Red vs Blue: ', np.mean(np.sqrt((y[:,:,0] - y[:,:,2])**2)))
print('MSE Green vs Blue: ', np.mean(np.sqrt((y[:,:,1] - y[:,:,2])**2)))

y[y < 1] = np.nan
print('Mean Color:', np.nanmean(y, axis=(0,1)))

print('-' * 50)

y = np.stack([ds[i][1].astype('f8') for i in range(500)])
y[y < 1] = np.nan
mean_colors = np.nanmean(y, axis=(1,2))
hls_means = np.array([colorsys.rgb_to_hls(*color) for color in mean_colors])

plt.hist(hls_means[:,0])
plt.title('Mean Hue')
plt.show()

plt.hist(hls_means[:,1])
plt.title('Mean Lightness')
plt.show()

plt.hist(hls_means[:,2])
plt.title('Mean Saturation')
plt.show()

In [None]:
start = 40
n = 10
m = 4

In [None]:
plot_dataset(aug.baseline, start, n, m)

In [None]:
plot_dataset(aug.weak, start, n, m)

In [None]:
plot_dataset(iaa.AddToHue((-255, 255)),  start, n, m)

In [None]:
plot_dataset(iaa.Affine(scale={"x": (0.9, 1.1), "y": (0.9, 1.1)}, translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)}, mode='symmetric'), start, n, m)

In [None]:
def no_shape_change():
    return iaa.Sequential([
        iaa.Fliplr(0.5),
        iaa.Flipud(0.5),
        iaa.Affine(
            scale={"x": (0.9, 1.1), "y": (0.9, 1.1)}, 
            translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)}, 
            mode='symmetric'
        ),
        iaa.MultiplySaturation((0.2, 1.3)),
        iaa.AddToHue((-255, 255))
    ])

plot_dataset(no_shape_change(), start, n, m)