In [1]:
import numpy as np
import os
import denoise_model
import keras
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import EarlyStopping, ModelCheckpoint

Using TensorFlow backend.


### Data preprocessing (MNIST)

In [2]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
image_size = x_train.shape[1]
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
x_test = np.reshape(x_test, [-1, image_size, image_size, 1])

In [3]:
datagen = ImageDataGenerator(
    featurewise_center=False,
    samplewise_center=False,
    featurewise_std_normalization=False,
    samplewise_std_normalization=False,
    zca_whitening=False,
    rotation_range=8,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=False,
    vertical_flip=False
)
datagen.fit(x_train)

### Train validation model

In [4]:
nb_epochs = 200
batch_size = 256

In [5]:
def train_val_model(model, denoise_model, denoise_model_path):
    if not os.path.exists(denoise_model_path):
        test_noise = np.random.normal(loc=0, scale=0.1, size=x_test.shape)
        x_test_noise = np.clip(x_test + test_noise, 0, 1)
        min_loss = np.finfo(np.float32).max
        patience = 0
        max_patience = 30
        for e in range(nb_epochs):
            batches = 0
            for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=batch_size):
                train_noise = np.random.normal(loc=0, scale=0.1, size=x_batch.shape)
                x_batch_noise = np.clip(x_batch + train_noise, 0, 1)
                denoise_model.fit(x_batch_noise, y_batch, verbose=0)
                model.fit(x_batch, y_batch, verbose=0)
                batches += 1
                if batches >= len(x_train) // batch_size: break
            denoise_model_loss = denoise_model.evaluate(x_test_noise, y_test)
            orig_model_loss = model.evaluate(x_test, y_test)
            loss = denoise_model_loss + orig_model_loss
            print('epoch ', e, ', loss: ', loss)
            if loss < min_loss:
                min_loss = loss
                denoise_model.save_weights(denoise_model_path)
                patience = 0
            else: patience += 1
            if patience > max_patience:
                break

In [6]:
def train_val_model_ver2(model, denoise_model, denoise_model_path):
    if not os.path.exists(denoise_model_path):
        train_noise = np.random.normal(loc=0, scale=0.1, size=x_train.shape)
        test_noise = np.random.normal(loc=0, scale=0.1, size=x_test.shape)
        x_train_noise = np.clip(x_train + train_noise, 0, 1)
        x_test_noise = np.clip(x_test + test_noise, 0, 1)
        callbacks = [EarlyStopping(monitor='val_accuracy', patience=30), ModelCheckpoint(denoise_model_path, monitor='val_accuracy', save_best_only=True)]
        denoise_model.fit(x_train_noise, y_train, epochs=nb_epochs, batch_size=batch_size, validation_data=(x_test_noise, y_test), callbacks=callbacks)

In [7]:
dm_gaussian_blur_path = 'model/dm_gaussian_blur.h5'
dm_median_blur_path = 'model/dm_median_blur.h5'
dm_nl_means_path = 'model/dm_nl_means.h5'
dm_dae_path = 'model/dm_dae.h5'
dae_path = 'model/DAE.h5'

dm_gaussian_blur = denoise_model.DM_GAUSSIAN_BLUR()
dm_median_blur = denoise_model.DM_MEDIAN_BLUR()
dm_nl_means = denoise_model.DM_NL_MEANS()
dm_dae = denoise_model.DM_DAE()

In [8]:
loss_object = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

dm_dae.denoiser.trainable = False
dm_dae.denoiser.compile(loss='mse', optimizer='adam')

dm_gaussian_blur.model.compile(loss=loss_object, optimizer='adam')
dm_median_blur.model.compile(loss=loss_object, optimizer='adam')
dm_nl_means.model.compile(loss=loss_object, optimizer='adam')
dm_dae.model.compile(loss=loss_object, optimizer='adam')

dm_gaussian_blur.denoise_model.compile(loss=loss_object, optimizer='adam', metrics=['accuracy'])
dm_median_blur.denoise_model.compile(loss=loss_object, optimizer='adam', metrics=['accuracy'])
dm_nl_means.denoise_model.compile(loss=loss_object, optimizer='adam', metrics=['accuracy'])
dm_dae.denoise_model.compile(loss=loss_object, optimizer='adam', metrics=['accuracy'])

dm_dae.denoiser.load_weights(dae_path)

In [9]:
train_val_model_ver2(dm_gaussian_blur.model, dm_gaussian_blur.denoise_model, dm_gaussian_blur_path)

