In [32]:
import torch
from torch import nn
from einops import rearrange, repeat

# SelfAttention

[![](https://mermaid.ink/img/pako:eNq1UjtPwzAQ_ivWTY2UVmLNwFQ2xNJuGEWmORqL2A7OGaja_nfOdltIeUzFQy6--x72J29h5RqECtZe9a1YzqUVvLTtA00kpCru70rRoS1Fo02t7YOEQkyn1zsJ1s5utUXlJ3mUIS5QIWEnXgL6DaukOlZhSJS5pNszRi_-_rvTq-oCskqqf7jFs7CeYEHyyg69G5D53K4pH3W2TPQDM-mM6DUdBR6NidaKCC1zU_20HvFy3D_w8jhTD2O-6MI9kVHvk-lVcULWb6jXLZ2M8nb4xfAr5czXpbjdedijmHKM33hxCCUY9Ebphp_oNnY4yhYNx1jxbxcdJUi7Z2DoG0V402hyHiryAUtQgdxiY1fHfcbMteLnbnJz_wHsB_sC)](https://mermaid.live/edit#pako:eNq1UjtPwzAQ_ivWTY2UVmLNwFQ2xNJuGEWmORqL2A7OGaja_nfOdltIeUzFQy6--x72J29h5RqECtZe9a1YzqUVvLTtA00kpCru70rRoS1Fo02t7YOEQkyn1zsJ1s5utUXlJ3mUIS5QIWEnXgL6DaukOlZhSJS5pNszRi_-_rvTq-oCskqqf7jFs7CeYEHyyg69G5D53K4pH3W2TPQDM-mM6DUdBR6NidaKCC1zU_20HvFy3D_w8jhTD2O-6MI9kVHvk-lVcULWb6jXLZ2M8nb4xfAr5czXpbjdedijmHKM33hxCCUY9Ebphp_oNnY4yhYNx1jxbxcdJUi7Z2DoG0V402hyHiryAUtQgdxiY1fHfcbMteLnbnJz_wHsB_sC)

In [9]:
class SelfAttention(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(SelfAttention, self).__init__()
        self.to_q = nn.Linear(in_dim, out_dim)
        self.to_k = nn.Linear(in_dim, out_dim)
        self.to_v = nn.Linear(in_dim, out_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)
        atten = torch.bmm(q, torch.permute(k, [0, 2, 1]))
        atten = self.softmax(atten)
        out = torch.bmm(atten, v)
        return out

def test_attention():
    in_dim = 128
    out_dim = 128
    b_s = 8
    sequence_len = 10
    x = torch.randn(b_s, sequence_len, in_dim)
    block = SelfAttention(in_dim, out_dim)
    out = block(x)
    make_dot(out)
    assert block(x).shape == (b_s, sequence_len, out_dim)
test_attention()

# Naive MultiheadAttention

In [20]:
class MultiheadSelfAttention(nn.Module):
    def __init__(self, num_head, in_dim):
        super(MultiheadSelfAttention, self).__init__()
        assert in_dim // num_head == in_dim / num_head
        self.proj = nn.Linear(in_dim, in_dim)
        self.num_head = num_head

        inter_dim = in_dim // num_head
        layers = [Attention(in_dim, inter_dim) for _ in range(num_head)]
        self.layers = nn.Sequential(*layers)

        self.scale = inter_dim ** -0.5

    def forward(self, x):
        outs = [self.layers[i](x) for i in range(self.num_head)]
        outs = torch.cat(outs, dim=-1)
        out = self.proj(outs)
        return out

def test_multiheadattention():
    in_dim = 128
    b_s = 8
    sequence_len = 10
    num_head = 8
    x = torch.randn(b_s, sequence_len, in_dim)
    block = MultiheadSelfAttention(num_head, in_dim)
    assert block(x).shape == (b_s, sequence_len, in_dim)
test_multiheadattention()

# Parallel MultiheadAttention

In [21]:
class Attention(nn.Module):
    """
    args:
        dim: input dim
        num_head: #head in multihead attention
        head_dim: dim of each head
        dropout: default 0.
    """
    def __init__(self, dim, num_head=8, head_dim=64, dropout=0.):
        super(Attention, self).__init__()
        self.inner_dim = num_head * head_dim
        self.num_head = num_head
        self.to_qkv = nn.Linear(dim, self.inner_dim * 3)

        self.softmax = nn.Softmax(dim=-1)
        self.scale = head_dim ** -0.5

        self.proj = nn.Linear(self.inner_dim, dim)
        self.dropout = nn.Dropout(dropout)



    def forward(self, x):
        """
        args:
            x: tensor (b_s, len_seq, dim)
            
        return:
            out: tensor (b_s, len_seq, dim)
        """
        # split into q k v
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        # reshape
        q, k, v = map(lambda t:rearrange(t, "b n (h d) ->  b h n d", h=self.num_head), qkv)
        # cal weights
        weights = self.softmax(self.scale * torch.matmul(q, torch.transpose(k, -1, -2)))
        weights = self.dropout(weights)
        out = torch.matmul(weights, v)
        # reshape
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.proj(out)
        return out

def test_attention():
    in_dim = 128
    b_s = 8
    sequence_len = 10
    x = torch.randn(b_s, sequence_len, in_dim)
    block = Attention(in_dim)
    assert block(x).shape == (b_s, sequence_len, in_dim)
test_attention()

In [3]:
class Norm(nn.Module):
    def __init__(self, dim):
        super(Norm, self).__init__()
        self.layer = nn.LayerNorm(dim)

    def forward(self, x):
        return self.layer(x)

def test_norm():
    dim = 128
    b_s = 8
    sequence_len = 10
    x = torch.randn(b_s, sequence_len, dim)
    block = Norm(dim)
    assert block(x).shape == (b_s, sequence_len, dim)
test_norm()

In [None]:
class InputEmbed(nn.Module):
    """
    args:
        img_h, img_w: h/w of input image
        pathc_h, patch_w: h/w of patch
        channel: channel of image
        dim: dim of embedding
    """
    def __init__(self, img_h, img_w, patch_h, patch_w, channel, dim):
        super(InputEmbed, self).__init__()
        assert img_h // patch_h == img_h / patch_h
        assert img_w // patch_w == img_w / patch_w
        self.p_h = patch_h
        self.p_w = patch_w
        patch_dim = patch_h * patch_w * channel
        self.patch_embed = nn.Linear(patch_dim, dim)

        num_patch = (img_h // patch_h) * (img_w // patch_w)

        # position embedding is just learnable params
        self.pos_embed = nn.Parameter(torch.randn(1, num_patch + 1, dim))

        # class token for feature representation
        self.cls_token =  nn.Parameter(torch.randn(1, 1, dim))


    def forward(self, img):
        """
        args:
            img: tensor (b, c, h, w)
        return:
            out: tensor (b, num_patches + 1, dim)
        """

        b_s = img.shape[0]
        # split into patches
        patches = rearrange(img, "b c (p_h n_h) (p_w n_w) -> b (n_h n_w) (p_h p_w c) ", p_h=self.p_h, p_w = self.p_w)
        patches = self.patch_embed(patches)

        cls_token = repeat(self.cls_token, "1 1 d -> b 1 d", b=b_s)
        features = torch.cat((patches, cls_token), dim=1)
        pos_embed = repeat(self.pos_embed,"1 n d -> b n d", b=b_s)
        features = features + pos_embed
        return features


def test_embed():
    img_h, img_w = 64, 64
    patch_h, patch_w = 8, 8
    channel = 3
    dim = 128
    b_s = 8
    num_patch = (img_h / patch_h) * (img_w / patch_w)
    x = torch.randn(b_s, channel, img_h, img_w)
    block = InputEmbed(img_h, img_w, patch_h, patch_w, channel, dim)
    res = block(x)
    assert res.shape == (b_s, num_patch + 1, dim)
test_embed()

torch.Size([8, 65, 128])


In [5]:
class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, drop_out=0.):
        super(MLP, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(drop_out),
            nn.Linear(hidden_dim, in_dim),
            nn.Dropout(drop_out)
        )

    def forward(self, x):
        return self.layer(x)

def test_MLP():
    dim = 128
    b_s = 8
    sequence_len = 10
    x = torch.randn(b_s, sequence_len, dim)
    block = MLP(dim, 256)
    assert block(x).shape == (b_s, sequence_len, dim)
test_MLP()

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, dim, num_head, hidden_dim):
        super(EncoderBlock, self).__init__()
        self.norm1 = Norm(dim)
        self.norm2 = Norm(dim)
        self.fcs = MLP(dim, hidden_dim)
        self.attention = MultiheadAttention(num_head, dim) 

    def forward(self, x):
        resi = x
        x = self.attention(self.norm1(x))
        resi = x + resi
        x = self.fcs(self.norm2(resi))
        x = x + resi
        return x
        

def test_MLP():
    dim = 128
    b_s = 8
    sequence_len = 10
    x = torch.randn(b_s, sequence_len, dim)
    block = MLP(dim, 256)
    assert block(x).shape == (b_s, sequence_len, dim)
test_MLP()

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        pass

    def forward(self):
        pass

In [2]:
a = torch.randn(8, 3, 4)
b = torch.randn(8, 4, 3)
(a@b).shape

torch.Size([8, 3, 3])