In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [1]:
def get_angles(pos, i, d_model):
    angle_rates = 1 / torch.pow(10000, (2 * (i//2)) / torch.float32(d_model))
    return pos * angle_rates

In [2]:
def positional_encoding_1d(position, d_model):
    pos = torch.arange(position).unsqueeze(1)
    i = torch.arange(d_model).unsqueeze(0)

    angle_rads = get_angles(pos, i, d_model)

    angle_rads[:, 0::2] = torch.sin(angle_rads[:, 0::2])
    angle_rads[:, 1::2] = torch.cos(angle_rads[:, 1::2])
    pos_encoding = angle_rads.unsqueeze(0)
    return pos_encoding.float()

In [3]:
def positional_encoding_2d(row, col, d_model):
    assert d_model % 2 == 0

    row_pos = torch.repeat(torch.arange(row), col).unsqueeze(1)
    col_pos = torch.repeat(torch.arange(col).unsqueeze(0), row, dim=0).reshape(-1, 1)

    angle_rads_row = get_angles(row_pos, torch.arange(d_model // 2).unsqueeze(0), d_model // 2)
    angle_rads_col = get_angles(col_pos, torch.arange(d_model // 2).unsqueeze(0), d_model // 2)

    angle_rads_row[:, 0::2] = torch.sin(angle_rads_row[:, 0::2])
    angle_rads_row[:, 1::2] = torch.cos(angle_rads_row[:, 1::2])
    angle_rads_col[:, 0::2] = torch.sin(angle_rads_col[:, 0::2])
    angle_rads_col[:, 1::2] = torch.cos(angle_rads_col[:, 1::2])

    pos_encoding = torch.cat([angle_rads_row, angle_rads_col], dim=1).unsqueeze(0)
    return pos_encoding.float()

In [5]:
def create_padding_mask(seq):
    seq = torch.eq(seq, 0).float()
    return seq.unsqueeze(1).unsqueeze(2)  # (batch_size, 1, 1, seq_len)

In [6]:
def create_look_ahead_mask(size):
    mask = 1 - torch.tril(torch.ones(size, size))
    return mask  # (seq_len, seq_len)

In [7]:
def scaled_dot_product_attention(q, k, v, mask):
    matmul_qk = torch.matmul(q, k.transpose(-2, -1))  # (..., seq_len_q, seq_len_k)
    dk = torch.tensor(k.size(-1), dtype=torch.float32)
    scaled_attention_logits = matmul_qk / torch.sqrt(dk)

    if mask is not None:
        scaled_attention_logits += (mask * -1e9)

    attention_weights = torch.nn.functional.softmax(scaled_attention_logits, dim=-1)
    output = torch.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

    return output, attention_weights

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        assert d_model % self.num_heads == 0
        self.depth = d_model // self.num_heads
        
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.dense = nn.Linear(d_model, d_model)
    
    def split_heads(self, x, batch_size):
        x = x.view(batch_size, -1, self.num_heads, self.depth)
        return x.permute(0, 2, 1, 3)
    
    def forward(self, v, k, q, mask=None):
        batch_size = q.size(0)

        q = self.wq(q)  # (batch_size, seq_len, d_model)
        k = self.wk(k)  # (batch_size, seq_len, d_model)
        v = self.wv(v)  # (batch_size, seq_len, d_model)

        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)

        scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask)
        scaled_attention = scaled_attention.permute(0, 2, 1, 3)  # (batch_size, seq_len_q, num_heads, depth)

        concat_attention = scaled_attention.contiguous().view(batch_size, -1, self.d_model)  # (batch_size, seq_len_q, d_model)

        output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)
        return output, attention_weights