# model CNN (digit prediction)

In [None]:
import tensorflow as tf
from keras.utils import to_categorical
from keras.datasets import mnist
from sklearn.model_selection import train_test_split

In [None]:
(x_train_main, y_train_main), (x_test, y_test) = mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [None]:
x_train, x_val, y_train, y_val = train_test_split(x_train_main, y_train_main, test_size=0.2, random_state=42)

* train = (48000, 28, 28)
* val = (12000, 28, 28)
* test = (10000, 28, 28)

In [None]:
x_train = x_train.reshape(48000, 28, 28, 1)
x_val = x_val.reshape(12000, 28, 28, 1)
x_test = x_test.reshape(10000, 28, 28, 1)

In [None]:
y_train = to_categorical(y_train)
y_val = to_categorical(y_val)
y_test = to_categorical(y_test)

In [None]:
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(16, kernel_size=3, activation='relu', input_shape=(28, 28, 1)),
    tf.keras.layers.Conv2D(32, kernel_size=3, activation='relu'),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.4),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')])

In [None]:
class MyCallback(tf.keras.callbacks.Callback):
        def on_epoch_end(self, epoch, logs={}):
            if (logs.get('val_loss') < 0.12):
                print("\nVal loss lower than 0.12")
                self.model.stop_training = True

callback = MyCallback()

In [None]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
              loss='categorical_crossentropy',
              metrics=['acc'])

In [None]:
model.fit(x_train,
          y_train,
          validation_data=(x_val, y_val),
          epochs=100,
          callbacks=callback)

Epoch 1/100
Epoch 2/100
Val loss lower than 0.12


<keras.src.callbacks.History at 0x78d8431c0ca0>

In [None]:
model.evaluate(x_test, y_test)



[0.08816168457269669, 0.974399983882904]

# save model

In [None]:
#!pip install pyyaml h5py
#model.save('/content/drive/MyDrive/model2.h5')