This notebook is a test for implementing Attention in Transformer

In [40]:
import torch
from torch import nn
from torch.nn import functional as F

In [5]:
# Constant and Parameter
device_type = 'cuda'

Basic structure of Attention:

    B: batch size
    N: sequence length
    C: (raw) embedding dimension
    d: (inner) embedding dimension
    nh: num. of head, note C = d * nh
    alpha: scaling factor, typically 1 / sqrt(d)
    mQ, mK, mV, mO:         matrix (C, C),          embedding project matrix
    x:                      matrix (B, N, C),       input tensor
                            matrix (B, nh, N, d)        reshape
    M:                      matrix (1, 1, N, N),    mask on attention scope, '-INF' for un-attended token
    D_attn, D_out:          dropout layer

Forward Pass:

    Q := x @ mQ, K := x @ mK, V := x @ mV:            
                            matrix (B, nh, N, d)        reshape
    S := alpha * Q @ K^T:   matrix (B, nh, N, N)    
    P := sigmoid(S * M + '-INF' * (1-M)):     
                            matrix (B, nh, N, N)
    O := D_attn(P) @ V:     matrix (B, nh, N, d)
    y := D_out(O @ mO):     matrix (B, nh, N, d),   output tensor
                            matrix (B, N, C)            reverse reshape


    

In [68]:
# Basic module class
class Attention(nn.Module):
    MAX_SEQ_LENGTH = 1024
    def __init__(self, n_embd: int,
                       n_head: int = 1,
                       dropout_ratio: float = 0,
                       **kwargs):
        super().__init__()
        assert n_embd % n_head == 0, 'Embedding size MUST be multiple of N head'
        self.n_embd = n_embd
        self.n_head = n_head
        self.d = self.n_embd // self.n_head
        self.mQ = nn.Linear(self.n_embd, self.n_embd)
        self.mK = nn.Linear(self.n_embd, self.n_embd)
        self.mV = nn.Linear(self.n_embd, self.n_embd)
        self.mO = nn.Linear(self.n_embd, self.n_embd)
        self.alpha = 1 / (self.d ** .5)
        self.register_buffer(
            'mask', 
            torch.tril(
                torch.ones(self.MAX_SEQ_LENGTH, self.MAX_SEQ_LENGTH),
            ).view(1, 1, self.MAX_SEQ_LENGTH, self.MAX_SEQ_LENGTH)
        )
        self.D_attn = nn.Dropout(dropout_ratio)
        self.D_out = nn.Dropout(dropout_ratio)

    def forward(self, x, 
                      **kwargs):
        '''
        Input,
            x           tensor(B, N, d)     input
        Output,
            y           tensor(B, N, d)     output
        '''
        B, N, C = x.size()
        assert C == self.n_embd, 'Embedding dimension MUST match'

        Q, K, V = self.mQ(x), self.mK(x), self.mV(x) 
        # reshape to (B, nh, N, d = C // nh)
        Q = Q.reshape(B, N, self.n_head, -1).transpose(1, 2)
        K = K.reshape(B, N, self.n_head, -1).transpose(1, 2)
        V = V.reshape(B, N, self.n_head, -1).transpose(1, 2)
        S = self.alpha * Q @ K.transpose(-2, -1)
        S = S.masked_fill(self.mask[:,:,:N,:N] == 0, float('-inf'))
        P = F.softmax(S, dim=-1)
        P = self.D_attn(P)
        O = P @ V 
        # reshape back to (B, N, C)
        O = O.transpose(1, 2).reshape(B, N, C)
        y = self.mO(O)
        y = self.D_out(y)

        return y 



In [70]:
attn = Attention(
    n_embd=32,
    n_head=4,
    dropout_ratio=0,
)

x = torch.rand(64, 100, 32)

attn(x).shape

torch.Size([64, 100, 32])