In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cpu


## Transformer Architecture

In [5]:
class MultiHeadAttention(nn.Module):

  def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
    super().__init__()
    assert d_model % n_heads == 0

    self.d_model = d_model
    self.n_heads = n_heads
    self.d_k = d_model // n_heads

    self.q_linear = nn.Linear(d_model, self.d_k * n_heads)
    self.k_linear = nn.Linear(d_model, self.d_k * n_heads)
    self.v_linear = nn.Linear(d_model, self.d_k * n_heads) # Because d_v = d_k
    self.out = nn.Linear(d_model, self.d_k * n_heads)

    self.dropout = nn.Dropout(dropout)

  def forward(self, query, key, value, mask = None):

    batch_size = query.size(0)

    Q = self.q_linear(query).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # batch_size, n_head, seq_len, d_k
    K = self.k_linear(key).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # batch_size, n_head, seq_len, d_k
    V = self.v_linear(value).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # batch_size, n_head, seq_len, d_k

    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) # batch_size, n_head, seq_len, seq_len

    if mask is not None:
      scores = scores.masked_fill(mask == 0, -1e9)

    attention_weights = F.softmax(scores, dim = -1) # batch_size, n_head, seq_len, seq_len
    attention_weights = self.dropout(attention_weights)

    # Concatenate heads and put through final linear layer
    context = torch.matmul(attention_weights, V) # batch_size, n_head, seq_len, d_k

    context = context.transpose(1, 2).contiguous().view(
        batch_size, -1, self.d_model
    )

    output = self.out(context)

    return output, attention_weights

