In [1]:
import tensorflow as tf
from sklearn.model_selection import KFold

In [2]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()


In [3]:
x_train, x_test = x_train / 255.0, x_test / 255.0

In [4]:
kf = KFold(n_splits = 5)

In [5]:
for train_index, val_index in kf.split(x_train):
    x_trainNew, x_validate = x_train[train_index], x_train[val_index]
    y_trainNew, y_validate = y_train[train_index], y_train[val_index]
    

In [6]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])


In [7]:
kf.split(x_train)

<generator object _BaseKFold.split at 0x0000020D800088C8>

In [8]:
predictions = model(x_trainNew[:5]).numpy()
predictions

array([[-0.56863916,  0.48715016,  0.918115  ,  0.20932673, -0.34033895,
         0.00482205,  0.376382  ,  0.18755674,  0.0556089 , -0.20087108],
       [ 0.01784399,  0.42656547,  0.4839636 , -0.33098048, -0.67142874,
         0.5823812 ,  0.11994042, -0.12319333,  0.02701378, -0.00309842],
       [-0.42536652, -0.09492633,  0.13521388, -0.6215933 ,  0.23637794,
         0.5459731 ,  0.5312194 ,  0.28240892,  0.15680578, -0.01617901],
       [-0.02080064,  0.13191366,  0.70834494,  0.23154882, -0.25943685,
         0.2619541 ,  0.43560153, -0.01516246, -0.0087954 , -0.7256506 ],
       [-0.95274067,  0.65920955,  0.118721  ,  0.49051687, -0.5086674 ,
        -0.32423648,  0.53903645, -0.35202697, -0.17892955, -0.7803909 ]],
      dtype=float32)

In [9]:
tf.nn.softmax(predictions).numpy()

array([[0.04647006, 0.13356623, 0.20552391, 0.10116715, 0.05838788,
        0.08245638, 0.11956131, 0.09898854, 0.08675224, 0.06712633],
       [0.09066827, 0.136446  , 0.14450687, 0.063968  , 0.04551012,
        0.15945227, 0.10041422, 0.07874148, 0.0915035 , 0.0887892 ],
       [0.0572327 , 0.07964391, 0.10025388, 0.04703531, 0.11092672,
        0.15117908, 0.148965  , 0.11615214, 0.1024421 , 0.08616921],
       [0.08525057, 0.09931625, 0.17675073, 0.10972139, 0.067152  ,
        0.11310873, 0.13455823, 0.08573259, 0.0862802 , 0.04212936],
       [0.03813148, 0.19113699, 0.11133034, 0.1614665 , 0.05944868,
        0.07148905, 0.16949396, 0.06952969, 0.08266954, 0.04530375]],
      dtype=float32)

In [10]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [11]:
loss_fn(y_trainNew[:5], predictions).numpy()

2.499746

In [12]:
model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])

In [13]:
model.fit(x_trainNew, y_trainNew, epochs=10, validation_data = (x_validate, y_validate), batch_size = 50)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x20d8a538208>

In [14]:
model.evaluate(x_test,  y_test, verbose=2)

313/313 - 0s - loss: 0.0708 - accuracy: 0.9797


[0.07077046483755112, 0.9797000288963318]

In [15]:
probability_model = tf.keras.Sequential([
  model,
  tf.keras.layers.Softmax()
])

In [16]:
probability_model(x_test[:5])

<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[1.22530674e-09, 3.73228914e-09, 7.76893614e-07, 9.16007848e-05,
        1.84010495e-13, 8.07300093e-09, 3.10705048e-13, 9.99906778e-01,
        1.00173963e-07, 6.93777906e-07],
       [2.68920797e-09, 2.37271342e-05, 9.99974370e-01, 1.72356590e-06,
        1.18596700e-18, 4.04456486e-08, 2.67031321e-08, 5.53421364e-16,
        2.82407996e-07, 9.07896669e-15],
       [1.24879057e-06, 9.98585820e-01, 4.51339620e-05, 4.99264752e-06,
        7.55726569e-06, 1.46426896e-06, 1.00686266e-05, 1.04680681e-03,
        2.96944490e-04, 1.72672060e-07],
       [9.99761522e-01, 1.12232659e-08, 2.16959132e-04, 6.66968845e-08,
        3.91663013e-09, 2.97716838e-06, 1.43862635e-05, 3.48822232e-06,
        5.78243409e-09, 6.51074117e-07],
       [1.02093293e-06, 7.80819942e-09, 9.37675577e-06, 2.68124083e-08,
        9.97686625e-01, 9.15143090e-08, 1.41833589e-05, 3.43706852e-05,
        5.53745622e-07, 2.25371704e-03]], dtype=float32)>

In [17]:
#Without KFold: accuracy = .9764, init loss function = 2.689857
#With KFold: accuracy = .9762, init loss function = 2.3394208