In [21]:
import tensorflow as tf
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.datasets.mnist import load_data
from tensorflow.keras.utils import to_categorical
import numpy as np

Load & Preprocess dataset

In [22]:
(x_train, y_train), (x_test, y_test) = load_data()

x_train, x_test = x_train / 255.0, x_test / 255
y_train, y_test = to_categorical(y_train, num_classes=10), to_categorical(y_test, num_classes=10)

In [23]:
def create_model():
    inputs = Input(shape=(28, 28))
    x = Flatten()(inputs)
    x = Dense(128, activation='relu')(x)
    x = Dense(64, activation='relu')(x)
    x = Dense(32, activation='relu')(x)
    outputs = Dense(10, activation='softmax')(x)
    model = Model(inputs=inputs, outputs=outputs)
    return model

In [24]:
model = create_model()
loss_fn = CategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()

In [25]:
@tf.function
def train_step(model, images, labels):
    with tf.GradientTape() as tape:
        predictions = model(images, training=True)
        loss = loss_fn(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss

train model with gradient tape

In [26]:
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(60000).batch(32)

for epoch in range(5):
    epoch_loss = tf.keras.metrics.Mean()
    for images, labels in train_dataset:
        loss = train_step(model, images, labels)
        epoch_loss.update_state(loss)
    print(f'Epoch {epoch + 1}, Loss: {epoch_loss.result().numpy():.4f}')

  output, from_logits = _get_logits(
2025-06-22 20:29:38.983142: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 1, Loss: 0.2619
Epoch 2, Loss: 0.1089
Epoch 3, Loss: 0.0770
Epoch 4, Loss: 0.0593
Epoch 5, Loss: 0.0473


Evaluate model


In [27]:
test_accuracy = tf.keras.metrics.CategoricalAccuracy()

test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

for images, labels in test_dataset:
    predictions = model(images, training=False)
    test_accuracy.update_state(labels, predictions)

print(f'\nGradientTape Final Test Accuracy: {test_accuracy.result().numpy():.4f}')


GradientTape Final Test Accuracy: 0.9758
