In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Input
from tensorflow.keras.models import Model
from tensorflow.keras.utils import to_categorical


# -------- MODEL --------
def build_fcf_nn(input_shape, num_classes):
    inputs = Input(shape=input_shape)
    x = Flatten()(inputs)
    x = Dense(512, activation='relu')(x)
    x = Dense(256, activation='relu')(x)
    x = Dense(128, activation='relu')(x)
    outputs = Dense(num_classes, activation='softmax')(x)

    model = Model(inputs, outputs)
    model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    return model


# -------- TRAIN FUNCTION --------
def run_dataset(dataset_name):
    if dataset_name == "mnist":
        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

    elif dataset_name == "fashion":
        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

    elif dataset_name == "cifar":
        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

    # Normalize
    x_train = x_train / 255.0
    x_test = x_test / 255.0

    # One-Hot
    y_train = to_categorical(y_train, 10)
    y_test = to_categorical(y_test, 10)

    model = build_fcf_nn(x_train.shape[1:], 10)

    model.fit(x_train, y_train,
              validation_split=0.1,
              epochs=10,
              batch_size=128)

    loss, acc = model.evaluate(x_test, y_test)
    print(f"{dataset_name} Test Accuracy:", acc)


# -------- MAIN --------
run_dataset("mnist")
run_dataset("fashion")
run_dataset("cifar")


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Epoch 1/10
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 17ms/step - accuracy: 0.8643 - loss: 0.4606 - val_accuracy: 0.9697 - val_loss: 0.0980
Epoch 2/10
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 16ms/step - accuracy: 0.9723 - loss: 0.0918 - val_accuracy: 0.9765 - val_loss: 0.0820
Epoch 3/10
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.9829 - loss: 0.0562 - val_accuracy: 0.9810 - val_loss: 0.0683
Epoch 4/10
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 17ms/step - accuracy: 0.9898 - loss: 0.0337 - val_accuracy: 0.9772 - val_loss: 0.0803
Epoch 5/10
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 14ms/step - accuracy: 0.9918 - loss: 0.0257 - val_accuracy: 0.9790 - val_loss: 0.0804
Epoch 6/