In [0]:
import tensorflow as tf
import keras # requires to remove later
from tensorflow import keras as keras
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

print(tf.__version__)

class OperationLayer(keras.layers.Layer):
    def __init__(self, units, kernel_size, padding, activation, **kwargs):
        super(OperationLayer, self).__init__(**kwargs)
        self.units = units
        self.kernel_size = kernel_size
        self.padding = padding
        self.activation = keras.activations.get(activation)

        self.h = keras.layers.Conv2D(filters=self.units,
                                     kernel_size = self.kernel_size,
                                     padding=self.padding, 
                                     activation=self.activation)
        
        self.c = keras.layers.Conv2D(filters=self.units,
                                     kernel_size=self.kernel_size,
                                     padding=self.padding, 
                                     activation=None,
                                     bias_initializer=keras.initializers.Constant(-3.))
        
    @tf.function
    def call(self, x):
        h = self.h(x)
        c = self.c(x)
        c = keras.activations.sigmoid(c)
        
        result = h*c + x*(1.-c)
        return result

class HWModel(keras.Model):
    def __init__(self, **kwargs):
        super(HWModel, self).__init__(**kwargs)
        # self.input = keras.layers.Input(shape=input_shape)
        self.c1 = keras.layers.Conv2D(256, 7, 1, 'same', activation='relu')
        self.c2 = keras.layers.Conv2D(256, 5, 1, 'same', activation='relu')
        self.h1 = [OperationLayer(256, 3, 'same', 'relu') for _ in range(5)]
        self.c3 = keras.layers.Conv2D(256, 3, 1, 'same', activation='relu')
        self.out = keras.layers.Dense(10, activation='softmax')
        
    @tf.function
    def call(self, x):
        z = self.c1(x)
        z = self.c2(z)
        for layer in self.h1:
            z = layer(z)
        z = self.c3(z)
        z = keras.layers.GlobalAvgPool2D()(z)
        return self.out(z)


from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

@tf.function
def img_preprocess(x, y):
    x = tf.cast(x, tf.float32)
    x = tf.expand_dims(x, 2)
    print(x.shape)
    x /= 255.
    return x, tf.cast(y, tf.float32)


# Prepare Data set
train_loader = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_loader = train_loader.shuffle(1024)
train_loader = train_loader.map(img_preprocess)
train_loader = train_loader.batch(128)

test_loader = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_loader = test_loader.map(img_preprocess)
test_loader = test_loader.batch(128)


model = HWModel()

input_ = keras.layers.Input(shape=(28,28,1))
x = model(input_)
f_model = keras.models.Model(input_, x)




# Create Custome Training Loop
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
optim = tf.keras.optimizers.Adam()

# create metrics format
train_loss = tf.keras.metrics.Mean()
test_loss = tf.keras.metrics.Mean()
train_acc = tf.keras.metrics.SparseCategoricalAccuracy()
test_acc = tf.keras.metrics.SparseCategoricalAccuracy()

@tf.function
def train_step(img, label):
    with tf.GradientTape() as tape:
        pred = model(img, training=True)
        loss = loss_fn(label, pred)
    grad = tape.gradient(loss, model.trainable_variables)
    optim.apply_gradients(zip(grad, model.trainable_variables))
    train_loss(loss)
    train_acc(label, pred)
    
@tf.function
def test_step(img, label):
    pred = model(img)
    loss = loss_fn(label, pred)
    test_loss(loss)
    test_acc(label, pred)


n_epochs = 100

for epoch in range(n_epochs):
    
    train_loss.reset_states()
    train_acc.reset_states()
    test_loss.reset_states()
    test_acc.reset_states()
    
    for img, label in train_loader:
        train_step(img, label)
        
    for img, label in test_loader:
        test_step(img, label)
    
    print(f'Epoch : {epoch} \t\
            Train Loss : {train_loss.result()} \t\
            Train Acc : {train_acc.result() * 100} \t\
            Test Loss : {test_loss.result()} \t\
            Test Acc : {test_acc.result() * 100}')
 