In [1]:
%matplotlib inline

In [2]:
from keras.models import Model
from keras.callbacks import TensorBoard
from keras.models import load_model

import numpy as np

import matplotlib.pyplot as plt

from skimage.measure import compare_psnr
from skimage.util import view_as_windows, crop
from skimage.data import imread
from skimage.io import imshow
from skimage.measure import compare_psnr
from skimage import img_as_float

from sklearn.model_selection import train_test_split

from patchify import patchify, unpatchify

from network import network

Using TensorFlow backend.
  return f(*args, **kwds)


# Load data

In [3]:
dataset = np.load("data/images.npz")

# Generate noises

In [4]:
NOISE_LEVEL = 0.067

In [5]:
origins = dataset["images"]

In [6]:
noises = np.random.normal(0, NOISE_LEVEL, origins.shape)

In [7]:
origin_train, origin_test, noise_train, noise_test = train_test_split(origins, noises)

In [8]:
corrupted_train = origin_train + noise_train
corrupted_test = origin_test + noise_test

# Split into patches

In [9]:
PATCH_SIZE=50
STEP=5
FULL_PATCH_MATRIX_SHAPE = patchify(origins[0], (PATCH_SIZE,PATCH_SIZE), STEP).shape
PATCH_MATRIX_SHAPE = FULL_PATCH_MATRIX_SHAPE[:2]


In [10]:
def patchify_batch(data):
    patches = [
        patchify(data[idx], (PATCH_SIZE,PATCH_SIZE), STEP).reshape(-1, PATCH_SIZE, PATCH_SIZE) 
        for idx in range(len(data))
    ]
    return np.concatenate(patches).reshape(-1, PATCH_SIZE, PATCH_SIZE, 1)

In [11]:
trainX = patchify_batch(corrupted_train)
trainX = trainX / trainX.max()
testX = patchify_batch(corrupted_test)
testX = testX / testX.max()
trainY = patchify_batch(noise_train)
testY = patchify_batch(noise_test)

MemoryError: 

# Build network

In [None]:
input, output = network(30, 30)

In [None]:
model = Model(input, output)
model.compile(optimizer='sgd', loss='mean_squared_error')

# Train model

In [None]:
board = TensorBoard(log_dir='./logs/run7')

In [None]:
model.fit(x=trainX,
          y=trainY,
          epochs=20,
          batch_size=512,
          shuffle=True,
          verbose=0,
          validation_data=(testX, testY),
          callbacks=[board]
         )

In [None]:
model.save("model-30*30-20-30.h5")

In [None]:
first_image_indices = PATCH_MATRIX_SHAPE[0]*PATCH_MATRIX_SHAPE[1]
predY = model.predict(testX[:first_image_indices])
recovered = testX[:first_image_indices] - predY


recovered = unpatchify(recovered.reshape(FULL_PATCH_MATRIX_SHAPE), step=STEP)
corrupted = unpatchify(testX[:first_image_indices].reshape(FULL_PATCH_MATRIX_SHAPE), step=STEP)
original = origin_test[0]

In [None]:
plt.figure(figsize=(18,9))
plt.subplot(131).set_title("original")
plt.imshow(original,  cmap="gray")
plt.subplot(132).set_title("recovered")
plt.imshow(recovered,  cmap="gray")
plt.subplot(133).set_title("corrupted")
plt.imshow(corrupted,  cmap="gray")
plt.legend()

In [None]:
r_h, r_w = recovered.shape
psnr_corr = compare_psnr(original[:r_h, :r_w], corrupted[:r_h, :r_w])
psnr_reco = compare_psnr(original[:r_h, :r_w], recovered[:r_h, :r_w])
print(f"psnr corrupted: {psnr_corr}, recovered: {psnr_reco}")