Train on 60000 samples, validate on 10000 samples
Epoch 1/200
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200
Epoch 12/200
Epoch 13/200
Epoch 14/200
Epoch 15/200
Epoch 16/200
Epoch 17/200
Epoch 18/200
Epoch 19/200
Epoch 20/200
Epoch 21/200
Epoch 22/200
Epoch 23/200
Epoch 24/200
Epoch 25/200
Epoch 26/200
Epoch 27/200
Epoch 28/200
Epoch 29/200
Epoch 30/200
Epoch 31/200
Epoch 32/200
Epoch 33/200
Epoch 34/200
Epoch 35/200
Epoch 36/200
Epoch 37/200
Epoch 38/200
Epoch 39/200
Epoch 40/200
Epoch 41/200
Epoch 42/200
Epoch 43/200
Epoch 44/200
Epoch 45/200
Epoch 46/200
Epoch 47/200
Epoch 48/200
Epoch 49/200
Epoch 50/200
Epoch 51/200
Epoch 52/200
Epoch 53/200
Epoch 54/200
Epoch 55/200


Epoch 56/200
Epoch 57/200
Epoch 58/200
Epoch 59/200
Epoch 60/200
Epoch 61/200
Epoch 62/200
Epoch 63/200
Epoch 64/200
Epoch 65/200
Epoch 66/200
Epoch 67/200
Epoch 68/200
Epoch 69/200
Epoch 70/200
Epoch 71/200
Epoch 72/200
Epoch 73/200
Epoch 74/200
Epoch 75/200
Epoch 76/200
Epoch 77/200
Epoch 78/200
Epoch 79/200
Epoch 80/200
Epoch 81/200
Epoch 82/200
Epoch 83/200
Epoch 84/200
Epoch 85/200
Epoch 86/200
Epoch 87/200
Epoch 88/200
Epoch 89/200
Epoch 90/200
Epoch 91/200
Epoch 92/200
Epoch 93/200
Epoch 94/200
Epoch 95/200
Epoch 96/200
Epoch 97/200
Epoch 98/200
Epoch 99/200


In [10]:
train_val_model_ver2(dm_median_blur.model, dm_median_blur.denoise_model, dm_median_blur_path)

Train on 60000 samples, validate on 10000 samples
Epoch 1/200
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200
Epoch 12/200
Epoch 13/200
Epoch 14/200
Epoch 15/200
Epoch 16/200
Epoch 17/200
Epoch 18/200
Epoch 19/200
Epoch 20/200
Epoch 21/200
Epoch 22/200
Epoch 23/200
Epoch 24/200
Epoch 25/200
Epoch 26/200
Epoch 27/200
Epoch 28/200
Epoch 29/200
Epoch 30/200
Epoch 31/200
Epoch 32/200
Epoch 33/200
Epoch 34/200
Epoch 35/200
Epoch 36/200
Epoch 37/200
Epoch 38/200
Epoch 39/200
Epoch 40/200
Epoch 41/200
Epoch 42/200
Epoch 43/200
Epoch 44/200
Epoch 45/200
Epoch 46/200
Epoch 47/200
Epoch 48/200
Epoch 49/200
Epoch 50/200
Epoch 51/200
Epoch 52/200
Epoch 53/200
Epoch 54/200


Epoch 55/200
Epoch 56/200
Epoch 57/200
Epoch 58/200
Epoch 59/200
Epoch 60/200
Epoch 61/200
Epoch 62/200
Epoch 63/200
Epoch 64/200
Epoch 65/200
Epoch 66/200


In [11]:
train_val_model_ver2(dm_dae.model, dm_dae.denoise_model, dm_dae_path)

Train on 60000 samples, validate on 10000 samples
Epoch 1/200
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200
Epoch 12/200
Epoch 13/200
Epoch 14/200
Epoch 15/200
Epoch 16/200
Epoch 17/200
Epoch 18/200
Epoch 19/200
Epoch 20/200
Epoch 21/200
Epoch 22/200
Epoch 23/200
Epoch 24/200
Epoch 25/200
Epoch 26/200
Epoch 27/200
Epoch 28/200
Epoch 29/200
Epoch 30/200
Epoch 31/200
Epoch 32/200
Epoch 33/200
Epoch 34/200
Epoch 35/200
Epoch 36/200
Epoch 37/200
Epoch 38/200
Epoch 39/200
Epoch 40/200
Epoch 41/200
Epoch 42/200
Epoch 43/200
Epoch 44/200
Epoch 45/200
Epoch 46/200
Epoch 47/200
Epoch 48/200
Epoch 49/200
Epoch 50/200
Epoch 51/200
Epoch 52/200
Epoch 53/200
Epoch 54/200
Epoch 55/200


Epoch 56/200
Epoch 57/200
Epoch 58/200
Epoch 59/200
Epoch 60/200
Epoch 61/200
Epoch 62/200
Epoch 63/200
Epoch 64/200
Epoch 65/200
Epoch 66/200
