### Attention block

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch as th
from module import AttentionBlock
from util import get_torch_size_string
np.set_printoptions(precision=3)
th.set_printoptions(precision=3)
%matplotlib inline
%config InlineBackend.figure_format='retina'
print ("PyTorch version:[%s]."%(th.__version__))

PyTorch version:[2.0.1].


### Let's see how `AttentionBlock` works
- First, we assume that an input tensor has a shape of [B x C x W x H].
- This can be thought of having a total of WH tokens with each token having C dimensions. 
- The MHA operates by initally partiting the channels, executing qkv attention process, and then merging the results. 
- Note the the number of channels should be divisible by the number of heads.

### `dims=2`
#### `x` has a shape of `[B x C x W x H]`

In [2]:
layer = AttentionBlock(n_channels=128,n_heads=4,n_groups=32)
x = th.randn(16,128,28,28)
out,intermediate_output_dict = layer(x)
print ("input shape:[%s] output shape:[%s]"%
       (get_torch_size_string(x),get_torch_size_string(out)))
# Print intermediate values
for key,value in intermediate_output_dict.items():
    print ("[%10s]:[%15s]"%(key,get_torch_size_string(value)))

input shape:[16x128x28x28] output shape:[16x128x28x28]
[         x]:[   16x128x28x28]
[     x_rsh]:[     16x128x784]
[     x_nzd]:[     16x128x784]
[       qkv]:[     16x384x784]
[     h_att]:[     16x128x784]
[    h_proj]:[     16x128x784]
[       out]:[   16x128x28x28]


### `dims=1`
#### `x` has a shape of `[B x C x L]`

In [3]:
layer = AttentionBlock(n_channels=4,n_heads=2,n_groups=1)
x = th.randn(16,4,100)
out,intermediate_output_dict = layer(x)
print ("input shape:[%s] output shape:[%s]"%
       (get_torch_size_string(x),get_torch_size_string(out)))
# Print intermediate values
for key,value in intermediate_output_dict.items():
    print ("[%10s]:[%15s]"%(key,get_torch_size_string(value)))

input shape:[16x4x100] output shape:[16x4x100]
[         x]:[       16x4x100]
[     x_rsh]:[       16x4x100]
[     x_nzd]:[       16x4x100]
[       qkv]:[      16x12x100]
[     h_att]:[       16x4x100]
[    h_proj]:[       16x4x100]
[       out]:[       16x4x100]
