# Attention

> Fill in a module description here

In [1]:
#| default_exp attention

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

In [3]:
#| export
from typing import Optional, Tuple

import torch
from torch import nn
import torch.nn.functional as F

from torchtyping import TensorType

In [4]:
#| export
class Attention(nn.Module):
    def __init__(self, query_dim: int, key_dim: int, value_dim: int, dropout: float=0.1):
        super().__init__()
        self.query_dim = query_dim
        self.key_dim = key_dim
        self.value_dim = value_dim
        self.dropout = dropout
        self.scale = 1 / (key_dim ** 0.5)
        
        self.query = nn.Linear(query_dim, key_dim)
        self.key = nn.Linear(key_dim, key_dim)
        self.value = nn.Linear(value_dim, value_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        query: TensorType["batch_size", "seq_len", "n_dim"],
        key: TensorType["batch_size", "seq_len", "n_dim"],
        value: TensorType["batch_size", "seq_len", "n_dim"],
        mask: Optional[TensorType["batch_size", "seq_len", "seq_len"]] = None
    ) -> Tuple[
        TensorType["batch_size", "seq_len", "n_dim"],
        TensorType["batch_size", "seq_len", "seq_len"]
    ]:
        query = self.query(query)
        key = self.key(key)
        value = self.value(value)
        scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        p_attn = F.softmax(scores, dim=-1)
        p_attn = self.dropout(p_attn)
        
        output = torch.matmul(p_attn, value)
        
        return output, p_attn