### In this notebook we train a simple 2D CNN on MNISt and Cifar10 datasets. <br> We will use these models for testing Quantus with TensorFlow

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import quantus
import tensorflow_addons as tfa

BATCH_SIZE = 1024

tf.config.list_physical_devices()

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

#### Load MNIST dataset

In [2]:
train_ds, val_ds = tfds.load(
    "mnist",
    try_gcs=True,
    as_supervised=True,
    split=["train", "test"],
    shuffle_files=True,
    batch_size=BATCH_SIZE
)

def configure_ds(ds: tf.data.Dataset) -> tf.data.Dataset:
    return (
        ds
        .map(lambda x, y: (tf.image.resize(x, (28, 28)),y))
        .cache()
        .prefetch(tf.data.experimental.AUTOTUNE)
    )


train_ds = configure_ds(train_ds)
val_ds = configure_ds(val_ds)

train_ds

Metal device set to: AMD Radeon Pro 560


<PrefetchDataset element_spec=(TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.int64, name=None))>

#### Create 2D CNN model

In [3]:
model = quantus.CNN_2D_TF(28, 28, 10, num_channels=1)
model.compile(
    optimizer=tfa.optimizers.AdamW(weight_decay=0.004),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = ['accuracy']
)
model.summary()

Model: "CNN_2D_TF"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 28, 28, 16)        160       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 14, 14, 16)       0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 14, 14, 32)        4640      
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 7, 7, 32)         0         
 2D)                                                             
                                                                 
 test_conv (Conv2D)          (None, 7, 7, 64)          18496     
                                                                 
 max_pooling2d_2 (MaxPooling  (None, 3, 3, 64)         0 

#### Train the model

In [4]:
model.fit(train_ds, validation_data=val_ds, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7fceb0c91d90>

#### Save the weights

In [6]:
model.save_weights('assets/cnn_2d_mnist_weights.keras')

#### Load Cifar10 dataset

In [7]:
train_ds, val_ds = tfds.load(
    "cifar10",
    try_gcs=True,
    as_supervised=True,
    split=["train", "test"],
    shuffle_files=True,
    batch_size=BATCH_SIZE
)

def configure_ds(ds: tf.data.Dataset) -> tf.data.Dataset:
    return (
        ds
        .map(lambda x, y: (tf.image.resize(x, (32, 32)),y))
        .cache()
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

train_ds = configure_ds(train_ds)
val_ds = configure_ds(val_ds)
train_ds

<PrefetchDataset element_spec=(TensorSpec(shape=(None, 32, 32, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.int64, name=None))>

#### Create 2D CNN model

In [8]:
model = quantus.CNN_2D_TF(32, 32, 10, num_channels=3)
model.compile(
    optimizer=tfa.optimizers.AdamW(weight_decay=0.004),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = ['accuracy']
)
model.summary()

Model: "CNN_2D_TF"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_2 (Conv2D)           (None, 32, 32, 16)        448       
                                                                 
 max_pooling2d_3 (MaxPooling  (None, 16, 16, 16)       0         
 2D)                                                             
                                                                 
 conv2d_3 (Conv2D)           (None, 16, 16, 32)        4640      
                                                                 
 max_pooling2d_4 (MaxPooling  (None, 8, 8, 32)         0         
 2D)                                                             
                                                                 
 test_conv (Conv2D)          (None, 8, 8, 64)          18496     
                                                                 
 max_pooling2d_5 (MaxPooling  (None, 4, 4, 64)         0 

#### Train the model

In [9]:
model.fit(train_ds, validation_data=val_ds, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7fceb06e6d00>

### Save the weights

In [11]:
model.save_weights('assets/cnn_2d_cifar_weights.keras')