In [None]:
from tensorflow import keras

import tensorflow_addons as tfa
import tensorflow as tf
import tensorflow_datasets as tfds
import wandb
from wandb.keras import WandbCallback

In [None]:
wandb.init(project="Conv-Mixer",name="ConvMixer",resume=True)

In [None]:
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 128
num_epochs = 32

In [None]:
dataset, info = tfds.load('cifar10', with_info=True, as_supervised=True)

In [None]:
train_dataset, test_dataset = dataset['train'], dataset['test']

In [None]:
#create a function to normalize and resize the images
def normalize_and_resize(image, label):
    image=tf.cast(image, tf.float32)
    image=tf.divide(image, 255)
    image=tf.image.resize(image, [28,28])
    return image, label

#create a function to augment the images
def augment(image, label):
    image=tf.image.random_flip_left_right(image)
    image=tf.image.random_flip_up_down(image)
    image=tf.image.random_brightness(image, max_delta=0.5)
    image=tf.image.random_contrast(image, lower=0.2, upper=1.8)
    image=tf.image.random_hue(image, max_delta=0.2)
    image=tf.image.random_saturation(image, lower=0.2, upper=1.8)
    return image, label

In [None]:
#modify the train and test datasets using the function
train_dataset=train_dataset.map(normalize_and_resize).cache().map(augment).shuffle(1000).batch(64).repeat()
test_dataset=test_dataset.map(normalize_and_resize).cache().batch(64)

In [None]:
#This function performs the activation function and the post activation batch normalization
def activation_normalization_layer(x):
    """
    x: input tensor
    """
    x=keras.layers.Activation('gelu')(x)
    x=keras.layers.BatchNormalization()(x)
    return x


#This function creates the patch embeddings
def patch_conv_layer(x, filters, patch_size):
    """
    x: input tensor
    filters: number of filters or hidden dimension
    patch_size: the patch size which in this case determines the kernel size and stride
    """
    x=keras.layers.Conv2D(filters=filters, kernel_size=patch_size, strides=patch_size)(x)
    x=activation_normalization_layer(x)
    return x


#This is the main ConvMixer layer which is repeated "depth" times
def conv_mixer_layer(x, filters, kernel_size):
    """
    x: input tensor
    filters: number of filters or hidden dimension
    kernel_size: the kernel size
    """
    #residual depthwise convolution
    initial=x
    x=keras.layers.DepthwiseConv2D(kernel_size=kernel_size, padding="same")(x)
    x=activation_normalization_layer(x)
    x=keras.layers.Concatenate()([x, initial])
    
    #pointwise convolution 1x1
    x=keras.layers.Conv2D(filters=filters, kernel_size=1, padding="same")(x) #1x1 because pointwise
    x=activation_normalization_layer(x)
    
    return x
    
def conv_mixer_model(image_size=28,filters=256,depth=8,kernel_size=5,patch_size=2,num_classes=10):
    """
    image_size: the size of the image
    filters: number of filters or hidden dimension
    depth: the number of times the conv_mixer_layer is repeated
    kernel_size: the kernel size
    patch_size: the patch size
    num_classes: the number of classes in the output
    """
    inputs=keras.Input(shape=(image_size,image_size,3))
    
    #get the patches
    x=patch_conv_layer(inputs, filters, patch_size)
    
    #conv mixer block repeated 'depth' times
    for _ in range(depth):
        x=conv_mixer_layer(x, filters, kernel_size)
    
    #pooling and softmax
    x=keras.layers.GlobalAveragePooling2D()(x)
    output=keras.layers.Dense(num_classes,activation="softmax")(x)
    
    model=keras.Model(inputs=inputs, outputs=output)
    
    return model


In [None]:
def run(model):
    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

    model.compile(
        optimizer=optimizer,
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )

    history = model.fit(
        train_dataset,
        validation_data=test_dataset,
        epochs=num_epochs,
        steps_per_epoch=60000//64,
        callbacks=[WandbCallback()],
    )

    return history, model

In [None]:
model = conv_mixer_model()
model.summary()

In [None]:
history, model = run(model)