In [9]:
import tensorflow as tf
import numpy as np

# Load and prepare data

In [2]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

In [12]:
print(x_train.dtype)
print(y_train.dtype)


float64
uint8


# Create Model

In [13]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])

# Make Predictions, returning logits

In [22]:
predictions = model(x_train[:1]).numpy()
print(predictions)

[[ 0.7574823   0.40195763 -0.19585942  0.47278902 -0.15115862 -0.08864196
   0.28004393 -0.271615   -0.35497454 -0.4060803 ]]


In [25]:
# convert to softmax. 
tf.nn.softmax(predictions).numpy()

# Note that it's not recommended to put softmax as last layer b/c impossible to provide exact and numerically stable loss calculation for all models

array([[0.18905489, 0.13249074, 0.07287136, 0.1422156 , 0.07620267,
        0.08111867, 0.11728409, 0.06755486, 0.06215185, 0.05905534]],
      dtype=float32)

# Compute loss using sparse categorical cross entropy

In [28]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# The losses.SparseCategoricalCrossentropy loss takes a vector of logits and a True index and returns a scalar loss for each example.
# This loss is equal to the negative log probability of the true class: It is zero if the model is sure of the correct class.


In [31]:
print(loss_fn(y_train[:1], predictions).numpy())
print()
# This untrained model should give probability close to random (1/10 for each class) , so the initial loss should be ~tf.log(1/10) ~= 2,3

2.5118423

# Compile model using the optimizer, loss function and metrics

In [36]:
model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])

# Train Model

In [37]:
model.fit(x_train, y_train, epochs=5)


Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x13b69117358>

# Evaluate model

In [38]:
model.evaluate(x_test,  y_test, verbose=2)
# Make sure the trained model works for validation or test set

313/313 - 0s - loss: 0.0740 - accuracy: 0.9776


[0.0740356296300888, 0.9775999784469604]

# Return probablity 

In [39]:
probability_model = tf.keras.Sequential([
  model,
  tf.keras.layers.Softmax()
])

In [41]:
probability_model(x_test[:5]).numpy()


array([[2.9712282e-07, 3.6487595e-09, 1.5694095e-05, 9.9794415e-05,
        2.3483321e-10, 5.1303329e-07, 8.0362163e-14, 9.9987435e-01,
        5.6426131e-07, 8.7411790e-06],
       [7.0113691e-07, 1.6030492e-04, 9.9971873e-01, 1.0946670e-04,
        9.2107235e-15, 6.3925845e-06, 1.0855697e-07, 4.5965705e-13,
        4.1810795e-06, 1.6208323e-11],
       [1.7864748e-07, 9.9972337e-01, 5.2863397e-05, 1.3593894e-05,
        2.5864128e-05, 5.8707483e-06, 6.7855131e-06, 8.2485552e-05,
        8.7939836e-05, 1.0368002e-06],
       [9.9992621e-01, 1.0704584e-08, 2.4439641e-05, 7.5949015e-07,
        5.0569820e-06, 2.7772865e-06, 3.0540232e-05, 1.8572230e-07,
        2.8287906e-07, 9.9068147e-06],
       [2.7675680e-06, 3.2633019e-07, 3.3358037e-06, 5.1699992e-08,
        9.9830866e-01, 9.9775832e-07, 2.3970661e-06, 3.5246470e-05,
        2.5950405e-07, 1.6461656e-03]], dtype=float32)

# Comparing results from raw model, probability model 

In [77]:
i = 5
predictions = model(x_train[i-1:i]).numpy()
prob_pred = probability_model(x_train[i-1:i]).numpy()
print(predictions)
print(prob_pred)
print(y_train[i-1:i])
print(np.argmax(predictions[0]))
print(np.argmax(prob_pred[0]))


[[-14.007774   -5.3724833  -8.441552   -3.0162597   2.0305045  -7.5744143
  -12.416223   -1.8851573  -1.0188295   9.887919 ]]
[[4.1884291e-11 2.3567354e-07 1.0950429e-08 2.4866001e-06 3.8671185e-04
  2.6062997e-08 2.0570920e-10 7.7061441e-06 1.8326466e-05 9.9958450e-01]]
[9]
9
9


In [90]:
i = 2
predictions = model(x_train[i-1:i]).numpy()
prob_pred = probability_model(x_train[i-1:i]).numpy()

loss_fn(y_train[i-1:i], prob_pred).numpy()

1.4615047