### Attention block

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch as th
import torch.nn as nn
from util import get_torch_size_string
np.set_printoptions(precision=3)
th.set_printoptions(precision=3)
%matplotlib inline
%config InlineBackend.figure_format='retina'
print ("PyTorch version:[%s]."%(th.__version__))

PyTorch version:[2.0.1].


### QKV Attention (Legacy)

In [2]:
def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module

class GroupNorm32(nn.GroupNorm):
    def forward(self, x):
        return super().forward(x.float()).type(x.dtype)
    
def normalization(n_channels,n_groups=1):
    """
    Make a standard normalization layer.

    :param n_channels: number of input channels.
    :param n_channels: number of input channels. if this is 1, then it is identical to layernorm.
    :return: an nn.Module for normalization.
    """
    return GroupNorm32(num_groups=n_groups,num_channels=n_channels)

print ("Ready.")

Ready.


In [3]:
import math
class QKVAttentionLegacy(nn.Module):
    """
    A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
    """
    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads

    def forward(self, qkv):
        """
        Apply QKV attention.

        :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
        :return: an [N x (H * C) x T] tensor after attention.
        """
        bs, width, length = qkv.shape
        assert width % (3 * self.n_heads) == 0
        ch = width // (3 * self.n_heads)
        q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        weight = th.einsum(
            "bct,bcs->bts", q * scale, k * scale
        )  # More stable with f16 than dividing afterwards
        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
        a = th.einsum("bts,bcs->bct", weight, v)
        return a.reshape(bs, -1, length)
    
class AttentionBlock(nn.Module):
    """
    An attention block that allows spatial positions to attend to each other.
    """
    def __init__(
            self,
            n_channels         = 1,
            n_heads            = 1,
    ):
        super().__init__()
        self.n_channels         = n_channels
        self.n_heads            = n_heads
        assert (
            n_channels % n_heads == 0
        ), f"n_channels:[%d] should be divisible by n_heads:[%d]."%(n_channels,n_heads)
            
        # Normalize 
        self.norm = normalization(n_channels=n_channels,n_groups=32)
        
        # Tripple the channel
        self.qkv = nn.Conv1d(
            in_channels  = self.n_channels,
            out_channels = self.n_channels*3,
            kernel_size  = 1
        )
        
        # QKV Attention
        self.attention = QKVAttentionLegacy(
            n_heads = self.n_heads
        )
        
        # Projection
        self.proj_out = zero_module(
            nn.Conv1d(
                in_channels  = self.n_channels,
                out_channels = self.n_channels,
                kernel_size  = 1
            )
        )
        
    def forward(self, x):
        """
        :param x: [B x C x W x H] tensor
        :return out: [B x C x W x H] tensor
        """
        b, c, *spatial = x.shape
        # Triple the channel 
        x   = x.reshape(b, c, -1)    # [B x C x WH]
        x   = self.norm(x)           # [B x C x WH]
        qkv = self.qkv(x)            # [B x 3C x WH]
        # QKV attention
        h   = self.attention(qkv)    # [B x C x WH]
        h   = self.proj_out(h)       # [B x C x WH]
        out = (x + h).reshape(b, c, *spatial) # [B x C x WH]
        return out

print ("Ready.")

Ready.


### Let's see how `AttentionBlock` works
- First, we assume that an input tensor has a shape of [B x C x W x H].
- This can be thought of having a total of WH tokens with each token hainv C dimensions. 
- The MHA operates by initally partiting the channels, executing qkv attention process, and then merging the results. 
- Note the the number of channels should be divisible by the number of heads.

In [9]:
layer = AttentionBlock(n_channels=128,n_heads=4)
x = th.randn(16,128,28,28)
out = layer(x)
print ("input shape:[%s] output shape:[%s]"%
       (get_torch_size_string(x),get_torch_size_string(out)))

input shape:[16x128x28x28] output shape:[16x128x28x28]
