In [1]:
import math
from typing import Optional, List
import torch
from torch import nn

In [2]:
class TransformKQV(nn.Module):

  def __init__(self, model_dim, heads, vec_dim_per_head, bias):
    super().__init__()
    self.linear = nn.Linear(model_dim, heads * vec_dim_per_head, bias = bias)
    self.heads = heads
    self.vec_dim_per_head = vec_dim_per_head

  def forward(self, x):
    head = x.shape[:-1]
    x = self.linear(x)
    x = x.view(*head, self.heads, self.vec_dim_per_head)
    return x

In [4]:
class MultiHeadAttention(nn.Module):

  def __init__(self, heads, num_features_kqv, dropout = 0.1, bias = True):
    super().__init__()
    self.num_features_head = num_features_kqv // heads
    self.heads = heads
    self.query = TransformKQV(num_features_kqv, heads, self.num_features_head, bias = bias)
    self.key = TransformKQV(num_features_kqv, heads, self.num_features_head, bias = bias)
    self.value = TransformKQV(num_features_kqv, heads, self.num_features_head, bias = True)
    self.dropout = nn.Dropout(dropout)
    self.output = nn.Linear(num_features_kqv, num_features_kqv)
    self.softmax = nn.Softmax(dim = 1)
    self.scale = 1 / math.sqrt(self.num_features_kqv)
    self.attention = None
  
  def get_scores(self, query, key):
    return torch.einsum('ibhd, jbhd -> ijbh', query, key)

  def mask_gen(self, mask, query_shape, key_shape):
    assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
    assert mask.shape[1] == key_shape[0]
    assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
    mask = mask.unsqueeze(-1)
    return mask

  def forward(self, *, query, key, value, mask = None):
    sequence_length, batch_size, _ = query.shape
    if mask is not None:
      mask = self.prepare_mask(mask, query.shape, key.shape)
    query = self.query(query)
    key = self.key(key)
    value = self.value(value)
    scores = self.get_scores(query, key)
    scores *= self.scale
    if mask is not None:
      scores = scores.masked_fill(mask == 0, float('-inf'))
    attention = self.softmax(scores)
    attention = self.dropout(attention)
    x = torch.einsum("ijbh,jbhd->ibhd", attention, value)
    self.attention = attention.detach()
    x = x.reshape(sequence_length, batch_size, -1)
    return self.output(x)