In [16]:
from torch import nn
import numpy as np

Compute self attention for some input sequence.
For every token, we produce a new sequence of encoding vectors vectors

In [17]:
def self_attention(input_sequence):
    output = np.zeros(shape=input_sequence.shape)

    # iterate over each token in input sequence
    # compute attention score between this token, and all others
    for i, pivot_vector in enumerate(input_sequence):
        scores = np.zeros(shape=(len(input_sequence),))
        for j, vector in enumerate(input_sequence):
            # compute attention score with dot product between tokens
            scores[j] = np.dot(pivot_vector, vector.T)
        scores /= np.sqrt(input_sequence.shape[1])
        scores = nn.Softmax(scores)
        new_pivot_representation = np.zeros(shape=pivot_vector.shape)

        # sum all tokens weighted by attention score, which we return
        for j, vector in enumerate(input_sequence):
            new_pivot_representation += vector * scores[j]
        output[i] = new_pivot_representation
    return output

Take input sequence, of some length and some embedding vector size. Transform this into 3 tensors (query, key, value) of dimension sequence length by embedding vector size. Multiply by weight tensors of embedding vector size by embedding vector size. This gives us tensors with dimensions of the sequence length by the size of the embedding vector which we split along the embedding dimension into the amount of heads we want. Each of these serves as a head of the multi-headed attention and is the sized by the ratio between embedding vector size and the amount of desired heads. Each head has access to the full input sequence, but is limited in the embeddings it can see for each token. Softmax the different groups of heads together by which embedding vector they have access to, then finally concatenate the heads and multiply them by the weight tensor of the dimension of the input. 

In [None]:
# note, all dimensions actually include batch dimension
# for each tensor, the dimensions look more like (sequence_batch, d_model, etc...)
# as we do a batch for each sequence
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float) -> None:
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model # embedding vector size
        self.n_heads = n_heads # number of heads
        assert d_model % n_heads == 0, "embedding vector size (d_model) needs to be divisible by the nubmer of heads (n_heads)"
        self.d_k = d_model // n_heads # size of each head
        self.d_v = self.d_k # in attention is all you need, this is called d_v, but its equal to d_k

        # the 3 weight tensors of embedding vector size by embedding vector size
        self.w_q = nn.Linear(d_model, d_model) # query weight tensor
        self.w_k = nn.Linear(d_model, d_model) # key weight tensor
        self.w_v = nn.Linear(d_model, d_model) # value weight tensor

        # weight tensor to multiply by the concatenated heads at the end
        self.w_concat = nn.Linear(n_heads * self.d_v, d_model)

        # dropout
        self.Dropout = nn.Dropout(dropout)

    # Scaled Dot-Product Attention
    def attention(Q, K, V, mask, dropout):
        # K/keys has dimension (batch num, num heads, sequence length, d_k)
        # input V has dimension d_v (fairly certain its the same)
        # last element in keys dimensions
        d_k = K.shape[-1] 
        # query multiplied by transpose of last 2 dimensions of key
        # divided by the square root of d_k (temperature) 
        # helps to prevent dot products from growing in magnitude causing vanishing gradients
        attn = nn.matmul(Q / np.sqrt(d_k), K.transpose(2, 3))

        # masking
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)
        
        attn = nn.Softmax(attn, dim = -1)

        # dropout
        if dropout is not None:
            attn = dropout(attn)

        output = nn.matmul(attn, V)
        # return output and attention scores
        return output, attn
        

    # Masking, replacing values we don't want to interact. Default is no mask.
    # If mask is applied softmax is applied puts them to 0. 
    # Hides attention of those tokens. Otherwise just gets values
    # for each token with each other token. 
    def forward(self, Q, K, V, mask=None):
        # multiply each tensor with its weight tensor
        query_tensor = self.w_q(Q)
        key_tensor = self.w_k(K)
        value_tensor = self.w_v(V)

        # split by the number of heads
        query_tensor = self.split(query_tensor)
        key_tensor = self.split(key_tensor)
        value_tensor = self.split(value_tensor)

        # get attention 
        output, self.attention_scores = self.attention(query_tensor, key_tensor, value_tensor, mask, self.Dropout)
        
        # concatenating, contiguous for in place in memory
        output = output.transpose(1, 2).contiguous().view(output.shape[0], -1, self.h * self.d_k)

        # multiply by output tensor
        return self.w_concat(output)
    
    # Helper function to do logic of splitting by number of heads
    def split(self, tensor):
        
        # dimensions go from (batch num, sequence length, d_model) to (batch num, number of heads, sequence legnth, size of each head)
        split_output = tensor.view(tensor.shape[0], tensor.shape[1], self.n_heads, self.d_k).transpose(1, 2)
        return split_output


