In [1]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import pandas as pd
from tensorflow.keras.datasets import mnist

In [6]:
class MyModel(keras.Model):
    def __init__(self, output=10):
        super().__init__()
        self.dense1 = keras.layers.CustomDense(64)
        self.dense2 = keras.layers.CustomDense(output)
        self.flatten = keras.layers.Flatten()
    def call(self, x):
        x = self.flatten(x)
        x = tf.cast(x, dtype=float)
        x = x/255.
        print('X Shape:')
        print(x.shape)
        x = tf.nn.relu(self.dense1(x))
        return self.dense2(x)
    
    def model(self):
        x = keras.layers.Input(shape=(28*28))
        return keras.Model(inputs=[x], outputs=self.call(x))

In [7]:
model = MyModel(10)
model.model().summary()

X Shape:
(None, 784)
Model: "model_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_3 (InputLayer)        [(None, 784)]             0         
                                                                 
 flatten_1 (Flatten)         (None, 784)               0         
                                                                 
 tf.cast_1 (TFOpLambda)      (None, 784)               0         
                                                                 
 tf.math.truediv_1 (TFOpLamb  (None, 784)              0         
 da)                                                             
                                                                 
 dense_4 (Dense)             (None, 64)                50240     
                                                                 
 tf.nn.relu_2 (TFOpLambda)   (None, 64)                0         
                                      

In [31]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

In [33]:
model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.01),
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [34]:
model.fit(x_train, y_train, batch_size=32, validation_data=val_dataset, verbose=1, epochs=5)

Epoch 1/5
X Shape:
(32, 784)
X Shape:
(32, 784)
(None, 784)
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x28a31c0cf40>

In [35]:
class CustomDense(keras.layers.Layer):
    def __init__(self, units, input_shape):
        super().__init__()
        self.w = self.add_weight(
            name='w',
            shape=(input_shape, units),
            initializer='random_normal',
            trainable=True)
        
        self.b = self.add_weight(
            name='b',
            initializer='zeros',
            shape=(units,),
            trainable=True)
        
    def call(self, x):
        return tf.matmul(x, self.w) + self.b

In [36]:
dense = CustomDense(10, 784)

In [40]:
dense.b

<tf.Variable 'b:0' shape=(10,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>