In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import math

import torchvision.datasets as dset
import torchvision.transforms as T
from torch.nn import functional as F

In [2]:
NUM_TRAIN = 49000

# The torchvision.transforms package provides tools for preprocessing data
# and for performing data augmentation; here we set up a transform to
# preprocess the data by subtracting the mean RGB value and dividing by the
# standard deviation of each RGB value; we've hardcoded the mean and std.
transform = T.Compose([
                T.ToTensor(),
                T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ])

# We set up a Dataset object for each split (train / val / test); Datasets load
# training examples one at a time, so we wrap each Dataset in a DataLoader which
# iterates through the Dataset and forms minibatches. We divide the CIFAR-10
# training set into train and val sets by passing a Sampler object to the
# DataLoader telling how it should sample from the underlying Dataset.
cifar10_train = dset.CIFAR10('../data/CIFAR10', train=True, download=True,
                             transform=transform)
loader_train = DataLoader(cifar10_train, batch_size=64, 
                          sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))

cifar10_val = dset.CIFAR10('../data/CIFAR10', train=True, download=True,
                           transform=transform)
loader_val = DataLoader(cifar10_val, batch_size=64, 
                        sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, 50000)))

cifar10_test = dset.CIFAR10('../data/CIFAR10', train=False, download=True, 
                            transform=transform)
loader_test = DataLoader(cifar10_test, batch_size=64)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [3]:
USE_GPU = True
dtype = torch.float32 # We will be using float throughout this tutorial.

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Constant to control how frequently we print train loss.
print_every = 100
device

device(type='cpu')

In [4]:
def check_accuracy(loader, model):
    if loader.dataset.train:
        print('Checking accuracy on validation set')
    else:
        print('Checking accuracy on test set')   
    num_correct = 0
    num_samples = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))

In [5]:
def train(model, optimizer, epochs=1):
    """
    Train a model on CIFAR-10 using the PyTorch Module API.
    
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: Nothing, but prints model accuracies during training.
    """
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    for e in range(epochs):
        for t, (x, y) in enumerate(loader_train):
            model.train()  # put model to training mode
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            loss = F.cross_entropy(scores, y)

            # Zero out all of the gradients for the variables which the optimizer
            # will update.
            optimizer.zero_grad()

            # This is the backwards pass: compute the gradient of the loss with
            # respect to each  parameter of the model.
            loss.backward()

            # Actually update the parameters of the model using the gradients
            # computed by the backwards pass.
            optimizer.step()

            if t % print_every == 0:
                print('Iteration %d, loss = %.4f' % (t, loss.item()))
                check_accuracy(loader_val, model)
                print()

In [14]:
import torch
import torch.nn as nn
import math

In [15]:
def fourier_encode(x, max_freq, num_bands = 4, base = 2):
    x = x.unsqueeze(-1)
    device, dtype, orig_x = x.device, x.dtype, x

    scales = torch.logspace(0., math.log(max_freq / 2) / math.log(base), num_bands, base = base, device = device, dtype = dtype)
    scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]

    x = x * scales * math.pi
    x = torch.cat([x.sin(), x.cos()], dim=-1)
    x = torch.cat((x, orig_x), dim = -1)
    return x


class MultiHeadAttention(nn.Module):
    def __init__(self, query_dim, dim_head, num_heads=1, context_dim = None, dropout=0):
        super().__init__()
        embed_dim = dim_head * num_heads
        if context_dim is None:
            context_dim = query_dim

        self.H = num_heads
        self.E = embed_dim
        self.q_linear = nn.Linear(query_dim, embed_dim)
        self.k_linear = nn.Linear(context_dim, embed_dim)
        self.v_linear = nn.Linear(context_dim, embed_dim)
        self.o_linear = nn.Linear(embed_dim, query_dim)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, context = None, attn_mask=None):
        if context is None:
            context = query
        N, S, query_dim = query.shape
        N, T, context_dim = context.shape

        q = self.q_linear(query)
        k = self.k_linear(context)
        v = self.v_linear(context)
        q_split = q.view((N, S, self.H, int(self.E/self.H))).transpose(1, 2)
        k_split = k.view((N, T, self.H, int(self.E/self.H))).transpose(1, 2)
        v_split = v.view((N, T, self.H, int(self.E/self.H))).transpose(1, 2)
        a = torch.matmul(q_split, k_split.transpose(2, 3))/math.sqrt(self.E/self.H)
        if attn_mask is not None:
            a = a.masked_fill(~(attn_mask.type(torch.bool)), -math.inf)
        e = torch.nn.functional.softmax(a, dim=3)
        y = torch.matmul(e, v_split)
        output = self.o_linear(y.transpose(1, 2).reshape(N, S, self.E))
        return self.dropout(output)

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        return self.net(x)

