### Model Subclassing

In [2]:
from tensorflow import keras
from tensorflow.keras import layers, Input, Model
from tensorflow.keras.layers import Concatenate, Dense
import numpy as np
from keras.utils import plot_model

**Variation 1: Functional model using a subclassed model**

In [11]:
class Classifier(Model):

    def __init__(self, num_classes=2):
        super().__init__()
        
        if num_classes == 2:
            num_units = 1
            activation = "sigmoid"
        else:
            num_units = num_classes
            activation = "softmax"
        self.dense = Dense(num_units, activation=activation)

    def call(self, inputs):
        return self.dense(inputs)
    
inputs = Input(shape=(3,))
features = Dense(64, activation="relu")(inputs)
outputs = Classifier(num_classes=10)(features)

model = Model(inputs=inputs, outputs=outputs)

**Variation 2: Subclassed model using a Functional model**

In [10]:
inputs = Input(shape=(64,))
outputs = Dense(1, activation="sigmoid")(inputs)
binary_classifier = Model(inputs=inputs, outputs=outputs)

class CustomModel(Model):

    def __init__(self, num_classes=2):
        super().__init__()
        self.dense = Dense(64, activation="relu")
        self.classifier = binary_classifier

    def call(self, inputs):
        features = self.dense(inputs)
        
        return self.classifier(features)

model = CustomModel()

## Using built-in training and evaluation loops

**The standard workflow: `compile()`, `fit()`, `evaluate()`, `predict()`**

In [0]:
from tensorflow.keras.datasets import mnist

def get_mnist_model():
    inputs = keras.Input(shape=(28 * 28,))
    features = layers.Dense(512, activation="relu")(inputs)
    features = layers.Dropout(0.5)(features)
    outputs = layers.Dense(10, activation="softmax")(features)
    model = keras.Model(inputs, outputs)
    return model

(images, labels), (test_images, test_labels) = mnist.load_data()
images = images.reshape((60000, 28 * 28)).astype("float32") / 255
test_images = test_images.reshape((10000, 28 * 28)).astype("float32") / 255
train_images, val_images = images[10000:], images[:10000]
train_labels, val_labels = labels[10000:], labels[:10000]

model = get_mnist_model()
model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])
model.fit(train_images, train_labels,
          epochs=3,
          validation_data=(val_images, val_labels))
test_metrics = model.evaluate(test_images, test_labels)
predictions = model.predict(test_images)