In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.cuda.is_available()

False

In [160]:
def conv_2d(inp, oup, kernel_size=3, stride=1, groups=1, bias=False, norm=True, act=True):
    conv = nn.Sequential()
    padding = (kernel_size - 1) // 2
    conv.add_module('conv', nn.Conv2d(inp, oup, kernel_size, stride, padding, bias=bias, groups=groups))
    if norm:
        conv.add_module('BatchNorm2d', nn.BatchNorm2d(oup))
    if act:
        conv.add_module('Activation', nn.ReLU6())
    return conv

class MultiQueryAttentionLayerWithDownSampling(nn.Module):
    def __init__(self, inp, num_heads, key_dim, value_dim, query_h_strides, query_w_strides, kv_strides, dw_kernel_size=3, dropout=0.0):
        """Multi Query Attention with spatial downsampling.
        Referenced from here https://github.com/tensorflow/models/blob/master/official/vision/modeling/layers/nn_blocks.py

        3 parameters are introduced for the spatial downsampling:
        1. kv_strides: downsampling factor on Key and Values only.
        2. query_h_strides: vertical strides on Query only.
        3. query_w_strides: horizontal strides on Query only.

        This is an optimized version.
        1. Projections in Attention is explict written out as 1x1 Conv2D.
        2. Additional reshapes are introduced to bring a up to 3x speed up.
        """
        super().__init__()
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.value_dim = value_dim
        self.query_h_strides = query_h_strides
        self.query_w_strides = query_w_strides
        self.kv_strides = kv_strides
        self.dw_kernel_size = dw_kernel_size
        self.dropout = dropout

        self.head_dim = key_dim // num_heads

        if self.query_h_strides > 1 or self.query_w_strides > 1:
            self._query_downsampling_norm = nn.BatchNorm2d(inp)
        self._query_proj = conv_2d(inp, num_heads*key_dim, 1, 1, norm=False, act=False)
        
        if self.kv_strides > 1:
            self._key_dw_conv = conv_2d(inp, inp, dw_kernel_size, kv_strides, groups=inp, norm=True, act=False)
            self._value_dw_conv = conv_2d(inp, inp, dw_kernel_size, kv_strides, groups=inp, norm=True, act=False)
        self._key_proj = conv_2d(inp, key_dim, 1, 1, norm=False, act=False)
        self._value_proj = conv_2d(inp, key_dim, 1, 1, norm=False, act=False)

        if self.query_h_strides > 1 or self.query_w_strides > 1:
            self._output_upsample = nn.Upsample(scale_factor=(self.query_h_strides, self.query_w_strides), mode='bilinear')
        self._output_proj = conv_2d(num_heads*key_dim, inp, 1, 1, norm=False, act=False)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        batch_size, seq_length, _, _ = x.size()
        if self.query_h_strides > 1 or self.query_w_strides > 1:
            q = F.avg_pool2d(x,(self.query_h_stride, self.query_w_stride))
            q = self._query_downsampling_norm(q)
            q = self._query_proj(q)
        else:
            q = self._query_proj(x)
        px = q.size(2)
        q = q.view(batch_size, self.num_heads, -1, self.key_dim) # [batch_size, num_heads, seq_length, key_dim]

        if self.kv_strides > 1:
            k = self._key_dw_conv(x)
            k = self._key_proj(k)
            v = self._value_dw_conv(x)
            v = self._value_proj(v)          
        else:
            k = self._key_proj(x)
            v = self._value_proj(x)
        k = k.view(batch_size, 1, self.key_dim, -1) # [batch_size, 1, key_dim, seq_length]
        v = v.view(batch_size, 1, -1, self.key_dim) # [batch_size, 1, seq_length, key_dim]

        # calculate attn score
        attn_score = torch.matmul(q, k) / (self.head_dim ** 0.5)
        attn_score = self.dropout(attn_score)
        attn_score = F.softmax(attn_score, dim=-1)

        context = torch.matmul(attn_score, v)
        context = context.view(batch_size, self.num_heads * self.key_dim, px, px)
        
        output = self._output_upsample(context)
        output = self._output_proj(context)
        return output

In [161]:
inp = 3
num_heads = 2
key_dim = 8
query_h_strides = 3
query_w_strides = 3

head_dim = key_dim // num_heads
input_tensor = torch.randn((1,inp,15,15))
batch_size, seq_length, _, _ = input_tensor.size()
conv_2d(inp,7)(input_tensor).shape

torch.Size([1, 7, 15, 15])

In [162]:
query_dw_conv = conv_2d(inp, inp, 3, (query_h_strides, query_w_strides), groups=inp, norm=True, act=False)
query_proj = conv_2d(inp, num_heads*key_dim, 1, 1, norm=False, act=False)
query_downsampling_norm = nn.BatchNorm2d(inp)
q = F.avg_pool2d(input_tensor,(query_h_strides, query_w_strides))
q = query_downsampling_norm(q)
q = query_proj(q)
q.shape

torch.Size([1, 16, 5, 5])

In [163]:
px = q.size(2)
q = q.view(batch_size, num_heads, -1, key_dim)
q.shape

torch.Size([1, 2, 25, 8])

In [164]:
key_dw_conv = conv_2d(inp, inp, 3, 2, groups=inp, norm=True, act=False)
k = key_dw_conv(input_tensor)
k.shape

torch.Size([1, 3, 8, 8])

In [165]:
key_proj = conv_2d(inp, key_dim, 1, 1, norm=False, act=False)
k = key_proj(k)
k.shape

torch.Size([1, 8, 8, 8])

In [166]:
k = k.view(batch_size, 1, key_dim, -1)
k.shape

torch.Size([1, 1, 8, 64])

In [167]:
attn_score = torch.matmul(q, k) / (head_dim ** 0.5)
attn_score.shape

torch.Size([1, 2, 25, 64])

In [154]:
attn_score = F.softmax(attn_score, dim=-1)
attn_score.shape

torch.Size([1, 2, 25, 64])

In [155]:
v = k.view(batch_size, 1, -1, key_dim)
v.shape

torch.Size([1, 1, 64, 8])

In [156]:
context = torch.matmul(attn_score, v)
context.shape

torch.Size([1, 2, 25, 8])

In [157]:
context = context.view(batch_size, num_heads * key_dim, px, px)
context.shape

torch.Size([1, 16, 5, 5])

In [158]:
output_upsample = nn.Upsample(scale_factor=(query_h_strides, query_w_strides), mode='bilinear')
output = output_upsample(context)
output.shape

torch.Size([1, 16, 15, 15])

In [159]:
output_proj = conv_2d(num_heads*key_dim, inp, 1, 1, norm=False, act=False)
output = output_proj(output)
output.shape

torch.Size([1, 3, 15, 15])