# Self-Attention and the Transformer
CS-GY 9223 Deep Learning, Fall 2020

## Some helpful visual aids:

Transformer: http://jalammar.github.io/illustrated-transformer/

Attention: https://jalammar.github.io/visualizing-neural-machine-translation-mechanics-of-seq2seq-models-with-attention/

In [None]:
import torch 
from torch import nn
import torch.nn.functional as f
import numpy as np

# Self-attention

Consider a set of $t$ inputs $\{\mathbf{x}_{i}\}^{t}_{i=1} \in \mathbb{R}^{n}.$

These $\mathbf{x}$'s can form a matrix with $n$ rows and $t$ columns:

$$\mathbf{X} = \begin{bmatrix} | & | &  & |\\
\mathbf{x}_{1}&\mathbf{x}_{2}&\cdots&\mathbf{x}_{t}\\
| & | &  & | \end{bmatrix} \in \mathbb{R}^{n\times t}.$$

We then consider a hidden representation which is a linear combination of our column vectors $\mathbf{x}_{i}$ : 
$$
\begin{align*}
\mathbf{h} &= \alpha_{1}\mathbf{x}_{1} + \alpha_{2}\mathbf{x}_{2} + \cdots + \alpha_{t}\mathbf{x}_{t},
\end{align*}
$$

i.e.,

$$
\begin{aligned}
\mathbf{h} &=
\begin{bmatrix} | & | &  & |\\
\mathbf{x}_{1}&\mathbf{x}_{2}&\cdots&\mathbf{x}_{t}\\
| & | &  & | \end{bmatrix}
\begin{bmatrix}
\alpha_{1}\\
\alpha_{2}\\
\vdots\\
\alpha_{t}
\end{bmatrix} = \mathbf{X}\mathbf{a} \in \mathbb{R}^{n}
\end{aligned}
$$

Where do these coefficients come from? In a standard Transformer model, they are scores found by applying the softmax function:
$\mathbf{a} = \text{softmax}(\mathbf{X}^{\top}\mathbf{x})\in\mathbb{R}^{t}.$ [This](https://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/) is an excellent discussion of how softmax works if you need to refresh your memory.

$\mathbf{a}$ encodes the value of the dot product of input vector $\mathbf{x}_{t}$ with every other vector in the set $\mathbf{X}$. 
<!-- Every element is the scalar product of the whole set $\mathbf{x}$ against a given $\mathbf{x}$. -->

<!-- Note: $\beta$ is a parameter of the soft $\arg\max$ (softmax); in energy terms, the inverse of the temperature - the exponential of the argument divided by summation of all exponentials).  -->

### "Soft" vs "hard" attention

