In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import numpy as np

# Pytorch Transformer

> Based on the paper [Attention is All You Need](https://arxiv.org/abs/1706.03762)

> Google's original implementation in TensorFlow [here](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py)

The Transformer network is an attention only sequence transduction model "dispensing with recurrence and convolutions entirely".

# Transformer Architecture 

The model is made up of $N$ identical encoder/decoder layers ($N=6$ in the paper).

![Transformer Model Architecture](./images/network_architecture.png)

Each encoder/decoder is made up of two types of Sub-Layers:

- A Multi-Head Attention Layer
- A Feed-Forward Network

**TODO**: Layer-normilization is applied around each sub-layer before being passed to the next. Don't forget to implement this in the full model later.

We're going to implement from more or less scratch! So we need to dig down into the sub-layers and build up.

## Sub-Layer: Multi-Head Attention

> An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum
of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.

>-- Attention is All You Need, 3.2

These queries, keys and values are computed as matrices

The sub-layer looks like:

![Multi-head attention](./images/multi-head.png)

Where:

$$MultiHead(Q, K, V) = Concat(head_{1}, \ldots{}, head_{h}) W^{O}$$

and 

$$head_i = Attention(QW_{i}^{Q}, KW_{i}^{K}, VW_{i}^{V})$$


### Single Attention Computation

Recursing ever further into our model, we need the math behind a single Attention calculation:

![Attention Function](./images/attention_calc.png)

or, in notation:

$$Attention(Q,K,V) = softmax \left( \frac{ QK^{T} }{ \sqrt{ d_{k} } } \right) V$$

This is a version of multiplicative attention. It's scaled by the dimension of the key to prevent the dot product from getting out of hand (see this [blog post](http://ruder.io/deep-learning-nlp-best-practices/index.html#attention) for a nice overview of additive vs. multiplicative attention)


Some notes about the google implementation:
- The optional mask is implemented simply by setting masked values to $-\infty$.
- Later in the paper

This is something we can implement! 


In [16]:
class ScaledMultiAttention(nn.Module):
    
    def __init__(self, dim_key, drop_percent=0.1):
        
        super(ScaledMultiAttention, self).__init__()
        
        # The value to scale by will be constant throughout model
        self.scale_value = np.sqrt(dim_key)
        
        # Layers
        self.dropout = nn.Dropout(drop_percent)
        self.softmax = nn.Softmax()
        
    def forward(self, Q, K, V, mask=None):
        
        # Remember, the 1st dim of torch tensors is the batch size
        attention = torch.bmm(Q, K.transpose(1, 2))
        
        if mask:
            attention.masked_fill_(mask, -float('inf'))
            
        # Ugh... python... get you a pipe operator or something
        attention = self.droupout(self.softmax(attention))
        
        attention = torch.bmm(attention, V)
        
        return attention
        