In [1]:
from fastai.torch_core import *

import torch.nn as nn
import torch,math,sys

#Unmodified from https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py
def conv1d(ni:int, no:int, ks:int=1, stride:int=1, padding:int=0, bias:bool=False):
    "Create and initialize a `nn.Conv1d` layer with spectral normalization."
    conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
    nn.init.kaiming_normal_(conv.weight)
    if bias: conv.bias.data.zero_()
    return spectral_norm(conv)


## We compare two implementations of SimpleSelfAttention

Both have at their core this self-attention operation:

<img src="http://latex2png.com/output//latex_e06d8a3710c644867bc207268affb4d5.png" />


x is originally an input tensor of shape (input_channels, height , width) which gets reshaped to (input_channels, N) where N = height * width

W is a tensor of shape (input_channels, input_channels). W * x is implemented as a 1 * 1 convolution

We will show that order of operation matters.

SimpleSelfAttention1 multiplies matrices in the naive order:

<img src="http://latex2png.com/output//latex_0d9f4fa0f02db47f472dc4e681a8c54e.png" />



In [2]:
class SimpleSelfAttention1(nn.Module):

    def __init__(self, n_in:int, ks=1):
        super().__init__()

        self.conv = conv1d(n_in, n_in, ks, padding=ks//2, bias=False)
        self.gamma = nn.Parameter(tensor([0.]))       
        self.n_in = n_in

    def forward(self,x):

    
        size = x.size()            #(Minibatch, Channels, Height, Width)
        x = x.view(*size[:2],-1)           #(Minibatch, Channels, N)
        o = torch.bmm(x.permute(0,2,1).contiguous(),self.conv(x))           # x^T * (W * x)    
        o = self.gamma * torch.bmm(x,o) + x


        return o.view(*size).contiguous()    
    

   


While SimpleSelfAttention2 does it in a different order:


<img src="http://latex2png.com/output//latex_68f9f369ddb38f790f1033f4c09a184c.png" />





In [3]:
class SimpleSelfAttention2(nn.Module):
    
    def __init__(self, n_in:int, ks=1):#, n_out:int):
        super().__init__()
           
        self.conv = conv1d(n_in, n_in, ks, padding=ks//2, bias=False)    
        self.gamma = nn.Parameter(tensor([0.]))
        self.n_in = n_in
        
    def forward(self,x):        
                              
        size = x.size()  
        x = x.view(*size[:2],-1)   # (C,N)               
        convx = self.conv(x)   # (C,C) * (C,N) = (C,N)   => O(NC^2)
        xxT = torch.bmm(x,x.permute(0,2,1).contiguous())   # (C,N) * (N,C) = (C,C)   => O(NC^2)        
        o = torch.bmm(xxT, convx)   # (C,C) * (C,N) = (C,N)   => O(NC^2)          
        o = self.gamma * o + x
        
          
        return o.view(*size).contiguous()     

The complexity of computing the product of two rectangular matrices of shape (n,m) and (m.p) is O(nmp)

Therefore, the operation in SimpleSelfAttention1 is O(NC^2 + CN^2) while for SimpleSelfAttention2 it is O(NC^2)
Remember that N = height * width, so having complexity increase with N^2 is very undesirable!

## Let's see if this works in practice:

In [11]:
n_in = 64
x = torch.randn(64, n_in,32,32)  #minibatch,C,H,W

In [12]:
sa1 = SimpleSelfAttention1(n_in)
sa2 = SimpleSelfAttention2(n_in)

We first check that the two modules have the same output:

In [13]:
torch.equal(sa1(x),sa2(x))

True

#### Let's compare the runtimes:

In [14]:
%%timeit
sa1(x)

349 ms ± 28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [15]:
%%timeit
sa2(x)

216 ms ± 31.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


#### Let's see what happens if we increase channel size by a factor of 16:

In [16]:
n_in = 1024
x = torch.randn(64, n_in,32,32)  #minibatch,C,H,W
sa1 = SimpleSelfAttention1(n_in)
sa2 = SimpleSelfAttention2(n_in)

In [17]:
%%timeit
sa1(x)

2.07 s ± 39.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [18]:
%%timeit
sa2(x)

2.17 s ± 40.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


SimpleSelfAttention2 is more sensitive to channel size. Something to keep in mind if we work with high channel sizes with low input spatial dimensions.

#### What happens if we just double spatial dimensions?

Back to 64 channels

In [19]:
n_in = 64
x = torch.randn(64, n_in,64,64)  #minibatch,C,H,W
sa1 = SimpleSelfAttention1(n_in)
sa2 = SimpleSelfAttention2(n_in)

In [20]:
%%timeit
sa1(x)

2.32 s ± 49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [21]:
%%timeit
sa2(x)

435 ms ± 22.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


#### Let's double them again:

In [22]:
n_in = 64
x = torch.randn(64, n_in,128,128)  #minibatch,C,H,W
sa1 = SimpleSelfAttention1(n_in)
sa2 = SimpleSelfAttention2(n_in)

In [23]:
%%timeit
sa1(x)

34.5 s ± 1.78 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [24]:
%%timeit
sa2(x)

1.36 s ± 58.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


SimpleSelfAttention is much better at handling larger spatial dimensions!

## How does this compare to the original Self Attention layer?

This is the original SelfAttention layer as currently implemented in fast.ai
https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py

This implementation is based on the SAGAN paper
https://arxiv.org/abs/1805.08318

In [5]:
class SelfAttention(nn.Module):
    "Self attention layer for nd."
    def __init__(self, n_channels:int):
        super().__init__()
        self.query = conv1d(n_channels, n_channels//8)
        self.key   = conv1d(n_channels, n_channels//8)
        self.value = conv1d(n_channels, n_channels)
        self.gamma = nn.Parameter(tensor([0.]))

    def forward(self, x):
        #Notation from https://arxiv.org/pdf/1805.08318.pdf
        size = x.size()
        x = x.view(*size[:2],-1)
        f,g,h = self.query(x),self.key(x),self.value(x)
        beta = F.softmax(torch.bmm(f.permute(0,2,1).contiguous(), g), dim=1)
        o = self.gamma * torch.bmm(h, beta) + x
        return o.view(*size).contiguous()

It doesn't seem that we can use the same reordering trick, due to the presence of softmax.

The outputs from SelfAttention and SimpleSelfAttention won't match, but we can compare runtimes:

In [33]:
n_in = 32
x = torch.randn(64, n_in,16,16)  #minibatch,C,H,W
sa1 = SimpleSelfAttention1(n_in)
sa2 = SimpleSelfAttention2(n_in)
sa0 = SelfAttention(n_in)

In [34]:
%%timeit
sa0(x)

323 ms ± 35.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [35]:
%%timeit
sa1(x)

102 ms ± 6.14 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [36]:
%%timeit
sa2(x)

90.9 ms ± 11 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [37]:
# Double image size
n_in = 32
x = torch.randn(64, n_in,32,32)  #minibatch,C,H,W
sa1 = SimpleSelfAttention1(n_in)
sa2 = SimpleSelfAttention2(n_in)
sa0 = SelfAttention(n_in)

In [38]:
%%timeit
sa0(x)

3.85 s ± 206 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [39]:
%%timeit
sa1(x)

181 ms ± 10.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [40]:
%%timeit
sa2(x)

100 ms ± 22.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [18]:
# Double image size again
n_in = 32
x = torch.randn(64, n_in,64,64)  #minibatch,C,H,W
sa1 = SimpleSelfAttention1(n_in)
sa2 = SimpleSelfAttention2(n_in)
sa0 = SelfAttention(n_in)

In [22]:
%%time
sa0(x);

CPU times: user 2min 40s, sys: 9.91 s, total: 2min 50s
Wall time: 2min 29s


tensor([[[[-7.5878e-01,  1.2338e+00,  5.2843e-01,  ..., -2.9998e-01,
            7.1314e-01, -4.6787e-01],
          [ 5.3816e-01,  1.2520e+00, -5.7126e-02,  ..., -7.3750e-01,
           -2.7030e-01, -5.1272e-01],
          [ 1.5741e+00,  7.8619e-01,  7.7834e-01,  ...,  1.0803e+00,
            3.2582e-01,  4.1008e-01],
          ...,
          [ 2.6405e+00, -2.3683e+00,  3.3436e-01,  ...,  3.8543e-01,
            1.0957e+00, -5.4066e-01],
          [-1.4477e-01, -2.9981e-01, -7.0663e-01,  ..., -7.3166e-02,
           -1.0477e+00,  1.4565e+00],
          [-1.0083e+00,  6.1483e-01,  2.1860e-01,  ...,  1.2566e+00,
           -3.4978e-01, -2.0957e+00]],

         [[-8.5598e-01,  4.4064e-01,  1.0505e+00,  ...,  3.4780e-01,
            1.2274e+00,  1.5416e-01],
          [ 1.1609e-01,  8.6406e-01, -1.4297e+00,  ...,  9.5329e-01,
           -1.5561e-01,  5.7680e-01],
          [ 1.4711e+00,  1.6361e+00,  1.0770e+00,  ..., -1.3232e+00,
           -4.5743e-01,  2.2284e+00],
          ...,
     

In [20]:
%%timeit
sa1(x)

1.68 s ± 95.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [21]:
%%timeit
sa2(x)

296 ms ± 41.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


The original SelfAttention layer uses softmax on an N * N matrix, which I believe would be O(N^2) as well. since we do softmax N times and softmax is O(N) according to http://cs231n.stanford.edu/reports/2017/pdfs/130.pdf

The improved SimpleSelfAttention layer seems to provide a major improvement in terms of complexity compared to the original SelfAttention layer. It remains to be demonstrated whether it can work as an equivalent layer (e.g. in self-attention GANs)