We can make a choice between using "hard" attention via $\arg\max$ and "soft" attention via $\mathrm{softmax}$. In hard attention, $\mathbf{a}$ is a one-hot vector, and multiplication by $\mathbf{X}$ is a selection of a single column, choosing only one element of the set $\mathbf{X}$ with maximum similarity score. In this case we have that the $L_0$ "norm" $\lvert\lvert\mathbf{a}\rvert\rvert_{0} = 1.$ In soft attention, $\mathbf{a}$ is a distribution which assigns a non-zero probability to every element in the set $\mathbf{X}.$ Then we have that the $L_1$ norm $\lvert\lvert\mathbf{a}\rvert\rvert_{1} = 1.$ If you are wondering whether there is a function which assigns a non-zero probability to only some of the elements, you are correct, and it is called [sparsemax](https://arxiv.org/abs/1602.02068).

A set of $\mathbf{x}$'s implies a set of $\mathbf{a}$ score vectors, which can be stacked into a matrix $\mathbf{A} \in \mathbb{R}^{t\times t}$ (since $\mathbf{a}$ has size $t$ for the $t$ rows in $\mathbf{x}^{\top}$). For the set of $\mathbf{a}$'s we also have a set of $\mathbf{h}$'s: $\mathbf{H} \in \mathbb{R}^{n\times t},$ so we can write finally

$$\mathbf{H} = \mathbf{X}\mathbf{A} \in\mathbb{R}.$$

$\mathbf{H}$ is a linear combination of the elements of $\mathbf{X}$ using the factors in the columns of $\mathbf{A}$.

Overall, what we are doing is to mix the components of the set of $\mathbf{x}$'s by using these coefficients which are computed using the soft argmax, where each component has a score of cosine similarity (dot product) of a given $\mathbf{x}$ against the set $\mathbf{X}.$

# Key-value store

Conceptually, we are checking how aligned is the query against all the values in the dataset (compute how matching the dataset values are with respect to your query). We can retrieve the single maximum matching element with $\arg\max$ or use soft $\arg\max$ to return a distribution which has a score for every element, in which case we can retrieve things with an ordering of similarity.

Queries, keys, and values are rotations of input $\mathbf{x}$: 
$$\mathbf{q} = W_q \mathbf{x}$$ 
$$\mathbf{k} = W_k \mathbf{x}$$ 
$$\mathbf{v} = W_v \mathbf{x}$$
These rotations $W_q, W_k, W_v$ are training parameters.

Attention is *completely based on affine orientation*: the only nonlinear operation is introduced by the softmax. $\mathbf{q}$ and $\mathbf{k}$ must have the same dimension; $\mathbf{v}$ is the returned value/content associated with a given key, which can have any size, though in practice it is usually taken to be the same size as $\mathbf{q}$ and $\mathbf{k}$.

Given that we have a set of $\mathbf{x}$'s, we'll have a set of queries, keys, values, and we can make a matrix stacking them all up. This matrix has $t$ cols of row vectors of size $d$. In the attention operation, we check a query $\mathbf{q}$ against all keys by applying $\mathbf{K}^{\top}\mathbf{q}$. This returns $t$ scores which constitute a probability distribution over the space of possible matching sequences.

# Transformer Model

## Multi-head attention module

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
nn_Softargmax = nn.Softmax  # a more correct/descriptive name

In [None]:
# multiple heads: allows for multiple properties per query

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, p, d_input=None):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        if d_input is None:
            d_xq = d_xk = d_xv = d_model
        else:
            d_xq, d_xk, d_xv = d_input
            
        # Make sure that the embedding dimension of model is a multiple of number of heads
        assert d_model % self.num_heads == 0

        self.d_k = d_model // self.num_heads

        # matrices allowing to rotate current input
        # (These are still of dimension d_model. They will be split into number of heads)
        self.W_q = nn.Linear(d_xq, d_model, bias=False)
        self.W_k = nn.Linear(d_xk, d_model, bias=False)
        self.W_v = nn.Linear(d_xv, d_model, bias=False)
        
        # Outputs of all sub-layers need to be of dimension d_model
        self.W_h = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V):
        batch_size = Q.size(0)
        k_length = K.size(-2) 
        
        # Scaling by d_k so that the soft(arg)max doesn't saturate
        Q = Q / np.sqrt(self.d_k) # (bs, n_heads, q_length, dim_per_head)

        # multiplication between one query and all keys
        scores = torch.matmul(Q, K.transpose(2,3)) # (bs, n_heads, q_length, k_length)

        # compute the mixing coefficients
        A = nn_Softargmax(dim=-1)(scores) # (bs, n_heads, q_length, k_length)
        
        # get the weighted average of the values - multipy mixing coeff with V matrix
        H = torch.matmul(A, V) # (bs, n_heads, q_length, dim_per_head)

        return H, A

        
    def split_heads(self, x, batch_size):
        """
        Split the last dimension into (heads X depth)
        Return after transpose to put in shape (batch_size X num_heads X seq_length X d_k)
        """
        return x.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

    def group_heads(self, x, batch_size):
        """
        Combine the heads again to get (batch_size X seq_length X (num_heads times d_k))
        """
        return x.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
    

    def forward(self, X_q, X_k, X_v):
        batch_size, seq_length, dim = X_q.size()

        # apply W transformation (learned rotation of x input), then split into num_heads 
        Q = self.split_heads(self.W_q(X_q), batch_size)  # (bs, n_heads, q_length, dim_per_head)
        K = self.split_heads(self.W_k(X_k), batch_size)  # (bs, n_heads, k_length, dim_per_head)
        V = self.split_heads(self.W_v(X_v), batch_size)  # (bs, n_heads, v_length, dim_per_head)
        
        # compute scaled dot product between one query against all keys
        # i.e. calculate the attention weights for each of the heads
        H_cat, A = self.scaled_dot_product_attention(Q, K, V)
        
        # Put all the heads back together by concat
        H_cat = self.group_heads(H_cat, batch_size)  # (bs, q_length, dim)
        
        # Final linear layer  
        H = self.W_h(H_cat)  # (bs, q_length, dim)
        
        return H, A

### Check how the self-attention mechanism works:

In [None]:
temp_mha = MultiHeadAttention(d_model=512, num_heads=8, p=0)
def print_out(Q, K, V):
    temp_out, temp_attn = temp_mha.scaled_dot_product_attention(Q, K, V)
    print('Attention weights are:', temp_attn.squeeze())
    print('Output is:', temp_out.squeeze())

