In [1]:
import numpy as np
import os
import tensorflow as tf
import keras
import segmentation_models as sm
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import EarlyStopping, ModelCheckpoint

Using TensorFlow backend.


Segmentation Models: using `keras` framework.


### Data preprocessing (CIFAR-10)

In [2]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

In [3]:
datagen = ImageDataGenerator(
    featurewise_center=False,
    samplewise_center=False,
    featurewise_std_normalization=False,
    samplewise_std_normalization=False,
    zca_whitening=False,
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    vertical_flip=False
)
datagen.fit(x_train)

### Denoising autoencoder (DAE) training

In [4]:
input_shape = (32, 32, 3)

DAE_resnet = sm.Unet('resnet18', input_shape=input_shape, classes=3, activation='sigmoid')
DAE_vgg = sm.Unet('vgg16', input_shape=input_shape, classes=3, activation='sigmoid')

DAE_resnet_path = 'model/DAE_resnet.h5'
DAE_vgg_path = 'model/DAE_vgg.h5'

DAE_resnet.compile(loss='mse', optimizer='adam')
DAE_vgg.compile(loss='mse', optimizer='adam')

nb_epochs = 200
batch_size = 256



In [5]:
def train_DAE(DAE, DAE_path):
    if not os.path.exists(DAE_path):
        test_noise = np.random.normal(loc=0, scale=0.1, size=x_test.shape)
        x_test_noise = np.clip(x_test + test_noise, 0, 1)
        min_loss = np.finfo(np.float32).max
        patience = 0
        max_patience = 30
        for e in range(nb_epochs):
            batches = 0
            for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=batch_size):
                train_noise = np.random.normal(loc=0, scale=0.1, size=x_batch.shape)
                x_batch_noise = np.clip(x_batch + train_noise, 0, 1)
                DAE.fit(x_batch_noise, x_batch, verbose=0)
                batches += 1
                if batches >= len(x_train) // batch_size: break
            loss = DAE.evaluate(x_test_noise, x_test)
            print('epoch ', e, ', loss: ', loss)
            if loss < min_loss:
                min_loss = loss
                DAE.save_weights(DAE_path)
                patience = 0
            else: patience += 1
            if patience > max_patience:
                break

In [6]:
train_DAE(DAE_resnet, DAE_resnet_path)

In [7]:
train_DAE(DAE_vgg, DAE_vgg_path)