In [1]:
import tensorflow as tf

Load and prepare the **MNIST** dataset. 

Convert the samples from integers to floating-point numbers

In [5]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

Build the `tf.keras.Sequential` model by stacking layers. 

Choose an optimizer and loss function for training:

In [7]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])

For each example the model returns a vector of "logits" or "log-odds" scores, one for each class. The `tf.nn.softmax` function converts these logits to "probabilities" for each class:

In [12]:
predictions = model(x_train[:1]).numpy()
tf.nn.softmax(predictions).numpy()

array([[0.09442907, 0.0958738 , 0.06241443, 0.05443078, 0.11147247,
        0.12428605, 0.06256733, 0.16055244, 0.14737795, 0.08659574]],
      dtype=float32)

The `losses.SparseCategoricalCrossentropy` loss takes a vector of logits and a True index and returns a scalar loss for each example.

In [17]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
loss_fn(y_train[:1], predictions).numpy()

2.0851696

In [18]:
model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])

In [19]:
model.fit(x_train, y_train, epochs=5)

Train on 60000 samples
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x6679d6b10>

In [21]:
model.evaluate(x_test,  y_test, verbose=2)

10000/1 - 0s - loss: 0.0402 - accuracy: 0.9754


[0.07837730197343044, 0.9754]

In [22]:
probability_model = tf.keras.Sequential([
  model,
  tf.keras.layers.Softmax()
])

In [24]:
probability_model(x_test[:1])

<tf.Tensor: id=31220, shape=(1, 10), dtype=float32, numpy=
array([[7.7479307e-08, 1.2133576e-08, 5.0500603e-05, 1.0046952e-03,
        2.0051105e-12, 6.6146044e-06, 7.3978726e-14, 9.9893421e-01,
        1.1884456e-06, 2.6114392e-06]], dtype=float32)>