In [1]:
!pip install labml

Collecting labml
  Downloading labml-0.5.3-py3-none-any.whl.metadata (7.1 kB)
Downloading labml-0.5.3-py3-none-any.whl (94 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/94.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m94.6/94.6 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: labml
Successfully installed labml-0.5.3


In [2]:
import math
from typing import Optional, List
import torch
from torch import nn
from labml import tracker

In [3]:
class headSplit(nn.Module):
  def __init__(self, d_model: int, heads: int, d_k: int, bias:bool):
    super().__init__()

    self.linear = nn.linear(d_model, heads*d_k, bias)
    self.heads = heads
    self.d_k = d_k

  def forward(self, x: torch.Tensor):
    head_shape = x.shape[-1]
    x = self.linear
    x = x.view(*head_shape, self.heads, self.d_k)
    return x

In [4]:
class MultiHeadedAttention(nn.Module):
  def __init__(self, heads: int, d_model: int, dropout_prob: int, bias: bool = True):
    super().__init__()
    self.heads = heads
    self.d_k = d_model//heads

    self.query = headSplit(d_model, heads, self.d_k, bias = bias)
    self.key = headSplit(d_model, heads, self.d_k, bias = bias)
    self.query = headSplit(d_model, heads, self.d_k, bias = True)

    self.softmax = nn.Softmax(dim = 1)

    self.output = nn.linear(d_model, d_model)

    self.dropout = nn.Dropout(p = dropout_prob)

    self.scale = 1/ math.sqrt(self.d_k)

    self.attention = None



  def get_scores(self, query: torch.Tensor, key: torch.Tensor):
    return torch.einsum('ibhd,jbhd->ijbh', query, key)



  def prep_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
    assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
    assert mask.shape[1] == key_shape[0]
    assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]

    mask.unsqueeze(-1)

    return mask



  def forward(self, *, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None):
      seq_len, batch_size, _ = query.shape
      if mask is not None:
        mask = self.prep_mask(mask, query.shape, key.shape)

      query = self.query(query)
      key = self.key(key)
      value = self.value(value)

      scores = self.get_scores(query, key)

      scores *= self.scale

      if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

      attention = self.softmax(scores)

      tracker.debug('attention', attention)

      attention = self.dropout(attention)

      x = torch.einsum('ijbh,jbhd->ijbh', attention, value)

      self.attention = attention.detach()

      x = x.reshape(seq_len, batch_size, -1)

      return self.output(x)