In [71]:
class VisionPerceiver(nn.Module):
    def __init__(self, input_dim, input_channels=3,
                 max_freq = 8, num_freq_bands = 4,
                 num_iterations = 1, num_transformer_blocks = 4,
                 num_latents = 32, latent_dim = 128,
                 cross_heads = 1, cross_dim_head = 8,
                 latent_heads = 2, latent_dim_head = 8,
                 num_classes = 10,
                 attn_dropout = 0., ff_dropout = 0.):

        super().__init__()

        self.num_latents = num_latents
        self.latent_dim = latent_dim
        self.max_freq = max_freq
        self.num_freq_bands = num_freq_bands
        input_channels *= 9
        # perceiver stacks
        self.layers = nn.ModuleList([])
        for i in range(num_iterations): # build each perceiver cell
            cell = nn.ModuleList([])
            # cross attention module
            cell.append(nn.LayerNorm(latent_dim))
            cell.append(nn.LayerNorm(input_dim * input_channels))
            cell.append(MultiHeadAttention(latent_dim, dim_head = cross_dim_head,
                                           num_heads = cross_heads, context_dim = input_channels,
                                           dropout = attn_dropout))
            # feed forward
            cell.append(nn.LayerNorm(latent_dim))
            cell.append(FeedForward(latent_dim, dropout = ff_dropout))

            # latent transformer
            latent_transformer = nn.ModuleList([])
            for j in range(num_transformer_blocks):
                latent_transformer_block = nn.ModuleList([])
                # self attention
                latent_transformer_block.append(nn.LayerNorm(latent_dim))
                latent_transformer_block.append(MultiHeadAttention(latent_dim, dim_head = latent_dim_head,
                                           num_heads = latent_heads, dropout = attn_dropout))
                # feed forward
                latent_transformer_block.append(nn.LayerNorm(latent_dim))
                latent_transformer_block.append(FeedForward(latent_dim, dropout = ff_dropout))
                latent_transformer.append(latent_transformer_block)
            cell.append(latent_transformer)

            self.layers.append(cell)

        self.to_logits = nn.Sequential(
                            nn.LayerNorm(latent_dim),
                            nn.Linear(latent_dim, num_classes))

    def forward(self, data, attn_mask = None, latent_init = None, seed = None):
        # flatten
        N, C, H, W = data.shape
        data = data.view(N, C, -1).transpose(1, 2)

        # encoding
        data = fourier_encode(data, self.max_freq, self.num_freq_bands)

        # determine the initial latent vector
        if latent_init is None:
            if seed is not None:
                torch.manual_seed(seed)
            latent_init = torch.randn(self.num_latents, self.latent_dim).unsqueeze(0).repeat([N, 1, 1])
            latent_init.requires_grad = False
        else:
            assert latent_init.shape == (num_latents, latent_dim)
            latent_init.unsqueeze(0).repeat([N, 1, 1])

        self.latent_init = nn.Parameter(latent_init)

        x = latent_init
        for cell in self.layers:
            # cross attention
            y = cell[1](data.reshape(N, -1))
            x = cell[2](cell[0](x), y.reshape(N, H*W, -1), attn_mask=attn_mask) + x
            # feed forward
            x = cell[4](cell[3](x)) + x
            # latent transformer
            for latent_transformer in cell[5]:
                # self attention
                x = latent_transformer[1](latent_transformer[0](x), attn_mask=attn_mask) + x
                # feed forward
                x = latent_transformer[3](latent_transformer[2](x)) + x

        x = x.mean(dim = -2)
        return self.to_logits(x)

In [72]:
model = VisionPerceiver(1024, num_transformer_blocks=2)

In [73]:
optim = torch.optim.Adam(model.parameters(), lr=1e-3)

In [74]:
train(model, optim)

Iteration 0, loss = 2.3777
Checking accuracy on validation set
Got 103 / 1000 correct (10.30)

Iteration 100, loss = 2.2177
Checking accuracy on validation set
Got 181 / 1000 correct (18.10)

Iteration 200, loss = 2.1532
Checking accuracy on validation set
Got 240 / 1000 correct (24.00)

Iteration 300, loss = 1.9445
Checking accuracy on validation set
Got 261 / 1000 correct (26.10)

Iteration 400, loss = 1.9280
Checking accuracy on validation set
Got 263 / 1000 correct (26.30)

Iteration 500, loss = 1.9815
Checking accuracy on validation set
Got 251 / 1000 correct (25.10)

Iteration 600, loss = 1.9697
Checking accuracy on validation set
Got 261 / 1000 correct (26.10)

Iteration 700, loss = 1.8497
Checking accuracy on validation set
Got 261 / 1000 correct (26.10)

