In [27]:
try:
    # %tensorflow_version only exists in Colab.
    %tensorflow_version 2.x
    # !pip install -q -U tfx==0.15.0rc0
    print("You can safely ignore the package incompatibility errors.")
except Exception:
    pass


import tensorflow as tf
from tensorflow import keras
assert tf.__version__ >= "2.0"
print(tf.__version__)

import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.utils import plot_model
from tensorflow.keras import backend as K
from tensorflow.keras import datasets
from tqdm import tqdm_notebook
from sklearn import metrics as skm

tf.random.set_seed(1228)

You can safely ignore the package incompatibility errors.
2.1.0


In [2]:
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [0]:
# create numpy to tensor dataset 
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_data = tf.data.Dataset.from_tensor_slices((x_test, y_test))

# shuffle the data
train_data = train_data.shuffle(1024)


# create preprocess function 
@tf.function
def preprocess_train(x, y):
  x = tf.cast(x, tf.float32)
  x = tf.expand_dims(x, 2)
  x = x / 255.
  # x = tf.image.random_flip_left_right(x, seed=1228)
  # x = tf.image.random_flip_up_down(x, seed=1228)
  # x = tf.image.rot90(x)
  return x, tf.cast(y, tf.float32)

@tf.function
def preprocess_test(x, y):
  x = tf.cast(x, tf.float32)
  x = tf.expand_dims(x, 2)
  x = x / 255.
  return x, tf.cast(y, tf.float32)

In [0]:
train_data = train_data.map(preprocess_train)
test_data = test_data.map(preprocess_test)

In [0]:
train_data = train_data.batch(512).prefetch(128)
test_data = test_data.batch(512).prefetch(128)

In [0]:
train_data

<PrefetchDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.float32)>

In [0]:
# counter = 0
# for i in test_data:
#   print(i)
#   print(i[0].shape)
#   print(i[1].shape)
#   plt.imshow(i[0][0].numpy().reshape(28,28), cmap='gray')
#   counter += 1
#   if counter == 1:
#     break

## Base model CNN

In [0]:
# model archtecture 
input_ = tf.keras.layers.Input(shape=(28,28,1))
x = tf.keras.layers.Conv2D(128,
                           (3,3),
                           padding='same',
                           activation='relu',
                           kernel_initializer='he_normal',
                           kernel_regularizer=tf.keras.regularizers.l2(l=0.02),
                           kernel_constraint=tf.keras.constraints.MinMaxNorm())(input_)
x = tf.keras.layers.LayerNormalization()(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.MaxPool2D()(x)
x = tf.keras.layers.Dropout(0.4, trainable=True)(x)
for i in range(10):
  x = tf.keras.layers.Conv2D(128,
                             (3,3),
                             padding='same',
                             activation='relu',
                             kernel_initializer='he_normal',
                             kernel_regularizer=tf.keras.regularizers.l2(l=0.02),
                             kernel_constraint=tf.keras.constraints.MinMaxNorm())(x)
  x = tf.keras.layers.LayerNormalization()(x)
  x = tf.keras.layers.BatchNormalization()(x)
  if i == 7 :
    x = tf.keras.layers.MaxPool2D()(x)
  x = tf.keras.layers.Dropout(0.4, trainable=True)(x)
x = tf.keras.layers.GlobalAvgPool2D()(x)
x = tf.keras.layers.LayerNormalization()(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Dropout(0.4, trainable=True)(x)
x = tf.keras.layers.Dense(128, activation='relu')(x)
x = tf.keras.layers.LayerNormalization()(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Dropout(0.4, trainable=True)(x)
out_ = tf.keras.layers.Dense(10, activation='softmax')(x)

# define model
model = tf.keras.models.Model(input_, out_)
model.summary()

# Training

# Define loss 
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()

# Define optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate=0.008)

# Define batch metcis summary units
train_loss = tf.keras.metrics.Mean(name='train_loss')
test_loss = tf.keras.metrics.Mean(name='test_loss')
train_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='train_acc')
test_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='test_acc')

# Define train function
# Calculate Gradients, and apply to optimizer
@tf.function
def train_step(images, labels):
  with tf.GradientTape() as tape:
    pred = model(images, training=True)
    loss = loss_fn(labels, pred)
  g = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(g, model.trainable_variables))
  train_loss(loss)
  train_acc(labels, pred)

# Define test functionm
@tf.function
def test_step(images, labels):
  pred = model(images, training=False)
  loss = loss_fn(labels, pred)
  test_loss(loss)
  test_acc(labels, pred)

# Train
EPOCHS = 10

