In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import scrapbook as sb
import time


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.



In [2]:
LEAKY_RELU_ALPHA = 0.3     # Keras default
BATCH_NORM_MOMENTUM = 0.99 # Keras default

conv_0_0_params = dict(filters=16)
conv_0_1_params = dict(filters=16)
conv_1_0_params = dict(filters=32)
conv_1_1_params = dict(filters=32)

pooling = 'max'
reduce = 'flatten'
extra_dense = False
extra_dense_units = 128

train_batch_size = 20
num_epochs = 1

eval_batch_size = 20
prefetch_size = 100
shuffle_buffer_size = 1000

In [3]:
# Parameters
pooling = "average"
reduce = "flatten"
extra_dense = False
extra_dense_units = 64


In [4]:
tf.enable_eager_execution()
tf.logging.set_verbosity(tf.logging.ERROR)

In [5]:
ds, info = tfds.load('fashion_mnist', split=['train', 'test'], with_info=True)

In [6]:
# https://www.tensorflow.org/tutorials/eager/custom_layers
class ConvBlock(tf.keras.Model):
    def __init__(self, 
                 filters=16, 
                 kernel_size=3, 
                 strides=1,
                 padding='same', 
                 activation='leaky_relu',
                 batch_normalization=True, 
                 conv_first=True):
        """2D Convolution -> Batch Normalization -> Activation stack builder

        # Arguments
            ## Conv2D features:
            num_filters (int): number of filters used by Conv2D
            kernel_size (int): square kernel dimension
            strides (int): square stride dimension
            padding (str): one of 'same' or 'valid'

            ## Other cell features
            activation (string): name of activation function to be used or None
            batch_normalization (bool): whether to use batch normalization
            conv_first (bool): conv -> bn         -> activation, if True; 
                               bn   -> activation -> conv,       if False
        """
        super(ConvBlock, self).__init__(name='')

        self.conv_first = conv_first
        self.conv = tf.keras.layers.Conv2D(
            filters, 
            kernel_size=kernel_size,
            strides=strides,
            padding=padding)
        
        if batch_normalization:
            self.batch_norm = \
                tf.keras.layers.BatchNormalization(momentum=BATCH_NORM_MOMENTUM)
        else:
            self.batch_norm = None
        
        # Determine which activation function to use:
        if isinstance(activation, str):
            if activation.lower() == 'leaky_relu':
                self.activation_fn = \
                    tf.keras.layers.LeakyReLU(alpha=LEAKY_RELU_ALPHA)
            else:
                self.activation_fn = \
                    tf.keras.layers.Activation(activation) # May raise an error
        else:
            self.activation_fn = None

    def call(self, input_tensor, training=False):
        x = input_tensor
        if self.conv_first:
            x = self.conv(x)
            if self.batch_norm is not None:
                x = self.batch_norm(x, training=training)
            if self.activation_fn is not None:
                x = self.activation_fn(x)
        else:
            if self.batch_norm is not None:
                x = self.batch_norm(x, training=training)
            if self.activation_fn is not None:
                x = self.activation_fn(x)
            x = self.conv(x)
        return x

In [7]:
# Select the pooling method
if pooling.lower() == 'max':
    Pooling2D = tf.keras.layers.MaxPooling2D
else:
    assert (pooling.lower() == 'average')
    Pooling2D = tf.keras.layers.AveragePooling2D
    
# Select reduce method
if reduce.lower() == 'flatten':
    Reduce = tf.keras.layers.Flatten
else:
    assert (reduce.lower() == 'gap')
    Reduce = tf.keras.layers.GlobalAveragePooling2D

In [8]:
inputs = tf.keras.Input(shape=(28, 28, 1))
x = inputs
x = ConvBlock(**conv_0_0_params)(x)    # conv_0_0
x = ConvBlock(**conv_0_1_params)(x)    # conv_0_1
x = Pooling2D()(x)
x = ConvBlock(**conv_1_0_params)(x)    # conv_1_0
x = ConvBlock(**conv_1_1_params)(x)    # conv_1_1
x = Reduce()(x)
if extra_dense == True:
    x = tf.keras.layers.Dense(extra_dense_units)(x)
