In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F



In [9]:
class Attention(nn.Module):
    def __init__(self, d_model=2, row_dim=0, col_dim=1):
        super().__init__()
        
        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        
        self.row_dim = row_dim
        self.col_dim = col_dim
        
    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):
        ## ...and we pass those sets of encodings to the various weight matrices.
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)

        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)
            
        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores


In [10]:
## create matrices of token encodings...
encodings_for_q = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_k = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_v = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

In [11]:
torch.manual_seed(42)

<torch._C.Generator at 0x108bade30>

In [12]:
## create an attention object
attention = Attention(d_model=2,
                      row_dim=0,
                      col_dim=1)

In [13]:
## calculate encoder-decoder attention
attention(encodings_for_q, encodings_for_k, encodings_for_v)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

In [14]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=2, row_dim=0, col_dim=1, num_heads=1):
        super().__init__()
        
        self.heads = nn.ModuleList(
            [Attention(d_model, row_dim, col_dim)
             for _ in range(num_heads)]
        )
        
        self.col_dim = col_dim
        
    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v):
        return torch.cat([head(encodings_for_q, encodings_for_k, encodings_for_v)
                          for head in self.heads],dim=self.col_dim)

In [15]:
torch.manual_seed(42)


<torch._C.Generator at 0x108bade30>

In [16]:
multiHeadAttention = MultiHeadAttention(d_model=2, row_dim = 0, col_dim=1, num_heads=1)



In [17]:
multiHeadAttention

MultiHeadAttention(
  (heads): ModuleList(
    (0): Attention(
      (W_q): Linear(in_features=2, out_features=2, bias=False)
      (W_k): Linear(in_features=2, out_features=2, bias=False)
      (W_v): Linear(in_features=2, out_features=2, bias=False)
    )
  )
)

In [18]:
multiHeadAttention(encodings_for_q, encodings_for_k, encodings_for_v)
multiHeadAttention

MultiHeadAttention(
  (heads): ModuleList(
    (0): Attention(
      (W_q): Linear(in_features=2, out_features=2, bias=False)
      (W_k): Linear(in_features=2, out_features=2, bias=False)
      (W_v): Linear(in_features=2, out_features=2, bias=False)
    )
  )
)