In [1]:
import torch
import torch.nn as nn
import math
import torchvision
import torchvision.transforms as transforms

In [16]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.linear_q = nn.Linear(embed_dim, embed_dim)
        self.linear_k = nn.Linear(embed_dim, embed_dim)
        self.linear_v = nn.Linear(embed_dim, embed_dim)
        self.num_heads = num_heads

    # Reshape the input tensor from (B, N, C) to (B * num_heads, N, C // num_heads)
    def _reshape_to_batches(self, x):
        batch_size, seq_len, embed_dim = x.size()
        sub_dim = embed_dim // self.num_heads
        x = x.reshape(batch_size, seq_len, self.num_heads, sub_dim)
        x = x.permute(0, 2, 1, 3)
        x = x.reshape(batch_size * self.num_heads, seq_len, sub_dim)
        return x

    # Reshape the input tensor from (B * num_heads, N, C // num_heads) to (B, N, C)
    def _reshape_from_batches(self, x):
        batch_size, seq_len, sub_dim = x.size()
        batch_size //= self.num_heads
        embed_dim = sub_dim * self.num_heads
        x = x.reshape(batch_size, self.num_heads, seq_len, sub_dim)
        x = x.permute(0, 2, 1, 3)
        x = x.reshape(batch_size, seq_len, embed_dim)
        return x

    def forward(self, x):
        """
        :param x: (B, N, C).
        """
        # Obtain Query, Key and Value by linear transformation
        x_q = self.linear_q(x)
        x_k = self.linear_k(x)
        x_v = self.linear_v(x)
        if self.num_heads > 1:
            x_q = self._reshape_to_batches(x_q)
            x_k = self._reshape_to_batches(x_k)
            x_v = self._reshape_to_batches(x_v)

        # Matmul between Query and Key
        qk = torch.bmm(x_q, x_k.permute(0, 2, 1))      # (B, N, N)

        # Apply scale factor
        dk = x_q.shape[-1]
        qk = qk / math.sqrt(dk)        # (B, N, N)

        # Compute attention scores with softmax
        attn = qk.softmax(-1)       # (B, N, N)

        # Weight Value by attention scores
        out = torch.bmm(attn, x_v)      # (B, N, C)
        if self.num_heads > 1:
            out = self._reshape_from_batches(out)
        return out

In [17]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(TransformerEncoderLayer, self).__init__()
        self.attn = MultiHeadAttention(embed_dim, num_heads)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )

    def forward(self, x):
        """
        :param x: (B, N, C).
        """
        x = self.attn(x)        # (B, N, C)
        out = self.feed_forward(x)      # (B, N, C)
        return out

In [18]:
class TransformerImageClassifier(nn.Module):
    def __init__(self,
                 input_dim,
                 num_layers,
                 transformer_embed_dim,
                 transformer_num_heads,
                 num_classes):
        super(TransformerImageClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, transformer_embed_dim)
        self.transformer_layers = nn.ModuleList()
        for l in range(num_layers):
            self.transformer_layers.append(TransformerEncoderLayer(
                embed_dim=transformer_embed_dim,
                num_heads=transformer_num_heads
            ))
        self.fc2 = nn.Linear(transformer_embed_dim, num_classes)

    def forward(self, x):
        """
        :param x: (B, N, C_in).
        """
        x = self.fc1(x)     # (B, N, C)
        for transformer_layer in self.transformer_layers:
            x = transformer_layer(x)      # (B, N, C)

        # Merge the information of all elements
        x, _ = torch.max(x, dim=1)      # (B, C)
        out = self.fc2(x)     # (B, num_classes)
        return out

In [19]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Hyperparameters
input_dim = 20
num_layers = 1  # 2
transformer_embed_dim = 32
transformer_num_heads = 1   # 8
num_classes = 10

# Unit-test for Attention module
attn = MultiHeadAttention(embed_dim=transformer_embed_dim,
                            num_heads=transformer_num_heads)
