# Classical Multi-head attention

Desctiption bewlow is taken from appendix in DETR.

**Multi-head** The general form of _multi-head attention_ with M heads of dimension d is a function with the following signature (using $d^{\prime}=\frac{d}{M}$ (default), and giving matrix/tensors sizes in underbrace)

$$\text { mh-attn }: \underbrace{X_{\mathrm{q}}}_{d \times N_{\mathrm{q}}}, \underbrace{X_{\mathrm{kv}}}_{d \times N_{\mathrm{kv}}}, \underbrace{T}_{M \times 3 \times d^{\prime} \times d}, \underbrace{L}_{d \times d} \mapsto \underbrace{\tilde{X}_{\mathrm{q}}}_{d \times N_{\mathrm{q}}}$$

where $X_q$ is the query sequence of length $N_{\mathrm{q}}, X_{\mathrm{kv}}$ is the key-value sequence of length $N_{\mathrm{kv}}$ (with the same number of channels $d$ for simplicity of exposition), $T$ is the weight tensor to compute the so-called query, key and value embeddings, and $L$ is a projection matrix. The output is the same size as the query sequence. To fix the vocabulary before giving details, multi-head self-attention ($\operatorname{mh-s-attn}$) is the special case $X_{\mathrm{q}}=X_{\mathrm{kv}}$, i.e.

$$\operatorname{mh-s-attn} (X, T, L)=\operatorname{mh-attn} (X, X, T, L)$$

The multi-head attention is simply the concatenation of M single attention
heads followed by a projection with L. The common practice is to use residual connections, dropout and layer normalization. In other words, denoting $\tilde{X}_{\mathrm{q}} = \operatorname{mh-attn}\left(X_{\mathrm{q}}, X_{\mathrm{kv}}, T, L\right)$ and $X_{\mathrm{q}}^{\prime}$ the concatenation of attention heads, we have

$$X_{\mathrm{q}}^{\prime}=\left[\operatorname{attn}\left(X_{\mathrm{q}}, X_{\mathrm{kv}}, T_{1}\right) ; \ldots ; \operatorname{attn}\left(X_{\mathrm{q}}, X_{\mathrm{kv}}, T_{M}\right)\right]$$

$$\tilde{X}_{\mathrm{q}}=\operatorname{layernorm}\left(X_{\mathrm{q}}+\operatorname{dropout}\left(L X_{\mathrm{q}}^{\prime}\right)\right)$$

where [;] denotes concatenation on the channel axis.

**Single head** An attention head with weight tensor $T^{\prime} \in \mathbb{R}^{3 \times d^{\prime} \times d}$, denoted by $\operatorname{attn}\left(X_{\mathrm{q}}, X_{\mathrm{kv}}, T^{\prime}\right)$, depends on additional positional encoding $P_{\mathrm{q}} \in \mathbb{R}^{d \times N_{\mathrm{q}}}$ and
$P_{\mathrm{kv}} \in \mathbb{R}^{d \times N_{\mathrm{kv}}}$. It starts by computing so-called query, key and value embeddings after adding the query and key positional encodings:

$$[Q ; K ; V]=\left[T_{1}^{\prime}\left(X_{\mathrm{q}}+P_{\mathrm{q}}\right) ; T_{2}^{\prime}\left(X_{\mathrm{kv}}+P_{\mathrm{kv}}\right) ; T_{3}^{\prime} X_{\mathrm{kv}}\right]$$

where $T^{\prime}$ is the concatenation of $T_{1}^{\prime}, T_{2}^{\prime}, T_{3}^{\prime}$. The _attention weights_ $\alpha$ are then computed based on the softmax of dot products between queries and keys, so that each element of the query sequence attends to all elements of the key-value sequence ($i$ is a query index and $j$ a key-value index):

$$\alpha_{i, j}=\frac{e^{\frac{1}{\sqrt{d^{\prime}}} Q_{i}^{T} K_{j}}}{Z_{i}} \text{ where } Z_{i}=\sum_{j=1}^{N_{\mathrm{kv}}} e^{\frac{1}{\sqrt{d^{\prime}}} Q_{i}^{T} K_{j}}$$

