In [1]:
import numpy as np
import os
import denoise_model
import tensorflow as tf
import keras
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import EarlyStopping, ModelCheckpoint

Using TensorFlow backend.


Segmentation Models: using `keras` framework.


### Data preprocessing (CIFAR-10)

In [2]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

In [3]:
datagen = ImageDataGenerator(
    featurewise_center=False,
    samplewise_center=False,
    featurewise_std_normalization=False,
    samplewise_std_normalization=False,
    zca_whitening=False,
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    vertical_flip=False
)
datagen.fit(x_train)

### Train validation model

In [4]:
nb_epochs = 200
batch_size = 256

In [5]:
def train_val_model(model, denoise_model, denoise_model_path):
    if not os.path.exists(denoise_model_path):
        test_noise = np.random.normal(loc=0, scale=0.1, size=x_test.shape)
        x_test_noise = np.clip(x_test + test_noise, 0, 1)
        min_loss = np.finfo(np.float32).max
        patience = 0
        max_patience = 30
        for e in range(nb_epochs):
            batches = 0
            for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=batch_size):
                train_noise = np.random.normal(loc=0, scale=0.1, size=x_batch.shape)
                x_batch_noise = np.clip(x_batch + train_noise, 0, 1)
                denoise_model.fit(x_batch_noise, y_batch, verbose=0)
                model.fit(x_batch, y_batch, verbose=0)
                batches += 1
                if batches >= len(x_train) // batch_size: break
            denoise_model_loss = denoise_model.evaluate(x_test_noise, y_test)
            orig_model_loss = model.evaluate(x_test, y_test)
            loss = denoise_model_loss + orig_model_loss
            print('epoch ', e, ', loss: ', loss)
            if loss < min_loss:
                min_loss = loss
                denoise_model.save_weights(denoise_model_path)
                patience = 0
            else: patience += 1
            if patience > max_patience:
                break

In [6]:
dm_gaussian_blur_path = 'model/dm_gaussian_blur.h5'
dm_median_blur_path = 'model/dm_median_blur.h5'
dm_nl_means_path = 'model/dm_nl_means.h5'
dm_dae_resnet_model_resnet_path = 'model/dm_dae_resnet_model_resnet.h5'
dm_dae_vgg_model_resnet_path = 'model/dm_dae_vgg_model_resnet.h5'
dm_dae_resnet_model_vgg_path = 'model/dm_dae_resnet_model_vgg.h5'
dm_dae_vgg_model_vgg_path = 'model/dm_dae_vgg_model_vgg.h5'
dae_resnet_path = 'model/DAE_resnet.h5'
dae_vgg_path = 'model/DAE_vgg.h5'

dm_gaussian_blur = denoise_model.DM_GAUSSIAN_BLUR()
dm_median_blur = denoise_model.DM_MEDIAN_BLUR()
dm_nl_means = denoise_model.DM_NL_MEANS()
dm_dae_resnet_model_resnet = denoise_model.DM_DAE_RESNET()
dm_dae_vgg_model_resnet = denoise_model.DM_DAE_VGG()
dm_dae_resnet_model_vgg = denoise_model.DM_DAE_RESNET(model="vgg16")
dm_dae_vgg_model_vgg = denoise_model.DM_DAE_VGG(model="vgg16")

In [7]:
loss_object = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

dm_dae_resnet_model_resnet.denoiser.trainable = False
dm_dae_vgg_model_resnet.denoiser.trainable = False
dm_dae_resnet_model_vgg.denoiser.trainable = False
dm_dae_vgg_model_vgg.denoiser.trainable = False
dm_dae_resnet_model_resnet.denoiser.compile(loss='mse', optimizer='adam')
dm_dae_vgg_model_resnet.denoiser.compile(loss='mse', optimizer='adam')
dm_dae_resnet_model_vgg.denoiser.compile(loss='mse', optimizer='adam')
dm_dae_vgg_model_vgg.denoiser.compile(loss='mse', optimizer='adam')

dm_gaussian_blur.model.compile(loss=loss_object, optimizer='adam')
dm_median_blur.model.compile(loss=loss_object, optimizer='adam')
dm_nl_means.model.compile(loss=loss_object, optimizer='adam')
dm_dae_resnet_model_resnet.model.compile(loss=loss_object, optimizer='adam')
dm_dae_vgg_model_resnet.model.compile(loss=loss_object, optimizer='adam')
dm_dae_resnet_model_vgg.model.compile(loss=loss_object, optimizer='adam')
dm_dae_vgg_model_vgg.model.compile(loss=loss_object, optimizer='adam')

dm_gaussian_blur.denoise_model.compile(loss=loss_object, optimizer='adam')
dm_median_blur.denoise_model.compile(loss=loss_object, optimizer='adam')
dm_nl_means.denoise_model.compile(loss=loss_object, optimizer='adam')
dm_dae_resnet_model_resnet.denoise_model.compile(loss=loss_object, optimizer='adam')
dm_dae_vgg_model_resnet.denoise_model.compile(loss=loss_object, optimizer='adam')
dm_dae_resnet_model_vgg.denoise_model.compile(loss=loss_object, optimizer='adam')
dm_dae_vgg_model_vgg.denoise_model.compile(loss=loss_object, optimizer='adam')

dm_dae_resnet_model_resnet.denoiser.load_weights(dae_resnet_path)
dm_dae_vgg_model_resnet.denoiser.load_weights(dae_vgg_path)
dm_dae_resnet_model_vgg.denoiser.load_weights(dae_resnet_path)
dm_dae_vgg_model_vgg.denoiser.load_weights(dae_vgg_path)

In [8]:
train_val_model(dm_gaussian_blur.model, dm_gaussian_blur.denoise_model, dm_gaussian_blur_path)

In [9]:
train_val_model(dm_median_blur.model, dm_median_blur.denoise_model, dm_median_blur_path)

In [10]:
train_val_model(dm_dae_resnet_model_resnet.model, dm_dae_resnet_model_resnet.denoise_model, dm_dae_resnet_model_resnet_path)

In [11]:
train_val_model(dm_dae_vgg_model_resnet.model, dm_dae_vgg_model_resnet.denoise_model, dm_dae_vgg_model_resnet_path)

In [12]:
train_val_model(dm_dae_resnet_model_vgg.model, dm_dae_resnet_model_vgg.denoise_model, dm_dae_resnet_model_vgg_path)

In [13]:
train_val_model(dm_dae_vgg_model_vgg.model, dm_dae_vgg_model_vgg.denoise_model, dm_dae_vgg_model_vgg_path)

epoch  0 , loss:  2.711516697502136
epoch  1 , loss:  2.3426033720016477
epoch  2 , loss:  2.1635936340332034
epoch  3 , loss:  2.0398984000205993
epoch  4 , loss:  1.8155785824775696
epoch  5 , loss:  1.727113952922821
epoch  6 , loss:  1.741789979839325
epoch  7 , loss:  1.664573510313034
epoch  8 , loss:  1.6160442505836488
epoch  9 , loss:  1.669359286594391
epoch  10 , loss:  1.6750792260169982
epoch  11 , loss:  1.5349127403259277
epoch  12 , loss:  1.521119339799881
epoch  13 , loss:  2.3936730996608735
epoch  14 , loss:  1.7558628487586976
epoch  15 , loss:  1.613863602399826
epoch  16 , loss:  1.5558285702705383
epoch  17 , loss:  1.6247297393798827
epoch  18 , loss:  1.6057055946350096
epoch  19 , loss:  1.7788822031021119
epoch  20 , loss:  1.690976013803482
epoch  21 , loss:  4.605265830993653


KeyboardInterrupt: 