In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Layer

class SimpleDense(Layer):

  def __init__(self, units=32, activation=None):
      super(SimpleDense, self).__init__()
      self.units = units
      self.activation=tf.keras.activations.get(activation)

  def build(self, input_shape):  # Create the state of the layer (weights)
    w_init = tf.random_normal_initializer()
    self.w = tf.Variable(initial_value=w_init(shape=(input_shape[-1], self.units),
                             dtype='float32'),trainable=True)
    print(self.w)
    b_init = tf.zeros_initializer()
    self.b = tf.Variable(
        initial_value=b_init(shape=(self.units,), dtype='float32'),
        trainable=True)
    super().build(input_shape)

  def call(self, inputs):  # Defines the computation from inputs to outputs
      return self.activation(tf.matmul(inputs, self.w) + self.b)

In [2]:
# define the dataset
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

In [3]:
model=tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28,28)),
    SimpleDense(units=1, activation="relu"),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(units=10,activation="softmax")
])

model.compile(optimizer="adam",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])
model.fit(x_train,y_train, epochs=50, verbose=0)

<tf.Variable 'simple_dense/Variable:0' shape=(784, 1) dtype=float32, numpy=
array([[ 0.0143884 ],
       [ 0.00981471],
       [ 0.06175914],
       [ 0.03876346],
       [-0.04282288],
       [-0.01524801],
       [ 0.05089285],
       [ 0.01208711],
       [-0.01610368],
       [-0.04366557],
       [-0.06130209],
       [-0.05301021],
       [-0.06177487],
       [ 0.03441837],
       [ 0.0022605 ],
       [ 0.00945387],
       [ 0.02267182],
       [-0.03139129],
       [-0.00231996],
       [ 0.01104939],
       [-0.00194078],
       [-0.00488062],
       [ 0.07904318],
       [ 0.01210235],
       [ 0.06127321],
       [-0.03265398],
       [-0.00600224],
       [-0.00677833],
       [ 0.01313514],
       [-0.03130886],
       [-0.00267588],
       [-0.0508193 ],
       [ 0.01370941],
       [ 0.01281381],
       [ 0.01268882],
       [ 0.03930015],
       [-0.01625279],
       [-0.06794672],
       [ 0.02319304],
       [-0.00174239],
       [ 0.03922104],
       [ 0.01155673],


<keras.callbacks.History at 0x2522bfa1520>

In [4]:
model.evaluate(x_test, y_test)



[1.7490650415420532, 0.34369999170303345]

In [5]:
!pip3 install pydot graphviz



In [6]:
!pip install pydotplus



In [7]:
input_layer=tf.keras.layers.Input(shape=[1], name="Deep_Input")

input_layer2=tf.keras.layers.Input(shape=[1], name="Wide_Input")

dense=tf.keras.layers.Dense(units=30, activation="relu")(input_layer)

dense_1=tf.keras.layers.Dense(units=30, activation="relu")(dense)

aux_output=tf.keras.layers.Dense(1)(dense_1)

concatnate=tf.keras.layers.concatenate([input_layer2, dense_1])

output=tf.keras.layers.Dense(1, name="output")(concatnate)

model=tf.keras.Model(inputs=[input_layer,input_layer2], outputs=[output,aux_output])


tf.keras.utils.plot_model(model)

model.summary()

('You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) ', 'for plot_model/model_to_dot to work.')
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
Deep_Input (InputLayer)         [(None, 1)]          0                                            
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 30)           60          Deep_Input[0][0]                 
__________________________________________________________________________________________________
Wide_Input (InputLayer)         [(None, 1)]          0                                            
__________________________________________________________________________________________________
dense_2 (Dense)        