The final output is the aggregation of values weighted by attention weights: The $i$-th row is given by $\operatorname{attn}_{i}\left(X_{\mathrm{q}}, X_{\mathrm{kv}}, T^{\prime}\right)=\sum_{j=1}^{N_{\mathrm{kv}}} \alpha_{i, j} V_{j}$.

**Feed-forward network (FFN) layers** The original transformer alternates multi-head attention and so-called FFN layers, which are effectively multi-layer 1x1 convolutions, which have $M d$ input and output channels in our case. The FFN we consider is composed of two-layers of 1x1 convolutions with ReLU activations. There is also a residual connection/dropout/layernorm after the two layers.

In [1]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init

In [2]:
class ScaledDotProductAttention(nn.Module):
    """
    Scaled Dot-Product Attention
    """

    def __init__(self, d_head):
        super(ScaledDotProductAttention, self).__init__()
        self.temper = math.sqrt(d_head)

    def forward(self, q, k, v):
        """
        :param q: query (M, b, n_q, d_head)
        :param k: key (M, b, n_kv, d_head)
        :param v: value (M, b, n_kv, d_head)
        """
        attn = torch.matmul(q, k.transpose(2, 3)) / self.temper  # (M, b, n_q, n_kv)
        attn = torch.softmax(attn, dim=3)  # (M, b, n_q, n_kv)
        output = torch.matmul(attn, v)  # (M, b, n_q, d_head)
        return output

