#### In this notebook we train a LeNet5 model on MNIST dataset. <br> We will use it for testing Quantus

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import quantus
import tensorflow_addons as tfa

BATCH_SIZE = 1024

tf.config.list_physical_devices()

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

#### Load MNIST dataset

In [2]:
train_ds, val_ds = tfds.load(
    "mnist",
    try_gcs=True,
    as_supervised=True,
    split=["train", "test"],
    shuffle_files=True,
    batch_size=BATCH_SIZE
)

def configure_ds(ds: tf.data.Dataset) -> tf.data.Dataset:
    return (
        ds
        .map(lambda x, y: (tf.image.resize(x, (28, 28)),y))
        .cache()
        .prefetch(tf.data.experimental.AUTOTUNE)
    )


train_ds = configure_ds(train_ds)
val_ds = configure_ds(val_ds)

train_ds

Metal device set to: AMD Radeon Pro 560


<PrefetchDataset element_spec=(TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.int64, name=None))>

#### Create LeNet model

In [4]:
model = quantus.LeNetTF()
model.compile(
    optimizer=tfa.optimizers.AdamW(weight_decay=0.004),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = ['accuracy']
)
model.summary()

Model: "LeNetTF"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_2 (Conv2D)           (None, 26, 26, 6)         60        
                                                                 
 average_pooling2d_2 (Averag  (None, 13, 13, 6)        0         
 ePooling2D)                                                     
                                                                 
 conv2d_3 (Conv2D)           (None, 11, 11, 16)        880       
                                                                 
 average_pooling2d_3 (Averag  (None, 5, 5, 16)         0         
 ePooling2D)                                                     
                                                                 
 flatten_1 (Flatten)         (None, 400)               0         
                                                                 
 dense_3 (Dense)             (None, 120)               4812

#### Train the model

In [5]:
model.fit(train_ds, validation_data=val_ds, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7f86fc00ba00>

#### Save the weights

In [6]:
model.save_weights('assets/lenet_mnist_weights.keras')