---
title: "A General Purpose Deep Learning Architecture - Perceiver IO From Scratch"
toc: true
format:
    html: 
        code-fold: False
jupyter: python3
author: "Yann Dupis"
date: "2022-10-31"
categories: [deep learning, code, from scratch]
image: "./img/cross-attention.png"
---

In the recent years, we saw tremendous progress across all the machine learning tasks from image classification, object detection, translation, question answering, text classification etc. Main drivers of this progress are more compute, data, new training approaches such as transfer learning, but also more sophisticated neural network architectures. Fortunately these models are available at out fingertips. If you tackle a computer vision task, you can leverage the [ Pytorch Image Models (timm)](https://timm.fast.ai/#List-Models-with-Pretrained-Weights) library. As of now it contains more than 50 different type of architectures translating into 600 models which have [different trade offs](https://www.kaggle.com/code/jhoward/which-image-models-are-best) in terms of accuracy and inference time depending on the task. For NLP, you can use the popular [transformer library](https://github.com/huggingface/transformers) from HuggingFace which offers almost [80k models](https://huggingface.co/models) based on about 140 different type of architectures.

Each of these neural network architectures built upon one another by the research community, were able to achieve better performance thanks to very carefully hand crafted [inductive bias](https://samiraabnar.github.io/articles/2020-05/recurrence) for each specific task. For example, CNNs used in computer vision models such as ResNet have a locality bias where they assume close pixels are related to each other. LSTMs which used to be popular for NLP tasks before Transformers appeared, have a sequential bias where they process each element of a sequence one after another. However because the inductive bias is hand crafted for a specific task, existing architectures are often not able to generalize across multiple modalities. The Perceiver IO ([Jaegle et al, 2021](https://arxiv.org/abs/2107.14795)) architecture released by DeepMind in 2021 aim to solve this challenge and demonstrate they can achieve strong results across a wide variety of single-modal and multi-modal tasks such as: GLUE language benchmark, predicting optical flow between images, image classification, multi-modal video + audio classification, audio-video-label multi-modal autoencoding etc. 

In this blog post, we will implement a Perceiver IO architecture from scratch and we will apply it to two different domains: image classification and text classification. In conjunction to this blog post, I highly recommend watching [Drew Jaegle's talk](https://www.youtube.com/watch?v=wTZ3o36lXoQ) who's the main author of the Perceiver IO paper, and also read [this HuggingFace's blog post](https://huggingface.co/blog/perceiver). These resources in addition to training several Perceiver models with the [transformers library](https://github.com/huggingface/transformers) inspired me to implement this architecture from scratch.

In [1]:
import math
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


## Transformers


The Perceiver architecture is based on the Transformer architecture ([Vaswani et al., 2017](https://arxiv.org/pdf/1706.03762.pdf)). As described in [Drew Jaegle's presentation](https://www.youtube.com/watch?v=wTZ3o36lXoQ), the main benefits of transformers are:

- They don't make specific assumption about the domain.
- General purpose inductive bias.
- Position is a feature instead of being a constraint.
- Mainly uses matmuls which are efficiently computed by GPU and TPU.

The cons are: 

- Attention scales quadratically based on the input. 
- Multilayer perceptron (MLP) scales linearly based on the input. 

Transformers consist of: a self-attention module, multilayer perceptron (also called feed-forward layer), layer normalization and skip connections. But really the heart of the transformers are their self-attention module and that's where one of the main Perceiver's innovation has been introduced. So let's first implement it, understand why it scales quadratically and learn how the perceiver architecture addresses this challenge.

<center><img src="./img/transformer-architecture.png" width="150"/></center>


### Self-Attention

To get an intuition about the self-attention mechanism, I highly recommend [Peter Bloem's blog post](https://peterbloem.nl/blog/transformers) or the natural language processing with transformers [book](https://learning.oreilly.com/library/view/natural-language-processing/9781098136789/). Later in this post, we will demonstrate how the perceiver can be applied to images. But let's assume for now that our input is a sentence where each words of the sentence are represented by token embeddings. 

The self attention mechanism consists of the following steps:

- Project the input (token embeddings in our case) into a query, key and values
- Compute the similarity between the query and the key using a dot product. The result of the dot product represents the attention scores.
- Multiply the attention scores by a scaling factor to normalize their variance and apply a softmax so each column of the matrix, representing the attention weights, sum up to 1.
- Compute the dot product between the attention weights and the values.

The main purpose of the self-attention mechanism is to produce a new representation of the token embeddings where now each new embeddings is a linear combination (or weighted average) of the ordinal embeddings. This allow us to encode some contextual information. For example, we have the following sentence: "The tree bark is rough to the touch." and "I hope her dog doesn’t bark when I knock on the door." In the two sentences, we have the word bark but with a different meaning. In the first sentence, bark refers to the tree and in the second sentence it refers to the dog. When applying the self attention mechanism to the embedding bark, the new embedding should give more weights to the word tree in the first sentence and dog in the second one in order to encode this context. 

In [2]:
class SelfAttention(nn.Module):
    def __init__(
        self,
        input_dim,
        n_channels,
    ):
        super().__init__()
        self.q = nn.Linear(input_dim, n_channels)
        self.k = nn.Linear(input_dim, n_channels)
        self.v = nn.Linear(input_dim, n_channels)

    def forward(self, input):
        # (N, input_dim) . (input_dim, qk_channels) -> (N, qk_channels)
        query = self.q(input)
        # (N, input_dim) . (input_dim, qk_channels) -> (N, qk_channels)
        key = self.k(input)
        # (N, input_dim) . (input_dim, v_channels) -> (N, v_channels)
        value = self.v(input)

        scale = 1.0 / math.sqrt(query.size(-1))
        # (N, qk_channels) . (qk_channels, N) -> (N, N)
        scores = torch.bmm(query, key.transpose(-1, -2)) * scale
        print(f"Attention score shape: {scores.shape}")
        weights = F.softmax(scores, dim=-1)
        # (N, N) . (N, v_channels) -> (N, v_channels)
        return torch.bmm(weights, value)

Let's apply the self-attention module to a tensor of shape (1, 3, 5). We will assume the input is a sentence containing three words, each represented by an embedding token of length 5. 

In [3]:
x_embed = torch.ones(1, 3, 5)
print(f"Input shape: {x_embed.shape}")
self_attn = SelfAttention(5, 3)
attn_out = self_attn(x_embed)
print(f"Output shape {attn_out.shape}")

Input shape: torch.Size([1, 3, 5])
Attention score shape: torch.Size([1, 3, 3])
Output shape torch.Size([1, 3, 3])


We can see the attention score has a shape of (1, 3, 3) and the output shape is (1, 3, 5). Let's now apply the self-attention mechanism to an input containing 10 words and see how it impacts the size of attention score tensor and the output.

In [4]:
x = torch.ones(1, 10, 5)
print(f"Input shape: {x.shape}")
self_attn = SelfAttention(5, 3)
attn_out = self_attn(x)
print(f"Output shape {attn_out.shape}")

Input shape: torch.Size([1, 10, 5])
Attention score shape: torch.Size([1, 10, 10])
Output shape torch.Size([1, 10, 3])


For an input of shape (1, 10, 5), the new attention score shape is (1, 10, 10). So as you can see the attention score scale quadratically based on the input. Same thing happen for the second dot product between the weights and the values. And the output shape is (1, 10, 3), so it increases linearly. Later we will see that the output of the self-attention module is fed into an MLP layer. Because the output of the self-attention module scales linearly based on the input, the MLP layer will also scale linearly based on the input.

Ok, so let's learn how Perceiver IO is addressing the transformer scalability issue while maintaining its generality property.

### Cross-Attention

Perceiver IO introduces the idea of using a cross attention module to encode the input into a latent space. To do so, instead of passing the input to the query, they use a latent variable which has smaller dimension than the input. You can think about this a latent variable as the learned initial stated for RNN.

<center><img src="./img/cross-attention.png" width="400" /></center>

You can find other examples of cross-attention for computer vision in the papers: End-to-End Object Detection with Transformers ([Carion et al, 2020](https://arxiv.org/abs/2005.12872)) and Object-Centric Learning with Slot Attention ([Locatello et al, 2020](https://arxiv.org/abs/2006.15055))


In [5]:
class CrossAttention(nn.Module):
    def __init__(
        self,
        input_kv_dim,
        input_q_dim,
        qk_channels,
        v_channels,
    ):
        super().__init__()
        self.q = nn.Linear(input_q_dim, qk_channels)
        self.k = nn.Linear(input_kv_dim, qk_channels)
        self.v = nn.Linear(input_kv_dim, v_channels)

    def forward(self, input_kv, input_q):
        # (M, input_q_dim) . (input_q_dim, qk_channels) -> (N, qk_channels)
        query = self.q(input_q)
        # (N, input_kv_dim) . (input_kv_dim, qk_channels) -> (N, qk_channels)
        key = self.k(input_kv)
        # (N, input_kv_dim) . (input_kv_dim, v_channels) -> (N, v_channels)
        value = self.v(input_kv)

        scale = 1.0 / math.sqrt(query.size(-1))
        # (M, qk_channels) . (qk_channels, N) -> (M, N)
        scores = torch.bmm(query, key.transpose(-1, -2)) * scale
        print(f"Attention score shape: {scores.shape}")
        weights = F.softmax(scores, dim=-1)
        # (M, N) . (N, v_channels) -> (M, v_channels)
        return torch.bmm(weights, value)

To demonstrate it's cheaper to encode the input to a latent space instead of using self-attention, let's go through concrete examples.

In [6]:
x = torch.ones(1, 10, 5)
latent = nn.Parameter(torch.randn(1, 2, 3))
print(f"Input shape: {x.shape}")
print("Latent shape: {latent.shape}")
self_attn = CrossAttention(5, 3, 4, 3)
attn_out = self_attn(x, latent)
print(f"Output shape {attn_out.shape}")

Input shape: torch.Size([1, 10, 5])
Latent shape: {latent.shape}
Attention score shape: torch.Size([1, 2, 10])
Output shape torch.Size([1, 2, 3])


Let's say we have our previous input of shape (1, 10, 5) and a new latent variable of shape (1, 2, 3). The attention score has now a shape of (1, 2, 10). So instead of the scaling quadratically based on the input size, the score scale linearly based on the size of the latent variable which can be controlled. The second matrix of the cross-attention module also scale linearly. But more importantly the output which will be used by the MLP layer also scales linearly based on the latent size instead of the input size. The output has now a shape of (1, 2, 3) instead of (1, 10, 3). 

Excellent! We now have implemented a cross-attention module which scales linearly based on the latent variable which will allow us to compute on larger inputs such as images or longer text sequences. Later, you will see that the Pereiver IO architecture also uses a self-attention module on the latent array. So let's create a general Attention module which can be parametrized to become a self-attention or cross-attention module based on our need.

In [7]:
class AttentionHead(nn.Module):
    def __init__(
        self,
        is_cross_attention,
        input_kv_dim,
        input_q_dim,
        qk_channels_per_head,
        v_channels_per_head,
        attention_prob_dropout_prob=0.1,
    ):
        super().__init__()
        self.is_cross_attention = is_cross_attention
        if not is_cross_attention:
            input_kv_dim = input_q_dim
        
        self.q = nn.Linear(input_q_dim, qk_channels_per_head)
        self.k = nn.Linear(input_kv_dim, qk_channels_per_head)
        self.v = nn.Linear(input_kv_dim, v_channels_per_head)
        self.dropout = nn.Dropout(attention_prob_dropout_prob)

    def forward(self, input_kv, input_q):
        query = self.q(input_q)
        
        if self.is_cross_attention & (input_kv is not None):
            key = self.k(input_kv)
            value = self.v(input_kv)
        else:
            key = self.k(input_q)
            value = self.v(input_q)

        scale = 1.0 / math.sqrt(query.size(-1))
        scores = torch.bmm(query, key.transpose(-1, -2)) * scale
        weights = F.softmax(scores, dim=-1)
        weights = self.dropout(weights)
        return torch.bmm(weights, value)

Until now we have implemented a single head attention but most transformers multi-head attention. The benefits of having multiple heads, it's that each head can focus on different aspects of an image (edges, colors, etc.) or a sentence instead of a single aspect. 

We can simply create multi-head attention layer by instantiating several heads and concatenating their outputs. Usually we also apply a linear layer to its final output. Note that it's possible to avoid instantiating an `AttentionHead` for each head and concatenating their output. It's possible instead to have the linear layers with number of output features equal to `number of channels per head * number of heads` for the query, key and the value. Then reshape the output from the query, key and value to `(batch_size, num_heads, time (N or M), number of channels per head)`. As an example, you can check the [transformers implementation](https://github.com/huggingface/transformers/blob/bd469c40659ce76c81f69c7726759d249b4aef49/src/transformers/models/perceiver/modeling_perceiver.py#L259).

In [8]:
class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        is_cross_attention,
        input_kv_dim,
        input_q_dim,
        qk_channels,
        v_channels,
        num_heads,
    ):
        super().__init__()
        
        qk_channels_per_head = qk_channels // num_heads
        v_channels_per_head = v_channels // num_heads

        self.heads = nn.ModuleList(
            [
                AttentionHead(
                    is_cross_attention,
                    input_kv_dim,
                    input_q_dim,
                    qk_channels_per_head,
                    v_channels_per_head,
                )
                for _ in range(num_heads)
            ]
        )

        self.linear = nn.Linear(v_channels, v_channels)

    def forward(self, input, latent_embedding):
        x = torch.cat([h(input, latent_embedding) for h in self.heads], dim=-1)
        return self.linear(x)

In [9]:
self_attention = MultiHeadAttention(True, 5, 3, 8, 6, 2)
self_attention_multi_head = self_attention(x, latent)
print(f"Multi-head attention input shape: {x.shape}")
print(f"Multi-head attention output shape: {self_attention_multi_head.shape}")


Multi-head attention input shape: torch.Size([1, 10, 5])
Multi-head attention output shape: torch.Size([1, 2, 6])


Voilà, we have our multi-head attention layer which can become a self-attention or cross-attention by setting the `is_cross_attention` parameter. Since we have implemented the main building block, we are ready to implement the full perceiver architecture!

## Perceiver Architecture

The perceiver architecture consists of three main build blocks: encoder, processor and decoder. The input gets first encoded into a latent array, then the latent representation gets refined via several processing layers. Finally, the latent gets decoded into an output. As you can see on the diagram below, the encoder and decoder are using a cross-attention module, and the processor is using a self-attention module. WHat's amazing with the perceiver architecture is that it can handle any modalities thanks to the encoder and decoder modules. And the size of the inputs and outputs is not a problem anymore since they both use a cross-attention module where the latent size is independent of the input size. 

<center><img src="./img/perceiver-architecture.png" width="400" /></center>

Let's implement the encoder!

### Encoder

The Perceiver encoder module is very similar to the transformer encoder you could find for in Bert except it's using cross-attention with a latent variable. In addition to the cross-attention module we just need to add:

- A multi-perceptron module: two fully connected layer processing each latent vector independently, a GELU activation and dropout layer.
- Two layer normalization layers  
- Two skip connections

Ok let's see, how we can combine these components to build our encoder.

In [10]:
class MLP(nn.Module):
    def __init__(self, input_size, widening_factor, dropout_prob=0.0):
        super().__init__()
        self.dense1 = nn.Linear(input_size, input_size * widening_factor)
        self.dense2 = nn.Linear(input_size * widening_factor, input_size)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout()

    def forward(self, x):
        x = self.dense1(x)
        x = self.gelu(x)
        x = self.dense2(x)
        return self.dropout(x)


class PerceiverEncoder(nn.Module):
    def __init__(
        self,
        input_dim,
        latent_embedding_dim,
        qk_channels,
        v_channels,
        num_heads,
        widening_factor,
    ):
        super().__init__()

        self.layer_norm_1 = nn.LayerNorm(input_dim)
        self.layer_norm_2 = nn.LayerNorm(latent_embedding_dim)

        self.attention = MultiHeadAttention(
            is_cross_attention=True,
            input_kv_dim=input_dim,
            input_q_dim=latent_embedding_dim,
            qk_channels=qk_channels,
            v_channels=v_channels,
            num_heads=num_heads,
        )

        self.mlp = MLP(v_channels, widening_factor=widening_factor)

    def forward(self, input, latent):
        input_norm = self.layer_norm_1(input)
        latent_embedding_norm = self.layer_norm_2(latent)
        x_qkv = self.attention(input_norm, latent_embedding_norm)
        x_qkv = x_qkv + latent
        x_qkv = x_qkv + self.mlp(latent_embedding_norm)
        return x_qkv

In [11]:
perceiver_encoded = PerceiverEncoder(6, 6, 8, 6, 2, 2)
input = torch.ones(size=(1, 3, 6))
latent = nn.Parameter(torch.randn(1, 2, 6))
encoder_output = perceiver_encoded(input, latent)
print(f"Input shape: {input.shape}")
print(f"Latent variable shape: {latent.shape}")
print(f"Encoder output shape: {encoder_output.shape}")

Input shape: torch.Size([1, 3, 6])
Latent variable shape: torch.Size([1, 2, 6])
Encoder output shape: torch.Size([1, 2, 6])


Excellent, we have implemented successfully our encoder layer. In fact, the processor and the decoder use the exact same ingredients, except the processor will use a self-attention. So let's refactor our encoder to use PerceiverLayer which can be parametrized to use a cross-attention or self-attention module.

In [12]:
class PerceiverLayer(nn.Module):
    def __init__(
        self,
        input_kv_dim,
        input_q_dim,
        qk_channels,
        v_channels,
        num_heads,
        widening_factor,
        is_cross_attention,
    ):
        super().__init__()

        if input_kv_dim is None:
            input_kv_dim = input_q_dim

        self.layer_norm_1 = nn.LayerNorm(input_kv_dim)
        self.layer_norm_2 = nn.LayerNorm(input_q_dim)

        self.attention = MultiHeadAttention(
            is_cross_attention=is_cross_attention,
            input_kv_dim=input_kv_dim,
            input_q_dim=input_q_dim,
            qk_channels=qk_channels,
            v_channels=v_channels,
            num_heads=num_heads,
        )

        self.mlp = MLP(v_channels, widening_factor=widening_factor)

    def forward(self, input_kv, input_q):
        input_kv_norm = self.layer_norm_1(input_kv)
        input_q_norm = self.layer_norm_2(input_q)
        x_qkv = self.attention(input_kv_norm, input_q_norm)
        x_qkv = x_qkv + input_q
        x_qkv = x_qkv + self.mlp(input_q_norm)
        return x_qkv


class PerceiverEncoder(nn.Module):
    def __init__(
        self,
        input_dim,
        latent_dim,
        qk_channels,
        v_channels,
        num_heads,
        widening_factor,
    ):
        super().__init__()

        self.encoder = PerceiverLayer(
            input_kv_dim=input_dim,
            input_q_dim=latent_dim,
            qk_channels=qk_channels,
            v_channels=v_channels,
            num_heads=num_heads,
            widening_factor=widening_factor,
            is_cross_attention=True,
        )

    def forward(self, input, latent_embeddings):
        return self.encoder(input_kv=input, input_q=latent_embeddings)

In [13]:
perceiver_encoded = PerceiverEncoder(6, 6, 8, 6, 2, 2)
input = torch.ones(size=(1, 3, 6))
latent = nn.Parameter(torch.randn(1, 2, 6))
perceiver_encoded(input, latent).shape

encoder_output = perceiver_encoded(input, latent)
print(f"Input shape: {input.shape}")
print(f"Latent variable shape: {latent.shape}")
print(f"Encoder output shape: {encoder_output.shape}")

Input shape: torch.Size([1, 3, 6])
Latent variable shape: torch.Size([1, 2, 6])
Encoder output shape: torch.Size([1, 2, 6])


Our PerceiverEncoder layer now looks much simpler with our new generic PerceiverLayer.

### Processor

Our Processor layer is going to be responsible of refining the latent representation we obtained from the encoder. So this layer has a single input which is the latent array and we apply a self-attention module in conjunction with an MLP layer, layer normalization and skip connection (i.e `PerceiverLayer` configured as self-attention).

In [14]:
class PerceiverProcessor(nn.Module):
    def __init__(self, latent_dim, qk_channels, v_channels, num_heads, widening_factor):
        super().__init__()

        self.processor = PerceiverLayer(
            input_kv_dim=None,
            input_q_dim=latent_dim,
            qk_channels=qk_channels,
            v_channels=v_channels,
            num_heads=num_heads,
            widening_factor=widening_factor,
            is_cross_attention=False,
        )

    def forward(self, latent):
        return self.processor(input_kv=latent, input_q=latent)

In [15]:
latent = nn.Parameter(torch.randn(1, 2, 6))
perceiver_processor = PerceiverProcessor(6, 8, 6, 2, 2)
processor_output = perceiver_processor(latent)

print(f"Latent array (input) shape: {latent.shape}")
print(f"Processor output shape: {processor_output.shape}")

Latent array (input) shape: torch.Size([1, 2, 6])
Processor output shape: torch.Size([1, 2, 6])


We are almost ready to combine everything to have a complete architecture. We just need implement the decoder module.

### Decoder

The perceiver decoder will map the latent array to an output array. To do so, we can query the latent array with a query vector. Note that's this query vector can be hand-designed, or learned embeddings or a function of the input. For this blog post, we will use learned embeddings. To query the latent array returned by the processor, we simply need to compute the cross attention between a learned query variable and the latent array. The query should have the same number of elements as the desired output.

In [16]:
class PerceiverDecoder(nn.Module):
    def __init__(
        self,
        num_output_channels,
        latent_dim,
        query_dim,
        qk_channels,
        v_channels,
        num_heads,
        widening_factor,
    ):
        super().__init__()

        self.decoder = PerceiverLayer(
            input_kv_dim=latent_dim,
            input_q_dim=query_dim,
            qk_channels=qk_channels,
            v_channels=v_channels,
            num_heads=num_heads,
            widening_factor=widening_factor,
            is_cross_attention=True,
        )

        self.dense = nn.Linear(query_dim, num_output_channels)

    def forward(self, latent, query):
        attn_output = self.decoder(latent, query)
        logit = self.dense(attn_output)
        return logit

Let's run the decoder with an expected output of 10 elements. For exmple, for an image classification task with 10 potential different label.

In [17]:
latent = torch.ones(1, 3, 5)
query_variable = nn.Parameter(torch.randn(1, 1, 10))
q_dim = 10
kv_dim = 5
num_output_channels = 10
qk_channels = q_dim
v_channels = qk_channels
num_heads = 1

perceiver_decoder = PerceiverDecoder(
    num_output_channels, kv_dim, q_dim, qk_channels, v_channels, num_heads, 1
)

perceiver_decoder_output = perceiver_decoder(latent, query_variable)
print(f"Latent array shape: {latent.shape}")
print(f"Query variable shape: {query_variable.shape}")
print(f"Perceiver decoder output shape: {perceiver_decoder_output.shape}")

Latent array shape: torch.Size([1, 3, 5])
Query variable shape: torch.Size([1, 1, 10])
Perceiver decoder output shape: torch.Size([1, 1, 10])


As expected, the output of the decoder returns 10 elements matching the number of labels.

### Perceiver IO = Encoder + Processor + Decoder

We have now all the building block to create a complete PerceiverIO architecture. As discussed earlier, the latent variable for the encoder and the query variable for the decoder are learned. So we can create a `LearnedEmbedding` layer to instantiate a latent and query variables. Otherwise, it's straight forward. the encoder takes as an input the input array and the latent variable then returned a latent array. Then this latent array is fed into the processor. And finally the latent arrays is query by the decoder with a learned query. The output of our decoder will be the logits we will use to compute our loss when training the image and text classification task.

In [18]:
class LearnedEmbeddings(nn.Module):
    def __init__(self, index_dim, num_channels=128):
        super().__init__()
        self.index_dim = index_dim
        self.num_channels = num_channels
        self.learned_embeddings = nn.Parameter(torch.randn(index_dim, num_channels))

    def forward(self, batch_size):
        return self.learned_embeddings.expand(batch_size, -1, -1)


class PerceiverIO(nn.Module):
    def __init__(
        self,
        n_labels,
        input_dim,
        latent_embedding_dim,
        query_dim,
        qk_channels,
        v_channels,
        num_latents,
        num_heads,
        widening_factor,
        input_processor=None,
    ):
        super().__init__()

        self.input_processor = input_processor if input_processor else nn.Identity()

        self.latent_embeddings = LearnedEmbeddings(num_latents, latent_embedding_dim)
        self.query_embeddings = LearnedEmbeddings(1, query_dim)

        self.encoder = PerceiverEncoder(
            input_dim=input_dim,
            latent_dim=latent_embedding_dim,
            qk_channels=qk_channels,
            v_channels=v_channels,
            num_heads=num_heads,
            widening_factor=widening_factor,
        )

        self.processor = PerceiverProcessor(
            latent_dim=latent_embedding_dim,
            qk_channels=qk_channels,
            v_channels=v_channels,
            num_heads=num_heads,
            widening_factor=widening_factor,
        )

        self.decoder = PerceiverDecoder(
            num_output_channels=n_labels,
            latent_dim=latent_embedding_dim,
            query_dim=query_dim,
            qk_channels=qk_channels,
            v_channels=v_channels,
            num_heads=num_heads,
            widening_factor=widening_factor,
        )

    def forward(self, inputs):
        batch_size = inputs.shape[0]

        latent_embeddings = self.latent_embeddings(batch_size)
        query_embeddings = self.query_embeddings(batch_size)

        inputs = self.input_processor(inputs)
        encoder_output = self.encoder(inputs, latent_embeddings)
        processor_output = self.processor(encoder_output)
        logits = self.decoder(processor_output, query_embeddings)

        return logits[:, 0, :]

Voilà, we have now our full Perceiver IO architecture. Let's now dive into how in the Perceiver IO paper they pre-process an image or text to obtain the input arrays which will be fed into the encoder.

## Image Classification

### Pre-Processing

To pre-process an image, in the Perceiver IO paper, they use different techniques involving 2D Fourier position embeddings with 2D convolution or flatten pixel values. But for this example, we will use another approach used in the paper where they flatten the pixels by applying a 1D convolution and add learned absolute positional 1D position embeddings. Based on the paper, the other approaches provide better results, but what's unique with this pre-processing approach it's that it doesn't provide any information about the 2D image structure.

Let's assume we have a tensor representing a batch of images of shape (batch_size, 3, 32, 32). So each image have a width and height of 32 and 3 channels. To pre-process this image we will:

- Apply a Conv1D layer to increase the number of channels, for example to 256. So now, let's say our output has a shape of (batch_size, 256, 32, 32)
- The Conv1D will return a tensor with channel first. We want to make it channel last. Our new shape is (batch_size, 32, 32, 256)
- We flatten the height with the width of the image so now we have a shape of (batch_size, 1024, 256)
- We instantiate a trainable 1D position embedding for each pixel with for example 256 channels. So the shape of the embedding is (batch_size, 1024, 256)
- We concatenate the output of the Conv1D layer with the trainable 1D position embedding based on the last dimension. So the final shape will be (batch_size, 1024, 512) 

That's it! Let's implement it.

In [19]:
class PositionalEncoding(nn.Module):
    def __init__(self, index_dim, num_channels=128):
        super().__init__()
        self.index_dim = index_dim
        self.num_channels = num_channels
        self.position_embeddings = nn.Parameter(torch.randn(index_dim, num_channels))

    def forward(self, batch_size):
        return self.position_embeddings.expand(batch_size, -1, -1)


class ImagePreProcessor(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        spatial_downsample,
        position_encoding_index_dim,
        position_encoding_out_channels,
    ):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.spatial_downsample = spatial_downsample
        self.postion_encoding_index_dim = position_encoding_index_dim

        self.conv1d = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=(1, 1),
            stride=(spatial_downsample, spatial_downsample),
        )

        self.pos_enc = PositionalEncoding(
            position_encoding_index_dim, num_channels=position_encoding_out_channels
        )

    def forward(self, inputs):
        batch_size = inputs.shape[0]
        # Increase the number of channels while keeping same height and width
        inputs_post_conv1d = self.conv1d(inputs)
        # Make channel last
        inputs = torch.moveaxis(inputs, 1, -1)
        # Flatten from (batch_size, img_size, img_size, num_channels) to (batch_size, img_size*img_size, num_channels)
        inputs_post_conv1d = torch.reshape(
            inputs_post_conv1d, [batch_size, np.prod(inputs.shape[1:-1]), -1]
        )
        # Instantiate learned 1D positional embeddings
        pos_encoded = self.pos_enc(batch_size)
        # Concat inputs post conv1d with 1D positional embeddings
        return torch.cat([inputs_post_conv1d, pos_encoded], dim=-1)

We can validate we get the expected output shape using an image of shape (3, 32, 32) with Conv1D output channels of 256 and a 1D positional embeddings with 256 channels.

In [20]:
x = torch.ones((1, 3, 32, 32))
image_processor = ImagePreProcessor(3, 256, 1, 32**2, 256)
processed_image = image_processor(x)

print(f"Image shape: {x.shape}")
print(f"Processed image shape: {processed_image.shape}")

Image shape: torch.Size([1, 3, 32, 32])
Processed image shape: torch.Size([1, 1024, 512])


Excellent, as expected our tensor has a shape of (1, 1024, 512). That's all we need to pre-process our images.

We can finally train an image classification model using the Perceiver architecture!

### Training

To demonstrate we can use the Perceiver IO architecture to classify images, we will use the MNIST dataset. Here, we are using a simple dataset to show the generality of the architecture. But note in the paper they were able to able to achieve strong results on the ImageNet dataset especially when the model has been pre-trained and the pre-processor is using 2D convolution and MaxPooling layers.

In [63]:
import torch
from torch.optim import AdamW
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from torchvision import datasets
from torchvision import transforms

from sklearn.metrics import accuracy_score

device = "cuda" if torch.has_cuda else "cpu"

transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

training_data = datasets.MNIST(
    root="data", train=True, download=True, transform=transform
)

eval_data = datasets.MNIST(root="data", train=False, download=True, transform=transform)


train_dataloader = DataLoader(training_data, batch_size=128, shuffle=True)
eval_dataloader = DataLoader(eval_data, batch_size=128, shuffle=True)


img_size = 28
img_channels = 1
img_processor_output_channels = 32
img_processor_pos_encoding_out_channels = 32
n_labels = 10
input_dim = 64
latent_embedding_dim = 128
num_latents = 258
query_dim = 128
qk_channels = query_dim
v_channels = qk_channels
num_heads = 1
widening_factor = 1
spatial_downsample = 1


image_processor = ImagePreProcessor(
    img_channels,
    img_processor_output_channels,
    spatial_downsample,
    img_size**2,
    img_processor_pos_encoding_out_channels,
)

model = PerceiverIO(
    n_labels,
    input_dim,
    latent_embedding_dim,
    query_dim,
    qk_channels,
    v_channels,
    num_latents,
    num_heads,
    widening_factor,
    input_processor=image_processor,
)

model = model.to(device)

Let's define a generic training and evaluation loop we can reuse later to classify text.

In [64]:
def train(model, loss_fn, device, train_loader, optimizer, epoch, log_interval=50):
    model.train()
    for batch_idx, (x, y) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = loss_fn(logits, y)
        loss.backward()
        optimizer.step()

        pred = logits.argmax(-1).cpu().numpy()
        acc = accuracy_score(y_true=y.cpu().numpy(), y_pred=pred)

        if batch_idx % log_interval == 0:
            print(
                f"Train Epoch: {epoch} [{batch_idx * len(x)}/{len(train_loader.dataset)} ({(100. * batch_idx  * len(x) / len(train_loader.dataset)):.2f}%)]\t"
                f"Loss: {loss.item():.2f} - Accuracy: {acc:.2f}"
            )


def eval(model, loss_fn, device, eval_loader):
    model.eval()
    eval_loss = 0
    correct = 0
    with torch.no_grad():
        for x, y in eval_loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            eval_loss += loss_fn(logits, y).item()
            pred = logits.argmax(-1, keepdim=True)
            correct += pred.eq(y.view_as(pred)).sum().item()

    eval_loss /= len(eval_loader.dataset)

    print(
        f"\nTest set: Average loss: {eval_loss:.2f}, Accuracy: "
        f"{correct}/{len(eval_loader.dataset)} ({(100. * correct/len(eval_loader.dataset)):.2f}%)\n"
    )

Traning time!

In [65]:
EPOCHS = 20
loss_fn = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=4e-3, weight_decay=1e-1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.7)

for e in range(EPOCHS):
    train(model, loss_fn, device, train_dataloader, optimizer, e, log_interval=100)
    eval(model, loss_fn, device, eval_dataloader)
    scheduler.step()


Test set: Average loss: 0.00, Accuracy: 9143/10000 (91.43%)


Test set: Average loss: 0.00, Accuracy: 9456/10000 (94.56%)


Test set: Average loss: 0.00, Accuracy: 9552/10000 (95.52%)


Test set: Average loss: 0.00, Accuracy: 9589/10000 (95.89%)


Test set: Average loss: 0.00, Accuracy: 9626/10000 (96.26%)


Test set: Average loss: 0.00, Accuracy: 9652/10000 (96.52%)


Test set: Average loss: 0.00, Accuracy: 9698/10000 (96.98%)


Test set: Average loss: 0.00, Accuracy: 9708/10000 (97.08%)


Test set: Average loss: 0.00, Accuracy: 9725/10000 (97.25%)


Test set: Average loss: 0.00, Accuracy: 9730/10000 (97.30%)


Test set: Average loss: 0.00, Accuracy: 9747/10000 (97.47%)


Test set: Average loss: 0.00, Accuracy: 9739/10000 (97.39%)


Test set: Average loss: 0.00, Accuracy: 9739/10000 (97.39%)


Test set: Average loss: 0.00, Accuracy: 9750/10000 (97.50%)


Test set: Average loss: 0.00, Accuracy: 9733/10000 (97.33%)


Test set: Average loss: 0.00, Accuracy: 9749/10000 (97.49%)


Test se

After training our model we acheived X accuracy. If we were spedind more time on tuning the model, use a 2D Fourier or 2D convolutional pre-processor it's very likely we could achieve a higher accuracy. But nonetheless we were able to classify images with the Perceiver model with zero information about the 2D structure of the images.

Let's now explore how we can use the same architecture for a text classification task. 

## Text Classification

### Prep-Processing

For this example, we will use the [AG_NEWS](https://pytorch.org/text/stable/datasets.html#ag-news) dataset which contains a corpus of news articles. Each article is classified as one of the following category: World, Sports, Business or Sci/Tech.

With Perceiver IO, tokenization is extremely simple, you just need to convert the string to raw UTF-8 bytes. You don't need to apply more sophisticated tokenizer such as WordPeice or SentencePiece or BPE etc. Then for the numericalization process, we will just convert each byte to a byte ID. And finally, we will pad the sequence to a max sequence length. Perceiver IO aim to get rid of tokenizers is that they tend to perform less well on rare words and they don't transfer well from one language to another. 

In the paper implementation, they have some reserved tokens such as [BOS], [EOS], [SEP] etc. to represent the beginning of the sentence, end of the sentence but for simplicity, we will just convert the sentence to raw bytes then ID. 

In [56]:
class BytesTokenizer:
    def __init__(self):
        pass

    def to_int(self, inputs):
        if isinstance(inputs, str):
            inputs = inputs.encode("utf-8")
        encoded = torch.frombuffer(inputs, dtype=torch.uint8).to(torch.int32)
        return encoded.to(torch.int32)

    @property
    def vocab_size(self):
        return 256

We can validate the tokenizer is working properly with an example.

In [57]:
input = "Hello Hello"
tokenizer = BytesTokenizer()
tokenized_input = tokenizer.to_int(input)
print(f"Sentence: {input}")
print(f"Tokenized sentence: {tokenized_input}")

Sentence: Hello Hello
Tokenized sentence: tensor([ 72, 101, 108, 108, 111,  32,  72, 101, 108, 108, 111],
       dtype=torch.int32)


We can see the number of tokens matches the number of characters in the original sentence and also the world Hello is repeating twice.

Next step is to implement a padding function so all the sequences in the batch will have the same length. If the length of the sequence exceeds the maximum sequence length, it will be truncated. 

In [58]:
def pad(max_sequence_length, inputs, pad_value=0):
    input_len = inputs.shape[1]
    # Truncate sequence if exceeds max sequence length
    if input_len > max_sequence_length:
        _text = _text[:max_sequence_length]
    # Pad sequence with pad value if shorter the max sequence length
    pad_len = max_sequence_length - input_len
    padded_input = torch.nn.functional.pad(inputs, pad=((0, pad_len)), value=pad_value)
    return padded_input

Now that we have our tokenizer and padding function, we can create a collate_batch function. This function will be used by our `DataLoader` to tokenize and pad each sentence contained in the batch. 

In [59]:
MAX_SEQUENCE_LEN = 1024

def collate_batch(batch):
    label_list, text_list = [], []

    for (_label, _text) in batch:
        # Convert labels [1, 2, 3, 4] to [0, 1, 2, 3] for loss function
        label_processed = _label - 1
        label_list.append(label_processed)
        # Tokenize and numericalize sentence
        tokenized_text = torch.unsqueeze(tokenizer.to_int(_text), 0)
        # Pad and truncate the tokenized sentence to match MAX_SEQUENCE_LEN
        padded_text = pad(MAX_SEQUENCE_LEN, tokenized_text)
        text_list.append(padded_text)
    
    label_list = torch.tensor(label_list, dtype=torch.int64)
    text_list = torch.cat(text_list, dim=0)
    return text_list.to(device), label_list.to(device)

We are just missing one piece of the puzzle before starting to train our model which is a text processor layer. This processing layer is very similar to the one you could find in a Bert model. It consists of:

- Convert each token in the sentence to embeddings
- Represent also the position of each token as embeddings
- Add the tokens embeddings to the function embeddings

In [60]:
class TextProcessor(nn.Module):
    def __init__(self, vocab_size, d_model, max_position_embeddings=2048):
        super().__init__()
        self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
        self.position_embeddings = nn.Embedding(max_position_embeddings, d_model)

    def forward(self, inputs):
        embeddings = self.embeddings(inputs)
        seq_len = inputs.shape[1]
        position_ids = torch.arange(0, seq_len, device=inputs.device)
        embeddings = embeddings + self.position_embeddings(position_ids)
        return embeddings

### Training

We can finally train our model to classify news articles. First we create a Dataloader for the training and evaluation datasets using the `collate_batch` function previously defined. 

In [64]:
from torchtext.datasets import AG_NEWS
from torchtext.data.functional import to_map_style_dataset

train_iter = AG_NEWS(split="train")
eval_iter = AG_NEWS(split="test")

train_dataset = to_map_style_dataset(train_iter)
eval_dataset = to_map_style_dataset(eval_iter)

train_dataloader = DataLoader(
    train_dataset, batch_size=32, shuffle=True, collate_fn=collate_batch
)
eval_dataloader = DataLoader(
    eval_dataset, batch_size=32, shuffle=True, collate_fn=collate_batch
)

Then we instantiate the Perceiver model. It's the exact same model we used for the image classification task except we use a text processor layer, set the number of labels to 4 and adjust the input_dim parameter to the embedding length for each token.

In [65]:
vocab_size = tokenizer.vocab_size
n_labels = 4
input_dim = 1024
latent_embedding_dim = 64
query_dim = 64
qk_channels = query_dim
v_channels = qk_channels
num_latents = 64
num_heads = 1
widening_factor = 1

text_processor = TextProcessor(vocab_size, input_dim)

model = PerceiverIO(
    n_labels,
    input_dim,
    latent_embedding_dim,
    query_dim,
    qk_channels,
    v_channels,
    num_latents,
    num_heads,
    widening_factor,
    input_processor=text_processor,
)

model = model.to(device)

To train the model we can reuse the training and evaluation loop for the previous tasks.

In [66]:
EPOCHS = 20
loss_fn = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=1e-4)

for e in range(EPOCHS):
    train(model, loss_fn, device, train_dataloader, optimizer, e, log_interval=1000)
    eval(model, loss_fn, device, eval_dataloader)



85% accuracy on the evaluation set, not bad. With the same architecture we were able to solve to distinc tasks. 

## Conclusion

## Ressources

- [Perceiver IO: A General Architecture for Structured Inputs & Outputs (Jaegle et al, 2021)](https://arxiv.org/abs/2107.14795)
- [CS25 I Stanford Seminar - DeepMind's Perceiver and Perceiver IO: new data family architecture](https://www.youtube.com/watch?v=wTZ3o36lXoQ)
- [Perceiver IO: a scalable, fully-attentional model that works on any modality](https://huggingface.co/blog/perceiver)
- [Perceiver IO implementation from DeepMind](https://github.com/deepmind/deepmind-research/tree/master/perceiver)
- [Perceiver IO implementation from HuggingFace Transformers](https://github.com/huggingface/transformers/tree/main/src/transformers/models/perceiver)
- [Natural Language Processing with Transformers book](https://learning.oreilly.com/library/view/natural-language-processing/9781098136789/)