To check our self attention does what we expect: if the query matches with one of the key values, it should have all the "attention" focused there, with the value returned being the value at that index.

In [None]:
test_K = torch.tensor(
    [[10, 0, 0],
     [ 0,10, 0],
     [ 0, 0,10],
     [ 0, 0,10]]
).float()[None, None]

test_V = torch.tensor(
    [[   1,0,0],
     [  10,0,0],
     [ 100,5,0],
     [1000,6,0]]
).float()[None, None]

test_Q = torch.tensor(
    [[0, 10, 0]]
).float()[None, None]

print_out(test_Q, test_K, test_V)

We can see that it focuses on the second key and returns the second value. 

If we give a query that matches two keys exactly, it should return the averaged value of the two values for those two keys. 

In [None]:
test_Q = torch.tensor([[0, 0, 10]]).float()  
print_out(test_Q, test_K, test_V)

We see that it focuses equally on the third and fourth key and returns the average of their values.

Now, passing all the queries at the same time:

In [None]:
test_Q = torch.tensor(
    [[0, 0, 10], [0, 10, 0], [10, 10, 0]]
).float()[None,None]
print_out(test_Q, test_K, test_V)

## 1D convolution with `kernel_size = 1`

This is equivalent to an MLP with one hidden layer and ReLU activation applied to each and every element in the set.

In [None]:
# element-wise feedforward = 1d convolution with kernel size 1
# linear layer maps a representation to some other representation (is a transformation)
# convolution maps one set to another set - which is what we are actually doing here
# apply same linear transform to every element in a sequence

# conv hidden layer is applied to every component in the set - every element treated separately
# if you apply same linear layer to every element in a sequence -> that's a convolution
# in practice, implementations generally use a linear layer

class CNN(nn.Module):
    def __init__(self, d_model, hidden_dim, p):
        super().__init__()
        self.k1convL1 = nn.Linear(d_model,    hidden_dim)
        self.k1convL2 = nn.Linear(hidden_dim, d_model)
        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.k1convL1(x)
        x = self.activation(x)
        x = self.k1convL2(x)
        return x

## Transformer encoder

In [None]:
# Components of encoder block:
# 1: self attention
# 2: convolution - MLP applied to very element in the set

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, conv_hidden_dim, p=0.1):
        super().__init__()

        self.mha = MultiHeadAttention(d_model, num_heads, p)
        self.cnn = CNN(d_model, conv_hidden_dim, p)

        self.layernorm1 = nn.LayerNorm(normalized_shape=d_model, eps=1e-6)
        self.layernorm2 = nn.LayerNorm(normalized_shape=d_model, eps=1e-6)
    
    def forward(self, x):
        
        # Multi-head attention
        attn_output, _ = self.mha(x, x, x)  # (batch_size, input_seq_len, d_model)
        
        # Layer norm after adding the residual connection 
        out1 = self.layernorm1(x + attn_output)  # (batch_size, input_seq_len, d_model)
        
        # Feed forward 
        cnn_output = self.cnn(out1)  # (batch_size, input_seq_len, d_model)
        
        #Second layer norm after adding residual connection 
        out2 = self.layernorm2(out1 + cnn_output)  # (batch_size, input_seq_len, d_model)

        return out2

## Positional Embeddings

See https://kazemnejad.com/blog/transformer_architecture_positional_encoding/ for illustration.

The attention operation we have defined above is permutation equivariant. For input which is ordered, such as words in a sentence, we need to somehow account for the order of the words. In the classic Transformer, we add information about position not in the model itself, but by enhancing each value in the input with some information about its position. Since the Transformer architecture is equipped with residual connections, the positional information in the input is also able to propagate directly to further layers.

Some criteria for a position-sensitive encoding function:
- Should output a unique encoding for each time-step/word position in a sentence
- Distance between any two time-steps should be consistent across sentences with different lengths
- Model should generalize to longer sentences without any efforts; values should be bounded
- Must be deterministic

As such: let $t$ be the desired position in an input sentence, $p_{t} \in \mathbb{R}^{d}$ be its corresponding encoding, and $d$ be the encoding dimension. The positional embedding is a transformation of the word embedding:

$$\psi^{\prime}(w_t) = \psi(w_t)+p_t.$$

Sinusoidal positional embeddings, for example, can be defined with:

\begin{aligned}
E(p, 2i)    &= \sin(p / 10000^{2i / d}) \\
E(p, 2i+1) &= \cos(p / 10000^{2i / d}),
\end{aligned}