for epoch in range(EPOCHS):
  train_loss.reset_states()
  test_loss.reset_states()

  train_acc.reset_states()
  test_acc.reset_states()

  for img, label in train_data:
    train_step(img, label)
  
  for img, label in test_data:
    test_step(img, label)

  template = 'Epoch:{} \t Loss:{} \t  Acc:{} \t Val_loss:{} \t Val_acc:{}'
  print(template.format(epoch, train_loss.result(), train_acc.result()*100, test_loss.result(), test_acc.result()*100))

Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv2d_32 (Conv2D)           (None, 28, 28, 128)       1280      
_________________________________________________________________
layer_normalization_16 (Laye (None, 28, 28, 128)       256       
_________________________________________________________________
batch_normalization_16 (Batc (None, 28, 28, 128)       512       
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 14, 14, 128)       0         
_________________________________________________________________
dropout_26 (Dropout)         (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_33 (Conv2D)           (None, 14, 14, 128)       1475

## Define with keras eager execution

In [0]:
# model archtecture 
input_ = tf.keras.layers.Input(shape=(28,28,1))
x = tf.keras.layers.Conv2D(128,
                           (3,3),
                           padding='same',
                           activation='relu',
                          #  kernel_initializer='he_normal',
                           kernel_regularizer=tf.keras.regularizers.l2(l=0.02),
                           kernel_constraint=tf.keras.constraints.MinMaxNorm())(input_)
x = tf.keras.layers.LayerNormalization()(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.MaxPool2D()(x)
x = tf.keras.layers.Dropout(0.4, trainable=True)(x)
for i in range(10):
  x = tf.keras.layers.Conv2D(128,
                             (3,3),
                             padding='same',
                             activation='relu',
                            #  kernel_initializer='he_normal',
                             kernel_regularizer=tf.keras.regularizers.l2(l=0.02),
                             kernel_constraint=tf.keras.constraints.MinMaxNorm())(x)
  x = tf.keras.layers.LayerNormalization()(x)
  x = tf.keras.layers.BatchNormalization()(x)
  if i == 7 :
    x = tf.keras.layers.MaxPool2D()(x)
  x = tf.keras.layers.Dropout(0.4, trainable=True)(x)
x = tf.keras.layers.GlobalAvgPool2D()(x)
x = tf.keras.layers.LayerNormalization()(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Dropout(0.4, trainable=True)(x)
x = tf.keras.layers.Dense(128, activation='relu')(x)
x = tf.keras.layers.LayerNormalization()(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Dropout(0.4, trainable=True)(x)
out_ = tf.keras.layers.Dense(10, activation='softmax')(x)

# define model
model = tf.keras.models.Model(input_, out_)
model.summary()

optimizer = tf.keras.optimizers.Adam(learning_rate=0.008)
model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['acc'])
hist = model.fit(train_data, epochs=10, validation_data=test_data)

Model: "model_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_5 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv2d_54 (Conv2D)           (None, 28, 28, 128)       1280      
_________________________________________________________________
layer_normalization_42 (Laye (None, 28, 28, 128)       256       
_________________________________________________________________
batch_normalization_42 (Batc (None, 28, 28, 128)       512       
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 14, 14, 128)       0         
_________________________________________________________________
dropout_52 (Dropout)         (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_55 (Conv2D)           (None, 14, 14, 128)       1475

# Create Highway networks
### 1) create as plane
### 2) create as function api
### 3) create as sub-class api

In [0]:
tf.keras.backend.clear_session()

# model archtecture 
input_ = tf.keras.layers.Input(shape=(28,28,1))
x = tf.keras.layers.Conv2D(128,
                           (3,3),
                           padding='same',
                           activation='relu',
                          #  kernel_initializer='he_normal',
                           kernel_regularizer=tf.keras.regularizers.l2(l=0.02),
                           kernel_constraint=tf.keras.constraints.MinMaxNorm())(input_)
x = tf.keras.layers.LayerNormalization()(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.MaxPool2D()(x)
x = tf.keras.layers.Dropout(0.4, trainable=True)(x)
for i in range(10):
  # highway block
  h = tf.keras.layers.Conv2D(128,
                             (3,3),
                             padding='same',
                             activation='relu',
                            #  kernel_initializer='he_normal',
                             kernel_regularizer=tf.keras.regularizers.l2(l=0.02),
                             kernel_constraint=tf.keras.constraints.MinMaxNorm())(x)
  c = tf.keras.layers.Conv2D(128, 
                             (3,3), 
                             padding='same',
                             activation='sigmoid',
                             kernel_regularizer=tf.keras.regularizers.l2(l=0.02),
                             kernel_constraint=tf.keras.constraints.MinMaxNorm(),
                             bias_initializer = tf.keras.initializers.Constant(-3.))(x)
  hc = tf.keras.layers.Multiply()([h,c])
  ad = tf.keras.layers.Lambda(lambda x: 1.0 - x)(c)

  xc = tf.keras.layers.Multiply()([x, ad])
  x = tf.keras.layers.Add()([hc, xc])  
  if i == 7 :
    x = tf.keras.layers.MaxPool2D()(x)
  x = tf.keras.layers.Dropout(0.4, trainable=True)(x)
x = tf.keras.layers.GlobalAvgPool2D()(x)
x = tf.keras.layers.LayerNormalization()(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Dropout(0.4, trainable=True)(x)
x = tf.keras.layers.Dense(128, activation='relu')(x)
x = tf.keras.layers.LayerNormalization()(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Dropout(0.4, trainable=True)(x)
out_ = tf.keras.layers.Dense(10, activation='softmax')(x)

# define model
model = tf.keras.models.Model(input_, out_)
model.summary()

# plot model
tf.keras.utils.plot_model(model, show_shapes=True, dpi=80)

# define loss
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

# Set summary metrics
train_loss = tf.keras.metrics.Mean()
train_acc = tf.keras.metrics.SparseCategoricalAccuracy()

test_loss = tf.keras.metrics.Mean()
test_acc = tf.keras.metrics.SparseCategoricalAccuracy()


@tf.function
def train_step(img, label):
  with tf.GradientTape() as tape:
    pred = model(img, training=True)
    loss = loss_fn(label, pred)
  grad = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(grad, model.trainable_variables))
  train_loss(loss)
  train_acc(label, pred)

@tf.function
def test_step(img, label):
  pred = model(img, training=False)
  loss = loss_fn(label, pred)
  test_loss(loss)
  test_acc(label, pred)


# Train Loop
n_epochs = 50
for epoch in range(n_epochs):
  train_loss.reset_states()
  train_acc.reset_states()
  test_loss.reset_states()
  test_acc.reset_states()

  for img, lab in tqdm_notebook(train_data):
    train_step(img, lab)

  for img, lab in test_data:
    test_step(img, lab)

  template = 'Epoch:{} \t Loss:{} \t  Acc:{} \t Val_loss:{} \t Val_acc:{}'
  print(template.format(epoch, train_loss.result(), train_acc.result()*100, test_loss.result(), test_acc.result()*100))

  

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 28, 28, 128)  1280        input_1[0][0]                    
__________________________________________________________________________________________________
layer_normalization (LayerNorma (None, 28, 28, 128)  256         conv2d[0][0]                     
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 28, 28, 128)  512         layer_normalization[0][0]        
______________________________________________________________________________________________

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:0 	 Loss:2.5990819931030273 	  Acc:17.866666793823242 	 Val_loss:1.8096681833267212 	 Val_acc:34.34000015258789


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:1 	 Loss:1.6695348024368286 	  Acc:40.608333587646484 	 Val_loss:1.1889042854309082 	 Val_acc:60.630001068115234


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:2 	 Loss:1.2952954769134521 	  Acc:53.584999084472656 	 Val_loss:0.7855485677719116 	 Val_acc:72.94999694824219


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:3 	 Loss:0.9705101251602173 	  Acc:66.23332977294922 	 Val_loss:0.5505603551864624 	 Val_acc:83.0


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:4 	 Loss:0.801274299621582 	  Acc:72.78500366210938 	 Val_loss:0.48742884397506714 	 Val_acc:83.5300064086914


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:5 	 Loss:0.6568916440010071 	  Acc:78.2249984741211 	 Val_loss:0.32796886563301086 	 Val_acc:89.62000274658203


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:6 	 Loss:0.5785620808601379 	  Acc:81.2733383178711 	 Val_loss:0.30629298090934753 	 Val_acc:90.36000061035156


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:7 	 Loss:0.5426816344261169 	  Acc:82.44666290283203 	 Val_loss:0.2757303714752197 	 Val_acc:91.58999633789062


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:8 	 Loss:0.4572766125202179 	  Acc:85.65833282470703 	 Val_loss:0.23188233375549316 	 Val_acc:92.8499984741211


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:9 	 Loss:0.5127370357513428 	  Acc:83.49166870117188 	 Val_loss:0.37806349992752075 	 Val_acc:86.47000122070312


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:10 	 Loss:0.49965643882751465 	  Acc:83.66166687011719 	 Val_loss:0.27029696106910706 	 Val_acc:92.1199951171875


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:11 	 Loss:0.462035208940506 	  Acc:85.13500213623047 	 Val_loss:0.24397416412830353 	 Val_acc:92.48999786376953


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:12 	 Loss:0.46344268321990967 	  Acc:85.51166534423828 	 Val_loss:0.23511378467082977 	 Val_acc:92.69000244140625


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:13 	 Loss:0.3968488276004791 	  Acc:87.53500366210938 	 Val_loss:0.18851986527442932 	 Val_acc:94.18000030517578


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:14 	 Loss:0.7557682991027832 	  Acc:75.22000122070312 	 Val_loss:0.352718323469162 	 Val_acc:88.6300048828125


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:15 	 Loss:0.6199231147766113 	  Acc:79.7066650390625 	 Val_loss:0.2897387146949768 	 Val_acc:91.58999633789062


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:16 	 Loss:0.5390926003456116 	  Acc:82.42833709716797 	 Val_loss:0.31172341108322144 	 Val_acc:90.04000091552734


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:17 	 Loss:0.5071117281913757 	  Acc:83.5183334350586 	 Val_loss:0.2779403626918793 	 Val_acc:91.25999450683594


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:18 	 Loss:0.5430063009262085 	  Acc:82.2683334350586 	 Val_loss:1.2029389142990112 	 Val_acc:63.76000213623047


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:19 	 Loss:0.8513026833534241 	  Acc:71.375 	 Val_loss:0.5312917232513428 	 Val_acc:83.4800033569336


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:20 	 Loss:0.7435156106948853 	  Acc:75.01666259765625 	 Val_loss:0.47693389654159546 	 Val_acc:83.91999816894531


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:21 	 Loss:0.7363930940628052 	  Acc:75.30166625976562 	 Val_loss:0.8937384486198425 	 Val_acc:69.4000015258789


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:22 	 Loss:0.9933388233184814 	  Acc:65.87999725341797 	 Val_loss:0.5204464197158813 	 Val_acc:83.55000305175781


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:23 	 Loss:0.8886530995368958 	  Acc:69.3933334350586 	 Val_loss:0.521729588508606 	 Val_acc:83.4000015258789


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:24 	 Loss:0.8669584393501282 	  Acc:70.11833190917969 	 Val_loss:0.613081157207489 	 Val_acc:78.98999786376953


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:25 	 Loss:0.8202111721038818 	  Acc:72.15499877929688 	 Val_loss:0.6246150135993958 	 Val_acc:78.47999572753906


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:26 	 Loss:1.1530505418777466 	  Acc:60.05500030517578 	 Val_loss:0.7225214242935181 	 Val_acc:77.24000549316406


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:27 	 Loss:0.961911678314209 	  Acc:66.69833374023438 	 Val_loss:0.4975806176662445 	 Val_acc:84.0999984741211


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:28 	 Loss:0.8163377642631531 	  Acc:72.11166381835938 	 Val_loss:0.594054639339447 	 Val_acc:80.47999572753906


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:29 	 Loss:0.747271716594696 	  Acc:74.94499969482422 	 Val_loss:0.37589794397354126 	 Val_acc:87.84000396728516


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:30 	 Loss:0.7227592468261719 	  Acc:75.4816665649414 	 Val_loss:0.41290420293807983 	 Val_acc:86.20999908447266


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:31 	 Loss:0.6761678457260132 	  Acc:77.31500244140625 	 Val_loss:0.372750461101532 	 Val_acc:88.30999755859375


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:32 	 Loss:0.8877430558204651 	  Acc:69.68167114257812 	 Val_loss:0.498513400554657 	 Val_acc:84.93000030517578


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:33 	 Loss:0.806276798248291 	  Acc:72.30332946777344 	 Val_loss:0.41058582067489624 	 Val_acc:86.80999755859375


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:34 	 Loss:0.7115423679351807 	  Acc:75.7316665649414 	 Val_loss:0.5461608171463013 	 Val_acc:81.33000183105469


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:35 	 Loss:1.2915475368499756 	  Acc:55.71833419799805 	 Val_loss:0.7694435715675354 	 Val_acc:78.11000061035156


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:36 	 Loss:1.132615566253662 	  Acc:60.6966667175293 	 Val_loss:0.803520679473877 	 Val_acc:73.81999969482422


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:37 	 Loss:1.0873479843139648 	  Acc:62.279998779296875 	 Val_loss:0.6216964721679688 	 Val_acc:81.13999938964844


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:38 	 Loss:0.9685041904449463 	  Acc:66.61333465576172 	 Val_loss:0.5924363732337952 	 Val_acc:81.22000122070312


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:39 	 Loss:0.8779110312461853 	  Acc:69.96666717529297 	 Val_loss:0.485836923122406 	 Val_acc:85.04999542236328


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:40 	 Loss:0.9164681434631348 	  Acc:68.50333404541016 	 Val_loss:0.48860305547714233 	 Val_acc:84.58999633789062


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:41 	 Loss:0.8540305495262146 	  Acc:70.80332946777344 	 Val_loss:0.46014994382858276 	 Val_acc:85.6199951171875


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:42 	 Loss:0.8055254220962524 	  Acc:72.54666900634766 	 Val_loss:0.4349783957004547 	 Val_acc:87.20999908447266


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:43 	 Loss:0.7757413983345032 	  Acc:73.77999877929688 	 Val_loss:0.4117286205291748 	 Val_acc:87.72000122070312


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:44 	 Loss:0.7772415280342102 	  Acc:73.59832763671875 	 Val_loss:0.37575027346611023 	 Val_acc:88.52999877929688


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Epoch:45 	 Loss:0.7433900237083435 	  Acc:74.61500549316406 	 Val_loss:0.3608364462852478 	 Val_acc:89.01000213623047


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

In [0]:
# counter = 0
# for i in train_data:
#   print(i)
#   print(i[0].numpy().shape)
#   print(i[1].numpy().shape)

#   counter += 1
#   if counter == 1:
#     break

# TF - 1.X

In [0]:
# in_a = tf.placeholder(dtype=tf.float32, shape=(2))
# in_b = tf.placeholder(dtype=tf.float32, shape=(2))
 
# def forward(x):
#   with tf.variable_scope("matmul", reuse=tf.AUTO_REUSE):

#     W = tf.get_variable("W", initializer=tf.ones(shape=(2,2)),
#                         regularizer=tf.contrib.layers.l2_regularizer(0.04))
#     b = tf.get_variable("b", initializer=tf.zeros(shape=(2)))
#     return W * x + b
 
# out_a = forward(in_a)
# out_b = forward(in_b)
 
# reg_loss = tf.losses.get_regularization_loss(scope="matmul")
 
# with tf.Session() as sess:
#   sess.run(tf.global_variables_initializer())
#   outs = sess.run([out_a, out_b, reg_loss],
#                   feed_dict={in_a: [1, 0], in_b: [0, 1]})

# TF - 2.X

In [0]:
# W = tf.Variable(tf.ones(shape=(2,2)), name="W")
# b = tf.Variable(tf.zeros(shape=(2)), name="b")

# @tf.function
# def forward(x):
#   return W * x + b

# out_a = forward([1,0])
# print(out_a)

## Create the small custom layer with @tf.function 

In [0]:
# tf.keras.backend.clear_session()

# # model archtecture 
# input_ = tf.keras.layers.Input(shape=(28,28,1))
# x = tf.keras.layers.Conv2D(128,
#                            (3,3),
#                            padding='same',
#                            activation='relu',
#                           #  kernel_initializer='he_normal',
#                            kernel_regularizer=tf.keras.regularizers.l2(l=0.02),
#                            kernel_constraint=tf.keras.constraints.MinMaxNorm())(input_)
# x = tf.keras.layers.LayerNormalization()(x)
# x = tf.keras.layers.BatchNormalization()(x)
# x = tf.keras.layers.MaxPool2D()(x)
# x = tf.keras.layers.Dropout(0.4, trainable=True)(x)
# for i in range(10):
#   # highway block
#   h = tf.keras.layers.Conv2D(128,
#                              (3,3),
#                              padding='same',
#                              activation='relu',
#                             #  kernel_initializer='he_normal',
#                              kernel_regularizer=tf.keras.regularizers.l2(l=0.02),
#                              kernel_constraint=tf.keras.constraints.MinMaxNorm())(x)
#   c = tf.keras.layers.Conv2D(128, 
#                              (3,3), 
#                              padding='same',
#                              activation='sigmoid',
#                              kernel_regularizer=tf.keras.regularizers.l2(l=0.02),
#                              kernel_constraint=tf.keras.constraints.MinMaxNorm(),
#                              bias_initializer = tf.keras.initializers.Constant(-3.))(x)
#   hc = tf.keras.layers.Multiply()([h,c])
#   ad = tf.keras.layers.Lambda(lambda x: 1.0 - x)(c)

#   xc = tf.keras.layers.Multiply()([x, ad])
#   x = tf.keras.layers.Add()([hc, xc])  
#   if i == 7 :
#     x = tf.keras.layers.MaxPool2D()(x)
#   x = tf.keras.layers.Dropout(0.4, trainable=True)(x)
# x = tf.keras.layers.GlobalAvgPool2D()(x)
# x = tf.keras.layers.LayerNormalization()(x)
# x = tf.keras.layers.BatchNormalization()(x)
# x = tf.keras.layers.Dropout(0.4, trainable=True)(x)
# x = tf.keras.layers.Dense(128, activation='relu')(x)
# x = tf.keras.layers.LayerNormalization()(x)
# x = tf.keras.layers.BatchNormalization()(x)
# x = tf.keras.layers.Dropout(0.4, trainable=True)(x)
# out_ = tf.keras.layers.Dense(10, activation='softmax')(x)

In [0]:
# Function hw model

def hw_model(x):
  h = tf.keras.layers.Conv2D(128,
                              (3,3),
                              padding='same',
                              activation='relu',
                              # kernel_initializer='he_normal',
                              # kernel_regularizer=tf.keras.regularizers.l2(l=0.02),
                              # kernel_constraint=tf.keras.constraints.MinMaxNorm()
                              )(x)

  c = tf.keras.layers.Conv2D(128, 
                              (3,3), 
                              padding='same',
                              activation='sigmoid',
                            #  kernel_regularizer=tf.keras.regularizers.l2(l=0.02),
                            #  kernel_constraint=tf.keras.constraints.MinMaxNorm(),
                              bias_initializer = tf.keras.initializers.Constant(-3.))(x)

  hc = tf.keras.layers.Multiply()([h,c])
  ad = tf.keras.layers.Lambda(lambda x: 1.0 - x)(c)

  xc = tf.keras.layers.Multiply()([x, ad])
  x = tf.keras.layers.Add()([hc, xc])
  return x

In [0]:
class hwnet(tf.keras.layers.Layer):
  def __init__(self, units, **kwargs):
    super(hwnet, self).__init__(**kwargs)
    self.units = units

    self.h = tf.keras.layers.Conv2D(self.units, 3, padding='same')
    self.c = tf.keras.layers.Conv2D(self.units, 3, padding='same')  
  # @tf.function
  def call(self, x, traininig=True):
    h_x = self.h(x)
    h_x = tf.nn.relu(h_x)

    c_x = self.c(x)
    c_x = tf.nn.sigmoid(c_x)
    
    hc = c_x * h_x + x*(1. - c_x)
    
    return hc

In [0]:
my_layer = hwnet(128)

In [0]:
input_ = tf.keras.layers.Input(shape=(28,28,1))
x = my_layer(input_)

mode = tf.keras.models.Model(input_, x)

In [70]:
mode.summary()

Model: "model_10"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_14 (InputLayer)        [(None, 28, 28, 1)]       0         
_________________________________________________________________
hwnet_8 (hwnet)              (None, 28, 28, 128)       2560      
Total params: 2,560
Trainable params: 2,560
Non-trainable params: 0
_________________________________________________________________


In [0]:
input_ = tf.keras.layers.Input(shape=(28,28,1))
x = hw_model(input_)
mode2 = tf.keras.models.Model(input_, x)

In [65]:
mode2.summary()

Model: "model_8"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_12 (InputLayer)           [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 28, 28, 128)  1280        input_12[0][0]                   
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 28, 28, 128)  1280        input_12[0][0]                   
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 28, 28, 128)  0           conv2d_17[0][0]                  
____________________________________________________________________________________________

In [0]:
m1 = mode.trainable_variables

In [0]:
m2 = mode.trainable_variables

# Funtion Tester

In [0]:
# Function hw model

def hw_model(x):
  h = tf.keras.layers.Conv2D(128,
                              (3,3),
                              padding='same',
                              activation='relu',
                              # kernel_initializer='he_normal',
                              # kernel_regularizer=tf.keras.regularizers.l2(l=0.02),
                              # kernel_constraint=tf.keras.constraints.MinMaxNorm()
                              )(x)

  c = tf.keras.layers.Conv2D(128, 
                              (3,3), 
                              padding='same',
                              activation='sigmoid',
                            #  kernel_regularizer=tf.keras.regularizers.l2(l=0.02),
                            #  kernel_constraint=tf.keras.constraints.MinMaxNorm(),
                              bias_initializer = tf.keras.initializers.Constant(-3.))(x)

  hc = tf.keras.layers.Multiply()([h,c])
  ad = tf.keras.layers.Lambda(lambda x: 1.0 - x)(c)

  xc = tf.keras.layers.Multiply()([x, ad])
  x = tf.keras.layers.Add()([hc, xc])
  return x

In [83]:
tf.keras.backend.clear_session()

# model archtecture 
input_ = tf.keras.layers.Input(shape=(28,28,1))
x = tf.keras.layers.Conv2D(128,
                           (3,3),
                           padding='same',
                           activation='relu',
                          #  kernel_initializer='he_normal',
                           kernel_regularizer=tf.keras.regularizers.l2(l=0.02),
                           kernel_constraint=tf.keras.constraints.MinMaxNorm())(input_)
x = tf.keras.layers.LayerNormalization()(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.MaxPool2D()(x)
x = tf.keras.layers.Dropout(0.4, trainable=True)(x)
for i in range(10):
  # highway block
  x = hw_model(x)
  if i == 7 :
    x = tf.keras.layers.MaxPool2D()(x)
  x = tf.keras.layers.Dropout(0.4, trainable=True)(x)
x = tf.keras.layers.GlobalAvgPool2D()(x)
x = tf.keras.layers.LayerNormalization()(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Dropout(0.4, trainable=True)(x)
x = tf.keras.layers.Dense(128, activation='relu')(x)
x = tf.keras.layers.LayerNormalization()(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Dropout(0.4, trainable=True)(x)
out_ = tf.keras.layers.Dense(10, activation='softmax')(x)


# define model
model = tf.keras.models.Model(input_, out_)
model.summary()

# plot model
tf.keras.utils.plot_model(model, show_shapes=True, dpi=80)

# define loss
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

# Set summary metrics
train_loss = tf.keras.metrics.Mean()
train_acc = tf.keras.metrics.SparseCategoricalAccuracy()

test_loss = tf.keras.metrics.Mean()
test_acc = tf.keras.metrics.SparseCategoricalAccuracy()


@tf.function
def train_step(img, label):
  with tf.GradientTape() as tape:
    pred = model(img, training=True)
    loss = loss_fn(label, pred)
  grad = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(grad, model.trainable_variables))
  train_loss(loss)
  train_acc(label, pred)

@tf.function
def test_step(img, label):
  pred = model(img, training=False)
  loss = loss_fn(label, pred)
  test_loss(loss)
  test_acc(label, pred)


# Train Loop
n_epochs = 50
for epoch in range(n_epochs):
  train_loss.reset_states()
  train_acc.reset_states()
  test_loss.reset_states()
  test_acc.reset_states()

  for img, lab in tqdm_notebook(train_data):
    train_step(img, lab)

  for img, lab in test_data:
    test_step(img, lab)

  template = 'Epoch:{} \t Loss:{} \t  Acc:{} \t Val_loss:{} \t Val_acc:{}'
  print(template.format(epoch, train_loss.result(), train_acc.result()*100, test_loss.result(), test_acc.result()*100))

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 28, 28, 128)  1280        input_1[0][0]                    
__________________________________________________________________________________________________
layer_normalization (LayerNorma (None, 28, 28, 128)  256         conv2d[0][0]                     
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 28, 28, 128)  512         layer_normalization[0][0]        
______________________________________________________________________________________________

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:0 	 Loss:2.642965078353882 	  Acc:16.518333435058594 	 Val_loss:2.0929808616638184 	 Val_acc:19.68000030517578


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:1 	 Loss:1.903296947479248 	  Acc:32.7066650390625 	 Val_loss:2.167893886566162 	 Val_acc:19.23000144958496


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:2 	 Loss:1.5378406047821045 	  Acc:45.04833221435547 	 Val_loss:1.3899304866790771 	 Val_acc:47.71000289916992


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:3 	 Loss:1.3074982166290283 	  Acc:53.393333435058594 	 Val_loss:1.2795355319976807 	 Val_acc:55.61000061035156


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:4 	 Loss:1.1674259901046753 	  Acc:58.666664123535156 	 Val_loss:0.7831487059593201 	 Val_acc:73.91999816894531


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:5 	 Loss:0.9686303734779358 	  Acc:66.34667205810547 	 Val_loss:0.5480973720550537 	 Val_acc:81.23999786376953


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:6 	 Loss:0.7643668055534363 	  Acc:73.61333465576172 	 Val_loss:0.49524030089378357 	 Val_acc:83.33000183105469


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:7 	 Loss:0.6487959623336792 	  Acc:78.1500015258789 	 Val_loss:0.3445304334163666 	 Val_acc:87.63999938964844


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:8 	 Loss:0.5412564873695374 	  Acc:81.91666412353516 	 Val_loss:0.2971004843711853 	 Val_acc:90.56999969482422


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:9 	 Loss:0.455912709236145 	  Acc:85.05500030517578 	 Val_loss:0.23441115021705627 	 Val_acc:91.98999786376953


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:10 	 Loss:0.789532482624054 	  Acc:75.36833190917969 	 Val_loss:2.5036356449127197 	 Val_acc:19.14000129699707


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:11 	 Loss:1.6471704244613647 	  Acc:42.474998474121094 	 Val_loss:1.679119348526001 	 Val_acc:36.5099983215332


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:12 	 Loss:1.2444698810577393 	  Acc:56.01166534423828 	 Val_loss:0.8481962084770203 	 Val_acc:73.72999572753906


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:13 	 Loss:1.0644267797470093 	  Acc:62.79500198364258 	 Val_loss:0.6400256752967834 	 Val_acc:79.16999816894531


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:14 	 Loss:0.8741239905357361 	  Acc:69.96499633789062 	 Val_loss:0.5900257229804993 	 Val_acc:81.29000091552734


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:15 	 Loss:0.8600128293037415 	  Acc:70.54332733154297 	 Val_loss:0.44141992926597595 	 Val_acc:86.95999908447266


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:16 	 Loss:0.8439032435417175 	  Acc:70.93333435058594 	 Val_loss:0.47154608368873596 	 Val_acc:83.8800048828125


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:17 	 Loss:0.7046555876731873 	  Acc:76.11499786376953 	 Val_loss:0.4073246419429779 	 Val_acc:86.86000061035156


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:18 	 Loss:0.6310387253761292 	  Acc:78.80166625976562 	 Val_loss:0.29576513171195984 	 Val_acc:91.02999877929688


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:19 	 Loss:0.5586283206939697 	  Acc:81.58833312988281 	 Val_loss:0.3238520324230194 	 Val_acc:89.20000457763672


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:20 	 Loss:0.6046640276908875 	  Acc:80.23333740234375 	 Val_loss:0.3624538779258728 	 Val_acc:89.16000366210938


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:21 	 Loss:0.6869674921035767 	  Acc:77.2066650390625 	 Val_loss:0.3456323742866516 	 Val_acc:89.3800048828125


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:22 	 Loss:0.6846669316291809 	  Acc:77.14666748046875 	 Val_loss:0.31931012868881226 	 Val_acc:90.33000183105469


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:23 	 Loss:0.5821357369422913 	  Acc:80.7266616821289 	 Val_loss:0.2704907953739166 	 Val_acc:91.83999633789062


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:24 	 Loss:0.5284662246704102 	  Acc:82.49666595458984 	 Val_loss:0.3051919937133789 	 Val_acc:90.05000305175781


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:25 	 Loss:0.5624176859855652 	  Acc:81.21833038330078 	 Val_loss:0.3641680181026459 	 Val_acc:88.63999938964844


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:26 	 Loss:0.5315314531326294 	  Acc:82.53499603271484 	 Val_loss:0.24204476177692413 	 Val_acc:92.5199966430664


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:27 	 Loss:0.495760440826416 	  Acc:83.79499816894531 	 Val_loss:0.2892589271068573 	 Val_acc:90.77999877929688


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:28 	 Loss:0.6147192120552063 	  Acc:79.44166564941406 	 Val_loss:0.4382011890411377 	 Val_acc:84.73999786376953


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:29 	 Loss:0.7670836448669434 	  Acc:74.13333129882812 	 Val_loss:0.3854840397834778 	 Val_acc:89.01000213623047


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:30 	 Loss:0.6476397514343262 	  Acc:78.49833679199219 	 Val_loss:0.34474989771842957 	 Val_acc:89.27000427246094


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:31 	 Loss:0.5590430498123169 	  Acc:81.76333618164062 	 Val_loss:0.33360856771469116 	 Val_acc:89.29000091552734


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:32 	 Loss:0.6418601870536804 	  Acc:79.10333251953125 	 Val_loss:0.34226471185684204 	 Val_acc:89.63999938964844


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:33 	 Loss:0.7178331613540649 	  Acc:75.84833526611328 	 Val_loss:0.2611663043498993 	 Val_acc:92.51000213623047


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:34 	 Loss:0.7161076664924622 	  Acc:76.06499481201172 	 Val_loss:0.2877451777458191 	 Val_acc:91.81999969482422


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:35 	 Loss:0.7491903305053711 	  Acc:74.75666046142578 	 Val_loss:0.4323233664035797 	 Val_acc:86.58999633789062


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:36 	 Loss:0.7263030409812927 	  Acc:75.27833557128906 	 Val_loss:0.33513742685317993 	 Val_acc:89.81999969482422


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:37 	 Loss:0.7693406939506531 	  Acc:74.038330078125 	 Val_loss:0.5741264224052429 	 Val_acc:82.20000457763672


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:38 	 Loss:0.7516979575157166 	  Acc:74.70832824707031 	 Val_loss:0.4764224886894226 	 Val_acc:84.1500015258789


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:39 	 Loss:0.9387038946151733 	  Acc:67.68333435058594 	 Val_loss:0.4024984836578369 	 Val_acc:87.58999633789062


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:40 	 Loss:0.6521201729774475 	  Acc:78.24666595458984 	 Val_loss:0.3042925000190735 	 Val_acc:90.76000213623047


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:41 	 Loss:0.6206641793251038 	  Acc:79.42500305175781 	 Val_loss:0.2974497973918915 	 Val_acc:90.87999725341797


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:42 	 Loss:0.5462089776992798 	  Acc:81.9816665649414 	 Val_loss:0.2336399257183075 	 Val_acc:93.01000213623047


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:43 	 Loss:0.5277111530303955 	  Acc:82.5683364868164 	 Val_loss:0.2246612310409546 	 Val_acc:93.16999816894531


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:44 	 Loss:0.5225438475608826 	  Acc:82.88499450683594 	 Val_loss:0.23085328936576843 	 Val_acc:92.90999603271484


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:45 	 Loss:0.5355451703071594 	  Acc:82.41500091552734 	 Val_loss:0.2166265994310379 	 Val_acc:93.5


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:46 	 Loss:0.5657757520675659 	  Acc:81.55999755859375 	 Val_loss:0.2879483997821808 	 Val_acc:90.83999633789062


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:47 	 Loss:0.5235001444816589 	  Acc:82.99833679199219 	 Val_loss:0.25443035364151 	 Val_acc:91.88999938964844


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:48 	 Loss:0.5514200329780579 	  Acc:81.91666412353516 	 Val_loss:0.4504472315311432 	 Val_acc:85.25


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:49 	 Loss:0.5988316535949707 	  Acc:80.40666198730469 	 Val_loss:0.24089853465557098 	 Val_acc:92.51000213623047


In [0]:
class hwnet(tf.keras.layers.Layer):
  def __init__(self, units, **kwargs):
    super(hwnet, self).__init__(**kwargs)
    self.units = units

    self.h = tf.keras.layers.Conv2D(self.units, 3, padding='same')
    self.c = tf.keras.layers.Conv2D(self.units, 3, padding='same', bias_initializer=tf.keras.initializers.Constant(-3.))  
  # @tf.function
  def call(self, x, traininig=True):
    h_x = self.h(x)
    h_x = tf.nn.relu(h_x)

    c_x = self.c(x)
    c_x = tf.nn.sigmoid(c_x)
    
    hc = c_x * h_x + x*(1. - c_x)
    
    return hc

In [86]:
tf.keras.backend.clear_session()

# model archtecture 
input_ = tf.keras.layers.Input(shape=(28,28,1))
x = tf.keras.layers.Conv2D(128,
                           (3,3),
                           padding='same',
                           activation='relu',
                          #  kernel_initializer='he_normal',
                           kernel_regularizer=tf.keras.regularizers.l2(l=0.02),
                           kernel_constraint=tf.keras.constraints.MinMaxNorm())(input_)
x = tf.keras.layers.LayerNormalization()(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.MaxPool2D()(x)
x = tf.keras.layers.Dropout(0.4, trainable=True)(x)
for i in range(10):
  # highway block
  x = hwnet(128)(x)
  if i == 7 :
    x = tf.keras.layers.MaxPool2D()(x)
  x = tf.keras.layers.Dropout(0.4, trainable=True)(x)
x = tf.keras.layers.GlobalAvgPool2D()(x)
x = tf.keras.layers.LayerNormalization()(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Dropout(0.4, trainable=True)(x)
x = tf.keras.layers.Dense(128, activation='relu')(x)
x = tf.keras.layers.LayerNormalization()(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Dropout(0.4, trainable=True)(x)
out_ = tf.keras.layers.Dense(10, activation='softmax')(x)


# define model
model = tf.keras.models.Model(input_, out_)
model.summary()

# plot model
tf.keras.utils.plot_model(model, show_shapes=True, dpi=80)

# define loss
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

# Set summary metrics
train_loss = tf.keras.metrics.Mean()
train_acc = tf.keras.metrics.SparseCategoricalAccuracy()

test_loss = tf.keras.metrics.Mean()
test_acc = tf.keras.metrics.SparseCategoricalAccuracy()


@tf.function
def train_step(img, label):
  with tf.GradientTape() as tape:
    pred = model(img, training=True)
    loss = loss_fn(label, pred)
  grad = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(grad, model.trainable_variables))
  train_loss(loss)
  train_acc(label, pred)

@tf.function
def test_step(img, label):
  pred = model(img, training=False)
  loss = loss_fn(label, pred)
  test_loss(loss)
  test_acc(label, pred)


# Train Loop
n_epochs = 50
for epoch in range(n_epochs):
  train_loss.reset_states()
  train_acc.reset_states()
  test_loss.reset_states()
  test_acc.reset_states()

  for img, lab in tqdm_notebook(train_data):
    train_step(img, lab)

  for img, lab in test_data:
    test_step(img, lab)

  template = 'Epoch:{} \t Loss:{} \t  Acc:{} \t Val_loss:{} \t Val_acc:{}'
  print(template.format(epoch, train_loss.result(), train_acc.result()*100, test_loss.result(), test_acc.result()*100))

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 28, 28, 128)       1280      
_________________________________________________________________
layer_normalization (LayerNo (None, 28, 28, 128)       256       
_________________________________________________________________
batch_normalization (BatchNo (None, 28, 28, 128)       512       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 128)       0         
_________________________________________________________________
dropout (Dropout)            (None, 14, 14, 128)       0         
_________________________________________________________________
hwnet (hwnet)                (None, 14, 14, 128)       295168

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:0 	 Loss:2.6473326683044434 	  Acc:17.434999465942383 	 Val_loss:2.3466312885284424 	 Val_acc:9.809999465942383


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:1 	 Loss:1.8086100816726685 	  Acc:36.46666717529297 	 Val_loss:1.9726921319961548 	 Val_acc:23.3700008392334


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:2 	 Loss:1.4664957523345947 	  Acc:47.71666717529297 	 Val_loss:1.1578290462493896 	 Val_acc:57.06999969482422


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:3 	 Loss:1.2535818815231323 	  Acc:55.323333740234375 	 Val_loss:0.8606882095336914 	 Val_acc:71.81999969482422


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:4 	 Loss:1.1013411283493042 	  Acc:61.165000915527344 	 Val_loss:0.6803735494613647 	 Val_acc:76.97000122070312


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:5 	 Loss:1.0338962078094482 	  Acc:63.42832946777344 	 Val_loss:0.6533309817314148 	 Val_acc:78.58999633789062


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:6 	 Loss:0.9273470640182495 	  Acc:67.62166595458984 	 Val_loss:0.6247683763504028 	 Val_acc:79.06999969482422


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:7 	 Loss:0.8501074314117432 	  Acc:70.94499969482422 	 Val_loss:0.5704400539398193 	 Val_acc:80.45999908447266


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:8 	 Loss:0.7758262753486633 	  Acc:73.61166381835938 	 Val_loss:0.38786208629608154 	 Val_acc:88.56999969482422


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:9 	 Loss:0.6718999743461609 	  Acc:77.61499786376953 	 Val_loss:0.3304932415485382 	 Val_acc:89.56000518798828


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:10 	 Loss:0.576257586479187 	  Acc:81.26499938964844 	 Val_loss:0.34981769323349 	 Val_acc:88.5999984741211


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:11 	 Loss:0.5301187038421631 	  Acc:82.88833618164062 	 Val_loss:0.32375091314315796 	 Val_acc:88.84000396728516


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:12 	 Loss:0.4824641942977905 	  Acc:84.64166259765625 	 Val_loss:0.23028278350830078 	 Val_acc:92.80999755859375


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:13 	 Loss:0.4612663686275482 	  Acc:85.47666931152344 	 Val_loss:0.2029246836900711 	 Val_acc:93.69000244140625


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:14 	 Loss:0.4403134882450104 	  Acc:86.038330078125 	 Val_loss:0.20095381140708923 	 Val_acc:93.73999786376953


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:15 	 Loss:0.6911023259162903 	  Acc:77.67333221435547 	 Val_loss:0.3652738928794861 	 Val_acc:89.0999984741211


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:16 	 Loss:0.717606782913208 	  Acc:75.97000122070312 	 Val_loss:0.3797578513622284 	 Val_acc:89.0


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:17 	 Loss:1.1756348609924316 	  Acc:59.323333740234375 	 Val_loss:1.114208459854126 	 Val_acc:60.68000030517578


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:18 	 Loss:0.9983742833137512 	  Acc:65.15666961669922 	 Val_loss:0.5188337564468384 	 Val_acc:83.63999938964844


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:19 	 Loss:0.8963980674743652 	  Acc:69.27999877929688 	 Val_loss:0.45660948753356934 	 Val_acc:86.12999725341797


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:20 	 Loss:0.9277284145355225 	  Acc:68.16333770751953 	 Val_loss:0.6060566902160645 	 Val_acc:82.04000091552734


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:21 	 Loss:0.883016049861908 	  Acc:69.82167053222656 	 Val_loss:0.4883461594581604 	 Val_acc:83.99000549316406


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:22 	 Loss:0.8914530873298645 	  Acc:69.72167205810547 	 Val_loss:0.7035316824913025 	 Val_acc:76.29000091552734


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:23 	 Loss:0.799427330493927 	  Acc:72.91166687011719 	 Val_loss:0.8466266393661499 	 Val_acc:70.4000015258789


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:24 	 Loss:0.7693642973899841 	  Acc:74.05166625976562 	 Val_loss:0.4726037383079529 	 Val_acc:84.22000122070312


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:25 	 Loss:0.8799980878829956 	  Acc:69.99333190917969 	 Val_loss:0.5588588714599609 	 Val_acc:83.25


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:26 	 Loss:0.8083878755569458 	  Acc:72.18499755859375 	 Val_loss:0.46384748816490173 	 Val_acc:85.12999725341797


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:27 	 Loss:0.741334855556488 	  Acc:74.76333618164062 	 Val_loss:0.40709638595581055 	 Val_acc:87.51000213623047


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:28 	 Loss:0.8369259834289551 	  Acc:71.77666473388672 	 Val_loss:0.4306454062461853 	 Val_acc:86.68000030517578


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:29 	 Loss:0.7364579439163208 	  Acc:74.94999694824219 	 Val_loss:0.40234971046447754 	 Val_acc:86.47999572753906


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:30 	 Loss:0.7125886678695679 	  Acc:75.63999938964844 	 Val_loss:0.38391047716140747 	 Val_acc:87.29000091552734


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:31 	 Loss:0.6866780519485474 	  Acc:76.78499603271484 	 Val_loss:0.35703861713409424 	 Val_acc:88.61000061035156


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:32 	 Loss:0.7779655456542969 	  Acc:73.58333587646484 	 Val_loss:0.5593663454055786 	 Val_acc:81.2300033569336


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:33 	 Loss:0.7895306348800659 	  Acc:72.97833251953125 	 Val_loss:1.2862141132354736 	 Val_acc:57.970001220703125


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:34 	 Loss:0.954068124294281 	  Acc:67.25166320800781 	 Val_loss:0.5650211572647095 	 Val_acc:80.36000061035156


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:35 	 Loss:0.7573732137680054 	  Acc:74.1199951171875 	 Val_loss:0.39214983582496643 	 Val_acc:87.08999633789062


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:36 	 Loss:0.6938803195953369 	  Acc:76.65332794189453 	 Val_loss:0.38357672095298767 	 Val_acc:87.58000183105469


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:37 	 Loss:0.7115454077720642 	  Acc:75.95999908447266 	 Val_loss:0.3528289198875427 	 Val_acc:88.84000396728516


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:38 	 Loss:0.678758442401886 	  Acc:77.11499786376953 	 Val_loss:0.3773038983345032 	 Val_acc:87.66000366210938


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:39 	 Loss:0.7763702273368835 	  Acc:73.34832763671875 	 Val_loss:0.4084215760231018 	 Val_acc:87.81999969482422


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:40 	 Loss:0.6844228506088257 	  Acc:76.72166442871094 	 Val_loss:0.36107179522514343 	 Val_acc:88.37000274658203


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:41 	 Loss:0.6095004677772522 	  Acc:79.45832824707031 	 Val_loss:0.32030394673347473 	 Val_acc:89.70999908447266


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:42 	 Loss:0.6415389776229858 	  Acc:78.30000305175781 	 Val_loss:0.32906192541122437 	 Val_acc:89.12000274658203


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:43 	 Loss:0.6166014671325684 	  Acc:79.30500030517578 	 Val_loss:0.3119105398654938 	 Val_acc:89.67000579833984


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:44 	 Loss:0.603192150592804 	  Acc:79.95500183105469 	 Val_loss:0.3686971068382263 	 Val_acc:87.44000244140625


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:45 	 Loss:0.6148582696914673 	  Acc:79.45999908447266 	 Val_loss:0.2896623909473419 	 Val_acc:91.00999450683594


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:46 	 Loss:0.5815957188606262 	  Acc:80.47000122070312 	 Val_loss:0.3121838867664337 	 Val_acc:89.80000305175781


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:47 	 Loss:0.5675150156021118 	  Acc:81.28166961669922 	 Val_loss:0.2772147059440613 	 Val_acc:91.1199951171875


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:48 	 Loss:0.6114705801010132 	  Acc:79.69667053222656 	 Val_loss:0.28582069277763367 	 Val_acc:90.93000030517578


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Epoch:49 	 Loss:0.5816935896873474 	  Acc:80.68666076660156 	 Val_loss:0.2678208351135254 	 Val_acc:91.75