In [3]:
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention module
    """

    def __init__(self, M, d, d_head):
        """
        :param M: number of parallel attention heads
        :param d: input hidden size
        :param d_head: number of hidden units in one head (d' in text above)
        """
        super(MultiHeadAttention, self).__init__()
        self.M = M
        self.d_head = d_head
        
        self.w_q = nn.Parameter(torch.FloatTensor(M, d, d_head), requires_grad=True)
        self.w_k = nn.Parameter(torch.FloatTensor(M, d, d_head), requires_grad=True)
        self.w_v = nn.Parameter(torch.FloatTensor(M, d, d_head), requires_grad=True)
        self.attention = ScaledDotProductAttention(d_head)
        self.proj = nn.Linear(M * d_head, d, bias=False)
        self.dropout = nn.Dropout(0.1)
        self.layer_norm = nn.LayerNorm(d)
        self.init_weights()

    def init_weights(self):
        init.xavier_normal_(self.w_q)
        init.xavier_normal_(self.w_k)
        init.xavier_normal_(self.w_v)
        init.xavier_normal_(self.proj.weight)

    def forward(self, x_q, x_kv=None, p_q=None, p_kv=None):
        """
        :param x_q: query sequence (b, n_q, d)
        :param x_kv: key-value sequence (b, n_kv, d)
        :param p_q: positional embeddings for x_q (1, n_q, d)
        :param p_kv: positional embeddings for x_kv (1, n_kv, d)
        """
        if x_kv is None:
            x_kv = x_q
        
        assert (x_q.size(0), x_kv.size(2)) == (x_q.size(0), x_kv.size(2))
        b, n_q, d = x_q.size()
        b, n_kv, d = x_kv.size()
        
        if p_q is None:
            p_q = x_q.new_zeros(1, n_q, d)
        if p_kv is None:
            p_kv = x_kv.new_zeros(1, n_kv, d)
        
        residual = x_q
        
        x_q = x_q.view(1, b * n_q, d).expand(self.M, b * n_q, d)
        q = torch.bmm(x_q, self.w_q).view(self.M, b, n_q, self.d_head)  # (M, b, n_q, d_head)
        
        x_kv = x_kv.view(1, b * n_kv, d).expand(self.M, b * n_kv, d)
        k = torch.bmm(x_kv, self.w_k).view(self.M, b, n_kv, self.d_head)  # (M, b, n_kv, d_head)
        v = torch.bmm(x_kv, self.w_v).view(self.M, b, n_kv, self.d_head)  # (M, b, n_kv, d_head)
        
        out = self.attention(q, k, v).view(self.M * b, n_q, self.d_head)  # (M * b, n_q, d_head)
        out = torch.cat(torch.split(out, b, dim=0), dim=-1)  # (b, n_q, M * d_head)
        out = self.layer_norm(residual + self.dropout(self.proj(out)))  # (b, n_q, d)
        
        return out

class FFN(nn.Module):
    """
    Feed-forward network (FFN) layers
    """

    def __init__(self, d):
        self.ffn_layers = nn.Sequential(nn.Linear(d, d), nn.ReLU(True), nn.Linear(d, d))
        self.layer_norm = nn.LayerNorm(d)

    def forward(self, x):
        return self.layer_norm(x + self.ffn(x))

In [4]:
d = 10
mha = MultiHeadAttention(2, d, 10)
x_q = torch.rand(2, 40, d)
x_kv = torch.rand(2, 5, d)
mha(x_q, x_kv).shape

# one_encoder_block = nn.Sequential()

torch.Size([2, 40, 10])

# Non-local neural networks

https://github.com/AlexHex7/Non-local_pytorch/blob/master/Non-Local_pytorch_0.4.1_to_1.1.0/lib/non_local_embedded_gaussian.py

Сравните код для multi-head attention и тот что ниже.

Концептуально non-local block это signle-head attention, только вместо _layer normalization_ у нас _batch normaliztion_ и нормализация на выходе происходит _до_ skip connection в отличии от кода выше, где _после_

In [5]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init

In [6]:
class _NonLocalBlockND(nn.Module):
    def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
        super(_NonLocalBlockND, self).__init__()

        assert dimension in [1, 2, 3]

        self.dimension = dimension
        self.sub_sample = sub_sample

        self.in_channels = in_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        if dimension == 3:
            conv_nd = nn.Conv3d
            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
            bn = nn.BatchNorm3d
        elif dimension == 2:
            conv_nd = nn.Conv2d
            max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
            bn = nn.BatchNorm2d
        else:
            conv_nd = nn.Conv1d
            max_pool_layer = nn.MaxPool1d(kernel_size=(2))
            bn = nn.BatchNorm1d

        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                         kernel_size=1, stride=1, padding=0)

        if bn_layer:
            self.W = nn.Sequential(
                conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                        kernel_size=1, stride=1, padding=0),
                bn(self.in_channels)
            )
            nn.init.constant_(self.W[1].weight, 0)
            nn.init.constant_(self.W[1].bias, 0)
        else:
            self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                             kernel_size=1, stride=1, padding=0)
            nn.init.constant_(self.W.weight, 0)
            nn.init.constant_(self.W.bias, 0)

        self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                             kernel_size=1, stride=1, padding=0)
        self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                           kernel_size=1, stride=1, padding=0)

        if sub_sample:
            self.g = nn.Sequential(self.g, max_pool_layer)
            self.phi = nn.Sequential(self.phi, max_pool_layer)

    def forward(self, x):
        batch_size = x.size(0)

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1)
        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
        f = torch.matmul(theta_x, phi_x)
        f_div_C = F.softmax(f, dim=-1)

        y = torch.matmul(f_div_C, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y)
        z = W_y + x

        return z


class NONLocalBlock1D(_NonLocalBlockND):
    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
        super(NONLocalBlock1D, self).__init__(in_channels,
                                              inter_channels=inter_channels,
                                              dimension=1, sub_sample=sub_sample,
                                              bn_layer=bn_layer)


class NONLocalBlock2D(_NonLocalBlockND):
    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
        super(NONLocalBlock2D, self).__init__(in_channels,
                                              inter_channels=inter_channels,
                                              dimension=2, sub_sample=sub_sample,
                                              bn_layer=bn_layer)


class NONLocalBlock3D(_NonLocalBlockND):
    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
        super(NONLocalBlock3D, self).__init__(in_channels,
                                              inter_channels=inter_channels,
                                              dimension=3, sub_sample=sub_sample,
                                              bn_layer=bn_layer)

In [7]:
for (sub_sample, bn_layer) in [(True, True), (False, False), (True, False), (False, True)]:
    img = torch.zeros(2, 3, 20)
    net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer)
    out = net(img)
    print(out.size())

    img = torch.zeros(2, 3, 20, 20)
    net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer)
    out = net(img)
    print(out.size())

    img = torch.randn(2, 3, 8, 20, 20)
    net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer)
    out = net(img)
    print(out.size())

torch.Size([2, 3, 20])
torch.Size([2, 3, 20, 20])
torch.Size([2, 3, 8, 20, 20])
torch.Size([2, 3, 20])
torch.Size([2, 3, 20, 20])
torch.Size([2, 3, 8, 20, 20])
torch.Size([2, 3, 20])
torch.Size([2, 3, 20, 20])
torch.Size([2, 3, 8, 20, 20])
torch.Size([2, 3, 20])
torch.Size([2, 3, 20, 20])
torch.Size([2, 3, 8, 20, 20])


# Other types of attention besides ScaledDotProductAttention?

inspired by https://arxiv.org/abs/1904.05873 (An Empirical Study of Spatial Attention Mechanisms in Deep Networks)

In [8]:
class KeyContentAttention(nn.Module):

    def __init__(self, M, d_head):
        super(KeyContentAttention, self).__init__()
        self.u = nn.Parameter(torch.FloatTensor(M, 1, d_head), requires_grad=True)
        self.init_weights()

    def init_weights(self):
        init.xavier_normal_(self.u)

    def forward(self, k, v, B, M=None):
        """
        :param k: key (M, b, n_kv, d_head)
        :param v: value (M, b, n_kv, d_head)
        """

        na, b, thw, da = v.size()
        attn = torch.bmm(self.u, k.transpose(1, 2))  # na x 1 x bthw
        attn = attn.view(na, b, 1, thw).expand(na, b, thw, thw)
        attn = torch.softmax(attn, dim=3)  # na, b, thw, thw
        output = torch.matmul(attn, v)  # na, b, thw, da
        return output

In [9]:
class MultiHeadKeyContentOnlyAttention(nn.Module):

    def __init__(self, M, d, d_head):
        """
        :param M: number of parallel attention heads
        :param d: input hidden size
        :param d_head: number of hidden units in one head (d' in text above)
        """
        super(MultiHeadKeyContentOnlyAttention, self).__init__()
        self.M = M
        self.d_head = d_head
        
        self.w_k = nn.Parameter(torch.FloatTensor(M, d, d_head), requires_grad=True)
        self.w_v = nn.Parameter(torch.FloatTensor(M, d, d_head), requires_grad=True)
        self.attention = KeyContentAttention(M, d_head)
        self.proj = nn.Linear(M * d_head, d, bias=False)
        self.dropout = nn.Dropout(0.1)
        self.layer_norm = nn.LayerNorm(d)
        self.init_weights()

    def init_weights(self):
        init.xavier_normal_(self.w_k)
        init.xavier_normal_(self.w_v)
        init.xavier_normal_(self.proj.weight)

    def forward(self, x_q):
        """
        :param x_q: query sequence (b, n_q, d)
        """
        
        residual = x_q
        
        x_kv = x_q
        b, n_kv, d = x_kv.size()  # n_kv = n_q
        
        x_kv = x_kv.view(1, b * n_kv, d).expand(self.M, b * n_kv, d)
        k = torch.bmm(x_kv, self.w_k).view(self.M, b, n_kv, self.d_head)  # (M, b, n_kv, d_head)
        v = torch.bmm(x_kv, self.w_v).view(self.M, b, n_kv, self.d_head)  # (M, b, n_kv, d_head)
        
        out = self.attention(q, k, v).view(self.M * b, n_q, self.d_head)  # (M * b, n_q, d_head)
        out = torch.cat(torch.split(out, b, dim=0), dim=-1)  # (b, n_q, M * d_head)
        out = self.layer_norm(residual + self.dropout(self.proj(out)))  # (b, n_q, d)
        
        return out
    
d = 10
mhka = MultiHeadKeyContentOnlyAttention(2, d, 10)
x_q = torch.rand(2, 40, d)
x_kv = torch.rand(2, 5, d)
mha(x_q, x_kv).shape

# one_encoder_block = nn.Sequential()

torch.Size([2, 40, 10])