x = tf.keras.layers.Dense(10, activation='softmax')(x)

In [9]:
model = tf.keras.Model(inputs=inputs, outputs=x)

In [10]:
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

In [11]:
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 28, 28, 1)         0         
_________________________________________________________________
conv_block (ConvBlock)       (None, 28, 28, 16)        224       
_________________________________________________________________
conv_block_1 (ConvBlock)     (None, 28, 28, 16)        2384      
_________________________________________________________________
average_pooling2d (AveragePo (None, 14, 14, 16)        0         
_________________________________________________________________
conv_block_2 (ConvBlock)     (None, 14, 14, 32)        4768      
_________________________________________________________________
conv_block_3 (ConvBlock)     (None, 14, 14, 32)        9376      
_________________________________________________________________
flatten (Flatten)            (None, 6272)              0         
__________

In [12]:
param_count = model.count_params()

In [13]:
fashion_train, fashion_test = ds

In [14]:
def parser(example):
    return example["image"] / 255, example["label"]

In [15]:
fashion_train = fashion_train.shuffle(shuffle_buffer_size).map(parser).batch(train_batch_size).prefetch(prefetch_size)

In [16]:
time_per_epoch = []
loss = []
accuracy = []
for i in range(1, num_epochs + 1):
    fashion_iterator = fashion_train.make_one_shot_iterator()
    
    start = time.time()
    hist = model.fit(x=fashion_iterator, steps_per_epoch=60000 // train_batch_size, epochs=1)
    end = time.time()
    
    time_per_epoch.append(end - start)
    loss.append(hist.history['loss'][0])
    accuracy.append(hist.history['acc'][0])

   1/3000 [..............................] - ETA: 1:06:10 - loss: 2.2948 - acc: 0.1500

   9/3000 [..............................] - ETA: 7:37 - loss: 2.1674 - acc: 0.2444   

  17/3000 [..............................] - ETA: 4:11 - loss: 1.7904 - acc: 0.3765

  25/3000 [..............................] - ETA: 2:56 - loss: 1.5601 - acc: 0.4500

  33/3000 [..............................] - ETA: 2:18 - loss: 1.4244 - acc: 0.5030

  41/3000 [..............................] - ETA: 1:54 - loss: 1.3298 - acc: 0.5293

  49/3000 [..............................] - ETA: 1:38 - loss: 1.2455 - acc: 0.5643

  57/3000 [..............................] - ETA: 1:27 - loss: 1.1888 - acc: 0.5877

  65/3000 [..............................] - ETA: 1:18 - loss: 1.1259 - acc: 0.6046

  73/3000 [..............................] - ETA: 1:11 - loss: 1.0782 - acc: 0.6212

  81/3000 [..............................] - ETA: 1:06 - loss: 1.0453 - acc: 0.6315

  89/3000 [..............................] - ETA: 1:02 - loss: 1.0052 - acc: 0.6433

  97/3000 [..............................] - ETA: 58s - loss: 0.9700 - acc: 0.6552 

 105/3000 [>.............................] - ETA: 55s - loss: 0.9451 - acc: 0.6643

 113/3000 [>.............................] - ETA: 52s - loss: 0.9272 - acc: 0.6699

 121/3000 [>.............................] - ETA: 50s - loss: 0.9098 - acc: 0.6756

 129/3000 [>.............................] - ETA: 47s - loss: 0.8873 - acc: 0.6814

 137/3000 [>.............................] - ETA: 46s - loss: 0.8601 - acc: 0.6909

 145/3000 [>.............................] - ETA: 44s - loss: 0.8398 - acc: 0.6979

 153/3000 [>.............................] - ETA: 42s - loss: 0.8222 - acc: 0.7036

 161/3000 [>.............................] - ETA: 41s - loss: 0.8077 - acc: 0.7093

 169/3000 [>.............................] - ETA: 40s - loss: 0.7950 - acc: 0.7127

 177/3000 [>.............................] - ETA: 39s - loss: 0.7871 - acc: 0.7150

 185/3000 [>.............................] - ETA: 38s - loss: 0.7760 - acc: 0.7197

 193/3000 [>.............................] - ETA: 37s - loss: 0.7679 - acc: 0.7223

 201/3000 [=>............................] - ETA: 36s - loss: 0.7554 - acc: 0.7269

 209/3000 [=>............................] - ETA: 35s - loss: 0.7432 - acc: 0.7297

 217/3000 [=>............................] - ETA: 34s - loss: 0.7364 - acc: 0.7336

 225/3000 [=>............................] - ETA: 34s - loss: 0.7276 - acc: 0.7373

 233/3000 [=>............................] - ETA: 33s - loss: 0.7220 - acc: 0.7386

 241/3000 [=>............................] - ETA: 32s - loss: 0.7131 - acc: 0.7421

 249/3000 [=>............................] - ETA: 32s - loss: 0.7049 - acc: 0.7450

 257/3000 [=>............................] - ETA: 31s - loss: 0.6981 - acc: 0.7471

 265/3000 [=>............................] - ETA: 31s - loss: 0.6916 - acc: 0.7489

 273/3000 [=>............................] - ETA: 30s - loss: 0.6869 - acc: 0.7496

 281/3000 [=>............................] - ETA: 30s - loss: 0.6830 - acc: 0.7512

 289/3000 [=>............................] - ETA: 29s - loss: 0.6828 - acc: 0.7526

 297/3000 [=>............................] - ETA: 29s - loss: 0.6790 - acc: 0.7549

 305/3000 [==>...........................] - ETA: 28s - loss: 0.6743 - acc: 0.7557

 313/3000 [==>...........................] - ETA: 28s - loss: 0.6678 - acc: 0.7580

 321/3000 [==>...........................] - ETA: 28s - loss: 0.6612 - acc: 0.7595

 329/3000 [==>...........................] - ETA: 27s - loss: 0.6589 - acc: 0.7609

 337/3000 [==>...........................] - ETA: 27s - loss: 0.6554 - acc: 0.7622

 345/3000 [==>...........................] - ETA: 27s - loss: 0.6519 - acc: 0.7633

 353/3000 [==>...........................] - ETA: 26s - loss: 0.6478 - acc: 0.7653

 361/3000 [==>...........................] - ETA: 26s - loss: 0.6436 - acc: 0.7666

 369/3000 [==>...........................] - ETA: 26s - loss: 0.6362 - acc: 0.7687

 377/3000 [==>...........................] - ETA: 26s - loss: 0.6342 - acc: 0.7699

 385/3000 [==>...........................] - ETA: 25s - loss: 0.6311 - acc: 0.7716

 393/3000 [==>...........................] - ETA: 25s - loss: 0.6293 - acc: 0.7718

 401/3000 [===>..........................] - ETA: 25s - loss: 0.6247 - acc: 0.7732

 409/3000 [===>..........................] - ETA: 24s - loss: 0.6224 - acc: 0.7751

 418/3000 [===>..........................] - ETA: 24s - loss: 0.6203 - acc: 0.7760

 426/3000 [===>..........................] - ETA: 24s - loss: 0.6173 - acc: 0.7772

 434/3000 [===>..........................] - ETA: 24s - loss: 0.6141 - acc: 0.7791

 442/3000 [===>..........................] - ETA: 24s - loss: 0.6102 - acc: 0.7807

 450/3000 [===>..........................] - ETA: 23s - loss: 0.6070 - acc: 0.7819

 458/3000 [===>..........................] - ETA: 23s - loss: 0.6041 - acc: 0.7828

 466/3000 [===>..........................] - ETA: 23s - loss: 0.6013 - acc: 0.7839

 474/3000 [===>..........................] - ETA: 23s - loss: 0.5967 - acc: 0.7853

 482/3000 [===>..........................] - ETA: 23s - loss: 0.5957 - acc: 0.7859

 490/3000 [===>..........................] - ETA: 22s - loss: 0.5960 - acc: 0.7865

 499/3000 [===>..........................] - ETA: 22s - loss: 0.5937 - acc: 0.7874

 507/3000 [====>.........................] - ETA: 22s - loss: 0.5923 - acc: 0.7879

 515/3000 [====>.........................] - ETA: 22s - loss: 0.5918 - acc: 0.7883

 523/3000 [====>.........................] - ETA: 22s - loss: 0.5889 - acc: 0.7893

 531/3000 [====>.........................] - ETA: 21s - loss: 0.5869 - acc: 0.7898

 539/3000 [====>.........................] - ETA: 21s - loss: 0.5861 - acc: 0.7906

 547/3000 [====>.........................] - ETA: 21s - loss: 0.5840 - acc: 0.7914

 555/3000 [====>.........................] - ETA: 21s - loss: 0.5818 - acc: 0.7923

 563/3000 [====>.........................] - ETA: 21s - loss: 0.5817 - acc: 0.7926

 571/3000 [====>.........................] - ETA: 21s - loss: 0.5789 - acc: 0.7937

 579/3000 [====>.........................] - ETA: 21s - loss: 0.5762 - acc: 0.7950

 587/3000 [====>.........................] - ETA: 20s - loss: 0.5731 - acc: 0.7962

 595/3000 [====>.........................] - ETA: 20s - loss: 0.5724 - acc: 0.7963

 603/3000 [=====>........................] - ETA: 20s - loss: 0.5730 - acc: 0.7967

 611/3000 [=====>........................] - ETA: 20s - loss: 0.5730 - acc: 0.7966

 619/3000 [=====>........................] - ETA: 20s - loss: 0.5718 - acc: 0.7969

 627/3000 [=====>........................] - ETA: 20s - loss: 0.5694 - acc: 0.7981

 635/3000 [=====>........................] - ETA: 20s - loss: 0.5674 - acc: 0.7990

 643/3000 [=====>........................] - ETA: 19s - loss: 0.5667 - acc: 0.7988

 651/3000 [=====>........................] - ETA: 19s - loss: 0.5655 - acc: 0.7990

 659/3000 [=====>........................] - ETA: 19s - loss: 0.5637 - acc: 0.7994

 667/3000 [=====>........................] - ETA: 19s - loss: 0.5629 - acc: 0.8001

 675/3000 [=====>........................] - ETA: 19s - loss: 0.5609 - acc: 0.8009

 683/3000 [=====>........................] - ETA: 19s - loss: 0.5576 - acc: 0.8023

 691/3000 [=====>........................] - ETA: 19s - loss: 0.5566 - acc: 0.8027

 699/3000 [=====>........................] - ETA: 19s - loss: 0.5559 - acc: 0.8030





























































































































































































































































































































































































































































































































































































In [17]:
fashion_test = fashion_test.map(parser).batch(eval_batch_size).prefetch(prefetch_size)

In [18]:
test_iterator = fashion_test.make_one_shot_iterator()
eval_results = model.evaluate(x=test_iterator, steps=10000 // eval_batch_size)

  1/500 [..............................] - ETA: 47s - loss: 0.2024 - acc: 0.9500

 23/500 [>.............................] - ETA: 3s - loss: 0.2615 - acc: 0.9130 

 48/500 [=>............................] - ETA: 1s - loss: 0.3250 - acc: 0.8875

 73/500 [===>..........................] - ETA: 1s - loss: 0.3158 - acc: 0.8884

 97/500 [====>.........................] - ETA: 1s - loss: 0.3373 - acc: 0.8835



































In [19]:
sb.glue("time_per_epoch", time_per_epoch)
sb.glue("loss", loss)
sb.glue("accuracy", accuracy)
sb.glue("param_count", param_count)
sb.glue("eval_results", eval_results)