so that the positional embedding $p_t$ is a vector containing pairs of sines and cosines.

<!-- - represents $p_{t+\phi}$ as a linear function of $p_t$ for any fixed offset $\phi$ - the sines and cosines implement a rotation transformation -->

<!-- - position as the frequency of flip in value when incrementing, which varies depending on the bit position -> sinusoidal functions as the continuous version of alternating bits -->


In [None]:
def create_sinusoidal_embeddings(nb_p, dim, E):
    theta = np.array([
        [p / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
        for p in range(nb_p)
    ])
    E[:, 0::2] = torch.FloatTensor(np.sin(theta[:, 0::2]))
    E[:, 1::2] = torch.FloatTensor(np.cos(theta[:, 1::2]))
    E.detach_()
    E.requires_grad = False
    E = E.to(device)

class Embeddings(nn.Module):
    def __init__(self, d_model, vocab_size, max_position_embeddings, p):
        super().__init__()
        self.word_embeddings = nn.Embedding(vocab_size, d_model, padding_idx=1) # a simple lookup table that stores embeddings of a fixed dictionary and size
        self.position_embeddings = nn.Embedding(max_position_embeddings, d_model)
        create_sinusoidal_embeddings(
            nb_p=max_position_embeddings,
            dim=d_model,
            E=self.position_embeddings.weight
        )

        self.LayerNorm = nn.LayerNorm(d_model, eps=1e-12)

    def forward(self, input_ids):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)                      # (b, max_seq_length)
        
        # Get word embeddings for each input id
        word_embeddings = self.word_embeddings(input_ids) # (b, max_seq_length, dim)
        
        # Get position embeddings for each position id 
        position_embeddings = self.position_embeddings(position_ids) # (b, max_seq_length, dim)
        
        # Add them both 
        embeddings = word_embeddings + position_embeddings  # (b, max_seq_length, dim)
        
        # Layer norm 
        embeddings = self.LayerNorm(embeddings) # (b, max_seq_length, dim)
        return embeddings

## Overall Encoder 
#### (Blocks of N Encoder Layers + Positional encoding + Input embedding)

In [None]:
class Encoder(nn.Module):
    def __init__(self, 
                 num_layers, 
                 d_model, 
                 num_heads, 
                 ff_hidden_dim, 
                 input_vocab_size,
                 maximum_position_encoding, p=0.1):
        super().__init__()

        self.d_model = d_model
        self.num_layers = num_layers
        
        # apply permutation-sensitive embeddings
        self.embedding = Embeddings(d_model, 
                                    input_vocab_size,
                                    maximum_position_encoding, 
                                    p)

        self.enc_layers = nn.ModuleList()
        for _ in range(num_layers):
            self.enc_layers.append(EncoderLayer(d_model, 
                                                num_heads, 
                                                ff_hidden_dim, 
                                                p))
        
    def forward(self, x):
        x = self.embedding(x) # Transform to (batch_size, input_seq_length, d_model)
        # stack multiple to make network "more powerful"
        # append several encoders together
        for i in range(self.num_layers):
            x = self.enc_layers[i](x)

        return x  # (batch_size, input_seq_len, d_model)

In [None]:
import torchtext.data as data
import torchtext.datasets as datasets

In [None]:
max_len = 200
text = data.Field(sequential=True, 
                  fix_length=max_len, 
                  batch_first=True, 
                  lower=True, 
                  dtype=torch.long)
label = data.LabelField(sequential=False, 
                        dtype=torch.long)

# using torch's IMDB dataset https://pytorch.org/text/stable/datasets.html#imdb
datasets.IMDB.download('./')
ds_train, ds_test = datasets.IMDB.splits(text, label, path='./imdb/aclImdb/')
print('train : ', len(ds_train))
print('test : ', len(ds_test))
print('train.fields :', ds_train.fields)

In [None]:
ds_train, ds_valid = ds_train.split(0.9)
print('train : ', len(ds_train))
print('valid : ', len(ds_valid))
print('test : ', len(ds_test))

In [None]:
num_words = 50000
text.build_vocab(ds_train, max_size=num_words)
label.build_vocab(ds_train)
vocab = text.vocab

In [None]:
batch_size = 164
train_loader, valid_loader, test_loader = data.BucketIterator.splits(
                                                (ds_train, ds_valid, ds_test), 
                                                batch_size=batch_size, 
                                                sort_key=lambda x: len(x.text), 
                                                repeat=False)

