# Image classification with ConvMixer

**Author:** [Sayak Paul](https://twitter.com/RisingSayak)<br>
**Date created:** 2021/10/12<br>
**Last modified:** 2021/10/12<br>
**Description:** An all-convolutional network applied to patches of images.

## Introduction

Vision Transformers (ViT; [Dosovitskiy et al.](https://arxiv.org/abs/1612.00593)) extract
small patches from the input images, linearly project them, and then apply the
Transformer ([Vaswani et al.](https://arxiv.org/abs/1706.03762)) blocks. The application
of ViTs to image recognition tasks is quickly becoming a promising area of research,
because ViTs eliminate the need to have strong inductive biases (such as convolutions) for
modeling locality. This presents them as a general computation primititive capable of
learning just from the training data with as minimal inductive priors as possible. ViTs
yield great downstream performance when trained with proper regularization, data
augmentation, and relatively large datasets.

In the [Patches Are All You Need](https://openreview.net/pdf?id=TVHS5Y4dNvM) paper (note:
at
the time of writing, it is a submission to the ICLR 2022 conference), the authors extend
the idea of using patches to train an all-convolutional network and demonstrate
competitive results. Their architecture namely **ConvMixer** uses recipes from the recent
isotrophic architectures like ViT, MLP-Mixer
([Tolstikhin et al.](https://arxiv.org/abs/2105.01601)), such as using the same
depth and resolution across different layers in the network, residual connections,
and so on.

In this example, we will implement the ConvMixer model and demonstrate its performance on
the CIFAR-10 dataset.

## Imports

In [1]:
!pip install tensorflow==2.18.0

import keras
from keras import layers

import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np

Collecting tensorflow==2.18.0
  Downloading tensorflow-2.18.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Collecting astunparse>=1.6.0 (from tensorflow==2.18.0)
  Downloading astunparse-1.6.3-py2.py3-none-any.whl.metadata (4.4 kB)
Collecting flatbuffers>=24.3.25 (from tensorflow==2.18.0)
  Downloading flatbuffers-25.2.10-py2.py3-none-any.whl.metadata (875 bytes)
Collecting google-pasta>=0.1.1 (from tensorflow==2.18.0)
  Downloading google_pasta-0.2.0-py3-none-any.whl.metadata (814 bytes)
Collecting libclang>=13.0.0 (from tensorflow==2.18.0)
  Downloading libclang-18.1.1-py2.py3-none-manylinux2010_x86_64.whl.metadata (5.2 kB)
Collecting tensorboard<2.19,>=2.18 (from tensorflow==2.18.0)
  Downloading tensorboard-2.18.0-py3-none-any.whl.metadata (1.6 kB)
Collecting ml-dtypes<0.5.0,>=0.4.0 (from tensorflow==2.18.0)
  Downloading ml_dtypes-0.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting tensorflow-io-gcs-filesyste

## Hyperparameters

To keep run time short, we will train the model for only 10 epochs. To focus on
the core ideas of ConvMixer, we will not use other training-specific elements like
RandAugment ([Cubuk et al.](https://arxiv.org/abs/1909.13719)). If you are interested in
learning more about those details, please refer to the
[original paper](https://openreview.net/pdf?id=TVHS5Y4dNvM).

In [2]:
learning_rate = 0.0001
weight_decay = 0.0001
batch_size = 32 #estava 16
num_epochs = 100 #estava 100

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
from PIL import Image
import glob
import numpy as np

filelist1 = glob.glob('/content/drive/MyDrive/TypeCoffee.v25i.folder/test/DuroRiadoRio/*.jpg')

xt_drr = np.array([np.array(Image.open(fname)) for fname in filelist1])
print(xt_drr.shape)
yt_drr = np.zeros((19,1),dtype=np.uint8)

filelist2 = glob.glob('/content/drive/MyDrive/TypeCoffee.v25i.folder/test/Mole/*.jpg')

xt_mole = np.array([np.array(Image.open(fname)) for fname in filelist2])
yt_mole = np.ones((19,1),dtype=np.uint8)

filelist3 = glob.glob('/content/drive/MyDrive/TypeCoffee.v25i.folder/test/Quebrado/*.jpg')

xt_q = np.array([np.array(Image.open(fname)) for fname in filelist3])
yt_q= np.full((20,1),2,dtype=np.uint8)

filelist4 = glob.glob('/content/drive/MyDrive/TypeCoffee.v25i.folder/test/RiadoRio/*.jpg')

xt_rr = np.array([np.array(Image.open(fname)) for fname in filelist4])
yt_rr= np.full ((22,1),3,dtype=np.uint8)

filelist5 = glob.glob('/content/drive/MyDrive/TypeCoffee.v25i.folder/test/RioFechado/*.jpg')

xt_rf = np.array([np.array(Image.open(fname)) for fname in filelist5])
yt_rf= np.full ((20,1),4,dtype=np.uint8)

x_test=np.concatenate((xt_drr,xt_mole,xt_q,xt_rr,xt_rf), axis=0)
y_test=np.concatenate((yt_drr,yt_mole,yt_q,yt_rr,yt_rf), axis=0)

print(x_test.shape)
print(y_test.shape)

(19, 256, 256, 3)
(100, 256, 256, 3)
(100, 1)


In [6]:
filelist6 = glob.glob('/content/drive/MyDrive/TypeCoffee.v25i.folder/train/DuroRiadoRio/*.jpg')

xtrain_drr = np.array([np.array(Image.open(fname)) for fname in filelist6])
ytrain_drr = np.zeros((210,1),dtype=np.uint8)

filelist7 = glob.glob('/content/drive/MyDrive/TypeCoffee.v25i.folder/train/Mole/*.jpg')

xtrain_mole = np.array([np.array(Image.open(fname)) for fname in filelist7])
ytrain_mole = np.ones((215,1),dtype=np.uint8)

filelist8 = glob.glob('/content/drive/MyDrive/TypeCoffee.v25i.folder/train/Quebrado/*.jpg')

xtrain_q = np.array([np.array(Image.open(fname)) for fname in filelist8])
ytrain_q= np.full((206,1),2,dtype=np.uint8)

filelist9 = glob.glob('/content/drive/MyDrive/TypeCoffee.v25i.folder/train/RiadoRio/*.jpg')

xtrain_rr = np.array([np.array(Image.open(fname)) for fname in filelist9])
ytrain_rr= np.full ((212,1),3,dtype=np.uint8)

filelist10 = glob.glob('/content/drive/MyDrive/TypeCoffee.v25i.folder/train/RioFechado/*.jpg')

xtrain_rf = np.array([np.array(Image.open(fname)) for fname in filelist10])
ytrain_rf= np.full ((206,1),4,dtype=np.uint8)

x_train=np.concatenate((xtrain_drr,xtrain_mole,xtrain_q,xtrain_rr,xtrain_rf), axis=0)
y_train=np.concatenate((ytrain_drr,ytrain_mole,ytrain_q,ytrain_rr,ytrain_rf), axis=0)

print(x_train.shape)
print(y_train.shape)

(1049, 256, 256, 3)
(1049, 1)


In [7]:
filelist11 = glob.glob('/content/drive/MyDrive/TypeCoffee.v25i.folder/val/DuroRiadoRio/*.jpg')

xv_drr = np.array([np.array(Image.open(fname)) for fname in filelist11])
yv_drr = np.zeros((13,1),dtype=np.uint8)

filelist12 = glob.glob('/content/drive/MyDrive/TypeCoffee.v25i.folder/val/Mole/*.jpg')

xv_mole = np.array([np.array(Image.open(fname)) for fname in filelist12])
yv_mole = np.ones((11,1),dtype=np.uint8)

filelist13 = glob.glob('/content/drive/MyDrive/TypeCoffee.v25i.folder/val/Quebrado/*.jpg')

xv_q = np.array([np.array(Image.open(fname)) for fname in filelist13])
yv_q= np.full((13,1),2,dtype=np.uint8)

filelist14 = glob.glob('/content/drive/MyDrive/TypeCoffee.v25i.folder/val/RiadoRio/*.jpg')

xv_rr = np.array([np.array(Image.open(fname)) for fname in filelist14])
yv_rr= np.full ((10,1),3,dtype=np.uint8)

filelist15 = glob.glob('/content/drive/MyDrive/TypeCoffee.v25i.folder/val/RioFechado/*.jpg')

xv_rf = np.array([np.array(Image.open(fname)) for fname in filelist15])
yv_rf= np.full ((13,1),4,dtype=np.uint8)

x_val=np.concatenate((xv_drr,xv_mole,xv_q,xv_rr,xv_rf), axis=0)
y_val=np.concatenate((yv_drr,yv_mole,yv_q,yv_rr,yv_rf), axis=0)

print(x_val.shape)
print(y_val.shape)

(60, 256, 256, 3)
(60, 1)


## Prepare `tf.data.Dataset` objects

Our data augmentation pipeline is different from what the authors used for the CIFAR-10
dataset, which is fine for the purpose of the example.
Note that, it's ok to use **TF APIs for data I/O and preprocessing** with other backends
(jax, torch) as it is feature-complete framework when it comes to data preprocessing.

In [8]:
image_size = 256 #era 32
auto = tf.data.AUTOTUNE

augmentation_layers = [
    keras.layers.RandomCrop(image_size, image_size),
    keras.layers.RandomFlip("horizontal"),
]


def augment_images(images):
    for layer in augmentation_layers:
        images = layer(images, training=True)
    return images


def make_datasets(images, labels, is_train=False):
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))
    if is_train:
        dataset = dataset.shuffle(batch_size * 10)
    dataset = dataset.batch(batch_size)
    if is_train:
        dataset = dataset.map(
            lambda x, y: (augment_images(x), y), num_parallel_calls=auto
        )
    return dataset.prefetch(auto)


train_dataset = make_datasets(x_train, y_train, is_train=True)
val_dataset = make_datasets(x_val, y_val)
test_dataset = make_datasets(x_test, y_test)

## ConvMixer utilities

The following figure (taken from the original paper) depicts the ConvMixer model:

![](https://i.imgur.com/yF8actg.png)

ConvMixer is very similar to the MLP-Mixer, model with the following key
differences:

* Instead of using fully-connected layers, it uses standard convolution layers.
* Instead of LayerNorm (which is typical for ViTs and MLP-Mixers), it uses BatchNorm.

Two types of convolution layers are used in ConvMixer. **(1)**: Depthwise convolutions,
for mixing spatial locations of the images, **(2)**: Pointwise convolutions (which follow
the depthwise convolutions), for mixing channel-wise information across the patches.
Another keypoint is the use of *larger kernel sizes* to allow a larger receptive field.

In [9]:

def activation_block(x):
    x = layers.Activation("gelu")(x)
    return layers.BatchNormalization()(x)


def conv_stem(x, filters: int, patch_size: int):
    x = layers.Conv2D(filters, kernel_size=patch_size, strides=patch_size)(x)
    return activation_block(x)


def conv_mixer_block(x, filters: int, kernel_size: int):
    # Depthwise convolution.
    x0 = x
    x = layers.DepthwiseConv2D(kernel_size=kernel_size, padding="same")(x)
    x = layers.Add()([activation_block(x), x0])  # Residual.

    # Pointwise convolution.
    x = layers.Conv2D(filters, kernel_size=1)(x)
    x = activation_block(x)

    return x


def get_conv_mixer_256_8(
    image_size=256, filters=256, depth=8, kernel_size=5, patch_size=2, num_classes=5
):
    """ConvMixer-256/8: https://openreview.net/pdf?id=TVHS5Y4dNvM.
    The hyperparameter values are taken from the paper.
    """
    inputs = keras.Input((image_size, image_size, 3))
    x = layers.Rescaling(scale=1.0 / 255)(inputs)

    # Extract patch embeddings.
    x = conv_stem(x, filters, patch_size)

    # ConvMixer blocks.
    for _ in range(depth):
        x = conv_mixer_block(x, filters, kernel_size)

    # Classification block.
    x = layers.GlobalAvgPool2D()(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)

    return keras.Model(inputs, outputs)


The model used in this experiment is termed as **ConvMixer-256/8** where 256 denotes the
number of channels and 8 denotes the depth. The resulting model only has 0.8 million
parameters.

## Model training and evaluation utility

In [10]:
# Code reference:
# https://keras.io/examples/vision/image_classification_with_vision_transformer/.


def run_experiment(model):
    optimizer = keras.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

    model.compile(
        optimizer=optimizer,
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )

    checkpoint_filepath = "/tmp/checkpoint.keras"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=False,
    )

    history = model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=num_epochs,
        callbacks=[checkpoint_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, accuracy = model.evaluate(test_dataset)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")

    return history, model


## Train and evaluate model

In [None]:
conv_mixer_model = get_conv_mixer_256_8()
history, conv_mixer_model = run_experiment(conv_mixer_model)

Epoch 1/100
[1m28/33[0m [32m━━━━━━━━━━━━━━━━[0m[37m━━━━[0m [1m1:45[0m 21s/step - accuracy: 0.2858 - loss: 1.5627

In [None]:
!pip install scikit-learn # Install scikit-learn if you haven't already

from sklearn.metrics import precision_recall_fscore_support # Import the function
import numpy as np

y_true = []
y_pred = []
for images, labels in zip(x_test, y_test):
  predictions = conv_mixer_model.predict(np.expand_dims(images, axis=0))  # Predict on a single image
  predicted_label = np.argmax(predictions, axis=1)[0]  # Get the predicted label
  y_true.append(labels[0])  # Assuming label is a single-element array
  y_pred.append(predicted_label)

precision, recall, f1_score, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')
print('Precision:', precision)
print('Recall:', recall)
print('F1 Score:', f1_score)

The gap in training and validation performance can be mitigated by using additional
regularization techniques. Nevertheless, being able to get to ~83% accuracy within 10
epochs with 0.8 million parameters is a strong result.

## Visualizing the internals of ConvMixer

We can visualize the patch embeddings and the learned convolution filters. Recall
that each patch embedding and intermediate feature map have the same number of channels
(256 in this case). This will make our visualization utility easier to implement.

In [None]:
# Code reference: https://bit.ly/3awIRbP.


def visualization_plot(weights, idx=1):
    # First, apply min-max normalization to the
    # given weights to avoid isotrophic scaling.
    p_min, p_max = weights.min(), weights.max()
    weights = (weights - p_min) / (p_max - p_min)

    # Visualize all the filters.
    num_filters = 256
    plt.figure(figsize=(8, 8))

    for i in range(num_filters):
        current_weight = weights[:, :, :, i]
        if current_weight.shape[-1] == 1:
            current_weight = current_weight.squeeze()
        ax = plt.subplot(16, 16, idx)
        ax.set_xticks([])
        ax.set_yticks([])
        plt.imshow(current_weight)
        idx += 1


# We first visualize the learned patch embeddings.
patch_embeddings = conv_mixer_model.layers[2].get_weights()[0]
visualization_plot(patch_embeddings)

Even though we did not train the network to convergence, we can notice that different
patches show different patterns. Some share similarity with others while some are very
different. These visualizations are more salient with larger image sizes.

Similarly, we can visualize the raw convolution kernels. This can help us understand
the patterns to which a given kernel is receptive.

In [None]:
# First, print the indices of the convolution layers that are not
# pointwise convolutions.
for i, layer in enumerate(conv_mixer_model.layers):
    if isinstance(layer, layers.DepthwiseConv2D):
        if layer.get_config()["kernel_size"] == (5, 5):
            print(i, layer)

idx = 26  # Taking a kernel from the middle of the network.

kernel = conv_mixer_model.layers[idx].get_weights()[0]
kernel = np.expand_dims(kernel.squeeze(), axis=2)
visualization_plot(kernel)

In [None]:
conv_mixer_model.summary()

In [None]:
import torch

# Save the model's state dictionary
torch.save(conv_mixer_model, 'meu_modelo_convmixer.pt')

We see that different filters in the kernel have different locality spans, and this
pattern
is likely to evolve with more training.

## Final notes

There's been a recent trend on fusing convolutions with other data-agnostic operations
like self-attention. Following works are along this line of research:

* ConViT ([d'Ascoli et al.](https://arxiv.org/abs/2103.10697))
* CCT ([Hassani et al.](https://arxiv.org/abs/2104.05704))
* CoAtNet ([Dai et al.](https://arxiv.org/abs/2106.04803))