attn = attn.to(device)
x = torch.randn(4, 100, 32).to(device)
out = attn(x)
print(out.shape)

# Unit-test for Transformer network
model = TransformerImageClassifier(input_dim, num_layers, transformer_embed_dim, transformer_num_heads, num_classes)
model.to(device)
x = torch.randn(4, 100, 20).to(device)
out = model(x)
print(out.shape)

torch.Size([4, 100, 32])
torch.Size([4, 10])


In [20]:
def create_dataloader():
    # MNIST dataset
    train_dataset = torchvision.datasets.MNIST(root='root',
                                               train=True,
                                               download=True,
                                               transform=transforms.ToTensor())

    test_dataset = torchvision.datasets.MNIST(root='root',
                                              train=False,
                                              download=True,
                                              transform=transforms.ToTensor())

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=64,
                                               shuffle=True)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=64,
                                              shuffle=False)

    return train_loader, test_loader

In [21]:
def train(train_loader, model, criterion, optimizer, num_epochs):
    # Train the model
    total_step = len(train_loader)
    for epoch in range(num_epochs):
        for step, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)
            images = images.reshape(-1, 28, 28)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (step + 1) % 100 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                      .format(epoch + 1, num_epochs, step + 1, total_step, loss.item()))

In [22]:
def test(test_loader, model):
    # Test the model
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            images = images.reshape(-1, 28, 28)
            
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))

In [23]:
# Hyperparameters
input_dim = 28
num_layers = 1
transformer_embed_dim = 32
transformer_num_heads = 1
num_classes = 10

### step 1: prepare dataset and create dataloader
train_loader, test_loader = create_dataloader()

### step 2: create neural network
model = TransformerImageClassifier(input_dim,
                                    num_layers,
                                    transformer_embed_dim,
                                    transformer_num_heads,
                                    num_classes)
model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

### step 3: train the model
train(train_loader, model, criterion, optimizer, num_epochs=5)

### step 4: test the model
test(test_loader, model)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to root/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:17<00:00, 577140.30it/s] 


Extracting root/MNIST/raw/train-images-idx3-ubyte.gz to root/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to root/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 136539.54it/s]


Extracting root/MNIST/raw/train-labels-idx1-ubyte.gz to root/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to root/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:02<00:00, 696925.85it/s]


Extracting root/MNIST/raw/t10k-images-idx3-ubyte.gz to root/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to root/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 1159284.90it/s]


Extracting root/MNIST/raw/t10k-labels-idx1-ubyte.gz to root/MNIST/raw

Epoch [1/5], Step [100/938], Loss: 2.0232
Epoch [1/5], Step [200/938], Loss: 1.6656
Epoch [1/5], Step [300/938], Loss: 1.2335
Epoch [1/5], Step [400/938], Loss: 1.5771
Epoch [1/5], Step [500/938], Loss: 1.1899
Epoch [1/5], Step [600/938], Loss: 1.4120
Epoch [1/5], Step [700/938], Loss: 1.4072
Epoch [1/5], Step [800/938], Loss: 0.9263
Epoch [1/5], Step [900/938], Loss: 1.0644
Epoch [2/5], Step [100/938], Loss: 1.0740
Epoch [2/5], Step [200/938], Loss: 1.1407
Epoch [2/5], Step [300/938], Loss: 1.1531
Epoch [2/5], Step [400/938], Loss: 1.1012
Epoch [2/5], Step [500/938], Loss: 1.1540
Epoch [2/5], Step [600/938], Loss: 1.1275
Epoch [2/5], Step [700/938], Loss: 1.1825
Epoch [2/5], Step [800/938], Loss: 1.2798
Epoch [2/5], Step [900/938], Loss: 0.9894
Epoch [3/5], Step [100/938], Loss: 0.9308
Epoch [3/5], Step [200/938], Loss: 0.8015
Epoch [3/5], Step [300/938], Loss: 0.9995
Epoch [3/5], Step [400/938], Loss: 0.8802
Epoch