In [None]:
class TransformerClassifier(nn.Module):
    def __init__(self, 
                 num_layers, 
                 d_model, 
                 num_heads, 
                 conv_hidden_dim, 
                 input_vocab_size, 
                 num_answers):
        super().__init__()
        
        self.encoder = Encoder(num_layers, 
                               d_model, 
                               num_heads, 
                               conv_hidden_dim, 
                               input_vocab_size,
                               maximum_position_encoding=10000)
        self.dense = nn.Linear(d_model, num_answers)

    def forward(self, x):
        x = self.encoder(x)
        
        x, _ = torch.max(x, dim=1)
        x = self.dense(x)
        return x

In [None]:
model = TransformerClassifier(num_layers=1, 
                              d_model=32, 
                              num_heads=2, 
                              conv_hidden_dim=128, 
                              input_vocab_size=50002, 
                              num_answers=2)
model.to(device)

TransformerClassifier(
  (encoder): Encoder(
    (embedding): Embeddings(
      (word_embeddings): Embedding(50002, 32, padding_idx=1)
      (position_embeddings): Embedding(10000, 32)
      (LayerNorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
    )
    (enc_layers): ModuleList(
      (0): EncoderLayer(
        (mha): MultiHeadAttention(
          (W_q): Linear(in_features=32, out_features=32, bias=False)
          (W_k): Linear(in_features=32, out_features=32, bias=False)
          (W_v): Linear(in_features=32, out_features=32, bias=False)
          (W_h): Linear(in_features=32, out_features=32, bias=True)
        )
        (cnn): CNN(
          (k1convL1): Linear(in_features=32, out_features=128, bias=True)
          (k1convL2): Linear(in_features=128, out_features=32, bias=True)
          (activation): ReLU()
        )
        (layernorm1): LayerNorm((32,), eps=1e-06, elementwise_affine=True)
        (layernorm2): LayerNorm((32,), eps=1e-06, elementwise_affine=True)
   

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
epochs = 10
t_total = len(train_loader) * epochs

In [None]:
def train(train_loader, valid_loader):
    
    for epoch in range(epochs):
        train_iterator, valid_iterator = iter(train_loader), iter(valid_loader)
        nb_batches_train = len(train_loader)
        train_acc = 0
        model.train()
        losses = 0.0

        for batch in train_iterator:
            x = batch.text.to(device)
            y = batch.label.to(device)
            
            out = model(x)

            loss = f.cross_entropy(out, y)
            
            model.zero_grad()

            loss.backward()
            losses += loss.item()

            optimizer.step()
                        
            train_acc += (out.argmax(1) == y).cpu().numpy().mean()
        
        print("Training loss at epoch {i} is {:2f}".format(epoch, 
                                                           losses / nb_batches_train))
        print("Training accuracy: {:2%}".format(train_acc / nb_batches_train))
        print('Evaluating on validation:')
        evaluate(valid_loader)

In [None]:
def evaluate(data_loader):
    data_iterator = iter(data_loader)
    nb_batches = len(data_loader)
    model.eval()
    acc = 0 
    for batch in data_iterator:
        x = batch.text.to(device)
        y = batch.label.to(device)
                
        out = model(x)
        acc += (out.argmax(1) == y).cpu().numpy().mean()

    print("Eval accuracy: {:2%}".format(acc / nb_batches))

In [None]:
train(train_loader, valid_loader)

Training loss at epoch 0 is 0.7528660064158232
Training accuracy: 0.5295102067868505
Evaluating on validation:
Eval accuracy: 0.5893673780487805
Training loss at epoch 1 is 0.6515302515548208
Training accuracy: 0.6335056557087311
Evaluating on validation:
Eval accuracy: 0.6597179878048781
Training loss at epoch 2 is 0.5937000944994498
Training accuracy: 0.7013520678685047
Evaluating on validation:
Eval accuracy: 0.6926067073170731
Training loss at epoch 3 is 0.49995334036108374
Training accuracy: 0.7663706256627787
Evaluating on validation:
Eval accuracy: 0.7580411585365855
Training loss at epoch 4 is 0.4073252291351125
Training accuracy: 0.8173548515376458
Evaluating on validation:
Eval accuracy: 0.778887195121951
Training loss at epoch 5 is 0.3362871474329976
Training accuracy: 0.8570828914810882
Evaluating on validation:
Eval accuracy: 0.7899771341463414
Training loss at epoch 6 is 0.2818543931496316
Training accuracy: 0.8850079533404033
Evaluating on validation:
Eval accuracy: 0.80

In [None]:
evaluate(test_loader)

Eval accuracy: 0.8091711390970119
