# Perceiver

Perceiver is a transformer-based model that uses both cross attention and self-attention layers to generate representations of multimodal data. A latent array is used to extract information from the input byte array using top-down or feedback processing i.e. attention to a byte array is controlled by the latent array which already has information about the byte array from the previous layer.


# Code

## Installation

A PyTorch implementation of the model was made available by authors. We can directly install this implementation from pip using the following command.

In [None]:
!python -m pip install pip --upgrade --user -q --no-warn-script-location
!python -m pip install numpy pandas seaborn matplotlib scipy statsmodels sklearn nltk gensim perceiver-pytorch torch tensorflow keras --user -q --no-warn-script-location

import IPython
IPython.Application.instance().kernel.do_shutdown(True)

## Data Loading

Let’s use the CIFAR10 image dataset for training the model. Following is the boilerplate code for loading the data. The code mentioned below, is referenced to official pytorch tutorial.

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# functions to show an image


def imshow(img):
    print(img.shape)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
print(images.shape)
# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

## Model

Model architecture can be loaded using.

In [None]:
import torch
from perceiver_pytorch import Perceiver

model = Perceiver(
    input_channels = 3,          # number of channels for each token of the input
    input_axis = 2,              # number of axis for input data (2 for images, 3 for video)
    num_freq_bands = 6,          # number of freq bands, with original value (2 * K + 1)
    max_freq = 10.,              # maximum frequency, hyperparameter depending on how fine the data is
    depth = 6,                   # depth of net
    num_latents = 32,           # number of latents, or induced set points, or centroids. different papers giving it different names
    latent_dim = 128,            # latent dimension
    cross_heads = 1,             # number of heads for cross attention. paper said 1
    latent_heads = 2,            # number of heads for latent self attention, 8
    cross_dim_head = 8,
    latent_dim_head = 8,
    num_classes = 10,          # output number of classes
    attn_dropout = 0.,
    ff_dropout = 0.,
    weight_tie_layers = False    # whether to weight tie layers (optional, as indicated in the diagram)
)


## Training Loop

In [None]:
import torch.optim as optim

criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
dev=torch.device("cpu")
model.to(dev)

In [None]:
for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs, labels = inputs.to(dev), labels.to(dev)
        # zero the parameter gradients
        optimizer.zero_grad()
        
        # forward + backward + optimize
        outputs = model(inputs.permute(0,2,3,1))
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 1000 == 0:    # print every 1000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 1000))
            running_loss = 0.0

print('Finished Training')

In [None]:
inp=inputs[0].permute(1,2,0).unsqueeze(0)
inp=inp.to(dev)
inp.shape

In [None]:
plt.imshow(images.permute(0,2,3,1)[0])

In [None]:
preds=model(inp)
preds=preds.detach().cpu().numpy()

In [None]:
preds[0]

In [None]:
classes[np.argmax(preds[0])]

## Inference

In [None]:
def predict(inputs):
  preds=model(inputs)
  preds=preds.detach().cpu().numpy()
  labels=[classes[np.argmax(i)] for i in preds]
  return labels
predict(inputs.permute(0,2,3,1).to(dev))