In [112]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow import keras

In [113]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train.shape, y_train.shape, x_test.shape, y_test.shape

((60000, 28, 28), (60000,), (10000, 28, 28), (10000,))

In [103]:
# x_train = tf.expand_dims(x_train, -1)
# x_test = tf.expand_dims(x_test, -1)
# x_train.shape,y_train.shape, x_test.shape, y_test.shape

(TensorShape([60000, 28, 28, 1]),
 (60000,),
 TensorShape([10000, 28, 28, 1]),
 (10000,))

In [114]:
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(64).prefetch(tf.data.AUTOTUNE)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(64).prefetch(tf.data.AUTOTUNE)

In [106]:
class CustomDense(keras.layers.Layer):
    def __init__(self, input_shape, units):
        super().__init__()
        self.w = self.add_weight(name='w', initializer='random_normal', trainable=True, shape=(input_shape, units))
        self.b = self.add_weight(name='b', initializer='zeros', trainable=True, shape=(units,))
        
    def call(self, x):
        return tf.matmul(x, self.w) + self.b

In [115]:
# class CustomCNN(keras.layers.Layer):
#     def __init__(self, filters, kernel_size=3):
#         super().__init__()
#         self.cnn = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, padding='same')
#         self.bn = keras.layers.BatchNormalization()
        
#     def call(self, x, training=False):
#         x = self.cnn(x)
#         x = self.bn(x, training=training)
#         return tf.nn.relu(x)

class CustomCNN(keras.layers.Layer):
    def __init__(self, filters, kernel_size=3):
        super().__init__()
        self.conv = keras.layers.Conv2D(filters, kernel_size=kernel_size, padding='same')
        self.bn = keras.layers.BatchNormalization()
        
    def call(self, input_tensor, training=False):
        x = self.conv(input_tensor)
        x = self.bn(x, training=training)
        x = tf.nn.relu(x)
        return x
    def model(self):
        x = keras.Input()
        return keras.Model(inputs=[x], outputs=self.call(x))

In [117]:
class MyModel(keras.Model):
    def __init__(self):
        super().__init__()
        self.cnn1 = CustomCNN(32)
        self.cnn2 = CustomCNN(64)
        self.cnn3 = CustomCNN(128)
        self.flatten = keras.layers.Flatten()
        self.dense1 = CustomDense(28*28, 64)
        self.dense2 = keras.layers.Dense(10)
#         self.input_data_shape = 
        
    def call(self, x):
#         x = keras.Input(shape=(28,28, 1))(x)
#         x = self.cnn1(x)
#         x = self.cnn2(x)
#         x = self.cnn3(x)
        x = self.flatten(x)
        x = tf.cast(x, dtype=float)
        x = x/255.
        x = tf.nn.relu(self.dense1(x))
        return self.dense2(x)
    
    def model(self):
        x = keras.layers.Input(shape=(28*28))
        return keras.Model(inputs=[x], outputs=self.call(x))

In [118]:
model = MyModel()

In [119]:
model.model().summary()

Model: "model_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_18 (InputLayer)       [(None, 784)]             0         
                                                                 
 flatten_16 (Flatten)        (None, 784)               0         
                                                                 
 tf.cast_5 (TFOpLambda)      (None, 784)               0         
                                                                 
 tf.math.truediv_5 (TFOpLamb  (None, 784)              0         
 da)                                                             
                                                                 
 custom_dense_14 (CustomDens  (None, 64)               50240     
 e)                                                              
                                                                 
 tf.nn.relu_5 (TFOpLambda)   (None, 64)                0   

In [120]:
# model.compile(optimizer=keras.optimizers.Adam(), loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), 
#               metrics=['accuracy'])

In [121]:
# model.fit(train_dataset, validation_data=test_dataset, epochs=5, batch_size=32, verbose=1)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x21c03cdb400>

In [175]:
class CustomFit(keras.Model):
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    def compile(self, loss, optimizer, metric):
        super(CustomFit, self).compile()
        self.optimizer = optimizer
        self.metric = metric
        self.loss = loss
        
    def train_step(self, data):
        x, y = data
        
        with tf.GradientTape() as tape:
            y_pred = self.model(x, training=True)
            y_pred = tf.cast(y_pred, tf.float16)
            y = tf.cast(y, tf.float16)
#             print('y type:', type(y))
#             print(y)
#             print('y pred type:', type(y_pred))
#             print(y_pred)
            loss = self.loss(y, y_pred)
        trainable_variables = self.trainable_variables
        grad = tape.gradient(loss, trainable_variables)
        
        self.optimizer.apply_gradients(zip(grad, trainable_variables))
        self.metric.update_state(y, y_pred)
        
        return {'train loss:':loss, 'train accuracy:':self.metric.result()}
    
    def test_step(self, data):
        x, y = data
        y_pred = self.model(x, training=False)
        loss = self.loss(y, y_pred)
        self.metric.update_state(y, y_pred)
        return {'test loss:':loss, 'test accuracy:':self.metric.result()}
        
        

In [176]:
cust_train = CustomFit(model)

In [177]:
cust_train.compile(optimizer=keras.optimizers.Adam(learning_rate=3e-4),
                   loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                   metric=keras.metrics.SparseCategoricalAccuracy(name='accuracy')
                  )

In [178]:
cust_train.fit(train_dataset, validation_data=test_dataset, batch_size=32, epochs=2)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x21c12037070>