In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.



In [2]:
LEAKY_RELU_ALPHA = 0.3     # Keras default
BATCH_NORM_MOMENTUM = 0.99 # Keras default

conv_0_0_params = dict(filters=16, padding='same', activation='leaky_relu')
conv_0_1_params = dict(filters=16, padding='same', activation='leaky_relu')
conv_1_0_params = dict(filters=32, padding='same', activation='leaky_relu')
conv_1_1_params = dict(filters=32, padding='same', activation='leaky_relu')

pooling = 'max'
reduce = 'flatten'
extra_dense = False
extra_dense_units = 128

In [3]:
tf.enable_eager_execution()
tf.logging.set_verbosity(tf.logging.ERROR)

In [4]:
ds, info = tfds.load('fashion_mnist', split=['train', 'test'], with_info=True)

In [None]:
print(info)

In [5]:
fashion_train, fashion_test = ds

In [6]:
def parser(example):
    return example["image"] / 255, example["label"]

In [7]:
fashion_train = fashion_train.shuffle(1000).map(parser).batch(10).prefetch(10)

In [8]:
# https://www.tensorflow.org/tutorials/eager/custom_layers
class ConvBlock(tf.keras.Model):
    def __init__(self, 
                 filters=16, 
                 kernel_size=3, 
                 strides=1,
                 padding='same', 
                 activation='leaky_relu',
                 batch_normalization=True, 
                 conv_first=True):
        """2D Convolution -> Batch Normalization -> Activation stack builder

        # Arguments
            ## Conv2D features:
            num_filters (int): number of filters used by Conv2D
            kernel_size (int): square kernel dimension
            strides (int): square stride dimension
            padding (str): one of 'same' or 'valid'

            ## Other cell features
            activation (string): name of activation function to be used or None
            batch_normalization (bool): whether to use batch normalization
            conv_first (bool): conv -> bn         -> activation, if True; 
                               bn   -> activation -> conv,       if False
        """
        super(ConvBlock, self).__init__(name='')

        self.conv_first = conv_first
        self.conv = tf.keras.layers.Conv2D(
            filters, 
            kernel_size=kernel_size,
            strides=strides,
            padding=padding)
        
        if batch_normalization:
            self.batch_norm = \
                tf.keras.layers.BatchNormalization(momentum=BATCH_NORM_MOMENTUM)
        else:
            self.batch_norm = None
        
        # Determine which activation function to use:
        if isinstance(activation, str):
            if activation.lower() == 'leaky_relu':
                self.activation_fn = \
                    tf.keras.layers.LeakyReLU(alpha=LEAKY_RELU_ALPHA)
            else:
                self.activation_fn = \
                    tf.keras.layers.Activation(activation) # May raise an error
        else:
            self.activation_fn = None

    def call(self, input_tensor, training=False):
        x = input_tensor
        if self.conv_first:
            x = self.conv(x)
            if self.batch_norm is not None:
                x = self.batch_norm(x, training=training)
            if self.activation_fn is not None:
                x = self.activation_fn(x)
        else:
            if self.batch_norm is not None:
                x = self.batch_norm(x, training=training)
            if self.activation_fn is not None:
                x = self.activation_fn(x)
            x = self.conv(x)
        return x

In [9]:
# Select the pooling method
if pooling.lower() == 'max':
    Pooling2D = tf.keras.layers.MaxPooling2D
else:
    assert (pooling.lower() == 'average')
    Pooling2D = tf.keras.layers.AveragePooling2D
    
# Select reduce method
if reduce.lower() == 'flatten':
    Reduce = tf.keras.layers.Flatten
else:
    assert (reduce.lower() == 'gap')
    Reduce = tf.keras.layers.GlobalAveragePooling2D

In [10]:
inputs = tf.keras.Input(shape=(28, 28, 1))
x = inputs
x = ConvBlock(**conv_0_0_params)(x)    # conv_0_0
x = ConvBlock(**conv_0_1_params)(x)    # conv_0_1
x = Pooling2D()(x)
x = ConvBlock(**conv_1_0_params)(x)    # conv_1_0
x = ConvBlock(**conv_1_1_params)(x)    # conv_1_1
x = Reduce()(x)
if extra_dense == True:
    x = tf.keras.layers.Dense(extra_dense_units)(x)
x = tf.keras.layers.Dense(10, activation='softmax')(x)

In [11]:
model = tf.keras.Model(inputs=inputs, outputs=x)

In [12]:
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

In [13]:
fashion_iterator = fashion_train.make_one_shot_iterator()

In [14]:
model.fit(x=fashion_iterator, steps_per_epoch=6000, epochs=1)



<tensorflow.python.keras.callbacks.History at 0x7f49580c63c8>

In [None]:
model.summary()