In [1]:
%matplotlib inline
from keras.layers import Conv2D, Conv2DTranspose, MaxPool2D, UpSampling2D, BatchNormalization, Lambda, Input, Concatenate, Dropout
from keras.callbacks import ReduceLROnPlateau, EarlyStopping
from keras.models import Sequential, Model
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
from keras import layers
from keras import optimizers
import matplotlib.pyplot as plt
import numpy as np
import cv2

Using TensorFlow backend.


In [2]:
# Load dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# Map data to floats in [0, 1]
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Add Gaussian noise to data
noise_factor = 0.1
x_train_noise = x_train + noise_factor * np.random.normal(size=x_train.shape)
x_test_noise = x_test + noise_factor * np.random.normal(size=x_test.shape)

# Clamp noisy data to [0, 1]
x_train_noise = np.clip(x_train_noise, 0.0, 1.0)
x_test_noise = np.clip(x_test_noise, 0.0, 1.0)

In [4]:
# Define amount of information gating for the Direct Symmetric Connections (DSCs) in each of the deconvolutional layers.
# 1.0 = no residual (only deconv output)
# 0.0 = only residual (no deconv output)
gating_factor = 1.0

# Input
inputs = Input(shape=(x_train.shape[1:]))

# Encoder
conv_1 = Conv2D(64, 4, strides=(2, 2), activation='relu', padding='same')(inputs)
conv_1 = BatchNormalization()(conv_1)

conv_2 = Conv2D(64, 4, strides=(2, 2), activation='relu', padding='same')(conv_1)
conv_2 = BatchNormalization()(conv_2)

conv_3 = Conv2D(128, 4, strides=(2, 2), activation='relu', padding='same')(conv_2)
conv_3 = BatchNormalization()(conv_3)

conv_4 = Conv2D(128, 4, strides=(2, 2), activation='relu', padding='same')(conv_3)
conv_4 = BatchNormalization()(conv_4)

conv_5 = Conv2D(256, 4, strides=(2, 2), activation='relu', padding='same')(conv_4)
conv_5 = BatchNormalization()(conv_5)

# Decoder
deconv_1 = Conv2DTranspose(128, 4, strides=(2, 2), activation='relu', padding='same')(conv_5)
conv_4   = Lambda(lambda x: x * (gating_factor))(conv_4)
deconv_1 = layers.Concatenate()([deconv_1, conv_4])
deconv_1 = BatchNormalization()(deconv_1)

deconv_2 = Conv2DTranspose(128, 4, strides=(2, 2), activation='relu', padding='same')(deconv_1)
conv_3   = Lambda(lambda x: x * (gating_factor))(conv_3)
deconv_2 = layers.Concatenate()([deconv_2, conv_3])
deconv_2 = BatchNormalization()(deconv_2)

deconv_3 = Conv2DTranspose(64, 4, strides=(2, 2), activation='relu', padding='same')(deconv_2)
conv_2   = Lambda(lambda x: x * (gating_factor))(conv_2)
deconv_3 = layers.Concatenate()([deconv_3, conv_2])
deconv_3 = BatchNormalization()(deconv_3)

deconv_4 = Conv2DTranspose(64, 4, strides=(2, 2), activation='relu', padding='same')(deconv_3)
conv_1   = Lambda(lambda x: x * (gating_factor))(conv_1)
deconv_4 = layers.Concatenate()([deconv_4, conv_1])
deconv_4 = BatchNormalization()(deconv_4)
deconv_4 = Dropout(0.25)(deconv_4)

# Output
deconv_5 = Conv2DTranspose(3, 4, strides=(2, 2), activation='sigmoid', padding='same')(deconv_4)

model = Model(inputs=inputs, outputs=deconv_5)

model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            (None, 32, 32, 3)    0                                            
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 16, 16, 64)   3136        input_2[0][0]                    
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 16, 16, 64)   256         conv2d_6[0][0]                   
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 8, 8, 64)     65600       batch_normalization_10[0][0]     
__________________________________________________________________________________________________
batch_norm

In [5]:
# Define training parameters
EPOCHS = 1 # more than 1 epoch causes this model to overfit on CIFAR-10
BATCH_SIZE = 128
VAL_SPLIT = 0.1

In [6]:
# Compile model with Adam optimizer
model.compile(loss='binary_crossentropy', optimizer=optimizers.adam())

In [8]:
# Train model
H = model.fit(x_train_noise, x_train, batch_size=BATCH_SIZE, epochs=EPOCHS, validation_split=VAL_SPLIT)

Train on 45000 samples, validate on 5000 samples
Epoch 1/1


In [None]:
# Plot losses (only useful for >1 epoch)
plt.plot(H.history['loss'], color='blue', label='train loss')
plt.plot(H.history['val_loss'], color='red', label='val loss')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Loss')

In [11]:
# Define result display constants
DIMS = (256, 256)
INTERP = cv2.INTER_CUBIC

while True:
    # Pick a random index into test set
    i = np.random.choice(np.arange(x_test.shape[0]), size=(1,), replace=False)[0]

    # Retrieve original image, noisy image, and label using index
    orig_image = x_test[i]
    noise_image = x_test_noise[i]
    true_label = y_test[i]
    
    # Attempt to reconstruct original image from noisy image using model
    recon_image = model.predict(np.expand_dims(noise_image, axis=0))
    recon_image = np.reshape(recon_image, orig_image.shape)
    
    # Resize images for easy viewing
    orig_image = cv2.resize(orig_image, DIMS, interpolation=INTERP)
    noise_image = cv2.resize(noise_image, DIMS, interpolation=INTERP)
    recon_image = cv2.resize(recon_image, DIMS, interpolation=INTERP)
    
    # Join original, noisy, and reconstructed images
    combo_image = np.uint8(np.clip(np.hstack([orig_image, noise_image, recon_image]) * 255, 0, 255))
    combo_image = cv2.cvtColor(combo_image, cv2.COLOR_RGB2BGR)
    
    # Display joined images until a key is pressed, exiting on "q"
    cv2.imshow(str(true_label), combo_image)
    if cv2.waitKey(0) == ord('q'):
        break
    else:
        cv2.destroyAllWindows()
    
cv2.destroyAllWindows()