# The Intuited Transformer

"*What I cannot create, I do not understand.*"

-- Richard Feynman

"*There is nothing magic about magic. The magician merely understands something simple which doesn’t appear to be simple or natural to the untrained audience. Once you learn how to hold a card while making your hand look empty, you only need practice before you, too, can “do magic.”*"

– Jeffrey Friedl in the book Mastering Regular Expressions

### Why This Post?
The transformer, proposed in Vaswani et. al.'s 'Attention is All You Need' paper has shaken the foundations of the NLP community, all but replacing RNNs as NLP's prized workhorse. Ladies and gentlemen, drop the mic, *there's a new golden boy in town*. The amazing breakthroughs that have come about as a result of this architecture (heard of BERT anyone?) can make it all the more intimidating to take on the task of digging into and understanding the crux of the model architecture that have made them possible. What I find most incredible is that once you open the black box, there's nothing to be intimidated by...so much of the architecture is very interpretable, and beautifully so.

This is my humble attempt at 
* explaining the architecture in an intuitive way, amalgamating the insights and understanding I've gathered from many different resources online. 
> I've personally never been satisified by just knowing *what* a model is doing, and always find myself questioning *why* it was built this way, and *how* it's doing what the authors say it's doing. This is really just an attempt to be able to satisfy some of my own curiosity.
![image.png](attachment:image.png)
<center><figcaption>Me analyzing ML models</figcaption>


* integrating functional pyTorch code to build and train a vanilla transformer from scratch alongside intuitions and explanations of the model architecture itself. Seeing both side-by-side is something I would've found incredibly illuminating in my journey to understand this architecture. 
>This unfortunately feels a bit lacking in two other great posts: *Illustrated Transformer*: has amazing visuals (several of which are leveraged here, a big thanks to the author Jay Alammar for beautiful, clear content!) and very clear explanations but doesn't combine them with code, whereas *The Annotated Transformer*: has pyTorch code (albeit written in a way that can be a bit hard to read and intuit at first) but relies too much on the original paper for explanations, which can make it a bit hard to parse fully. *(Please note my intention here isn't to knock either of these amazing posts (I've learned a lot from them!), but to simply fill in some gaps in a way I could have benefited from in my learning journey, for others in similar shoes.)*

    
### A few things this post specifically tries to provide some intuition for (i.e. things that have kept me up at night):
* **Motivation behind self-attention, what is it really trying to achieve?**
    - most blogs just go into how to compute self-attention, without explaining its end goal
   
    
* **Why does self-attention have a query, key, value formulation?**
    - Why do we use dot products to compute similarity between queries and keys?
    - Why use two vectors for keys and values, instead of just the same vector?
    
    
* **How is multi-head attention different from single-head attention (particularly a large single head)?**
    - How MHA is actually computed in the code?

Some more notes:
> * 'Word' and 'Token' are used interchangeably in this post
> * 'Sentence' and 'Sequence' are also used interchangeably

Let's get started!

### Introduction
In a nutshell, **a transformer model is an encoder-decoder architecture, composed of a stack of encoder layers connected to a stack of decoder layers.** 

This *encoder-decoder* formulation is very similar to the idea captured by seq2seq RNN models, of encoding some input sequence of tokens, and then conditioning upon that encoding to decode into a 'new type' of sequence. For example, a machine translation task can leverage seq2seq models to encode an english sequence then decode it into french.

<figure align="center">
  <center><img src="images/encoder_decoder_transformer.png" alt="Encoder Decoder Stacks: Transformer Architecture"></center>
  <center><figcaption>Encoder Decoder Stacks: Transformer Architecture <a href="https://arxiv.org/pdf/1706.03762.pdf">Illustrated Transformer</a></figcaption>
</figure></center>

Each encoder layer is identical and each decoder layer is identical, the outputs of one layer are simply passed in as inputs to the next layer. Then the output of the final encoder layer is passed to each of the decoder layers (**check this**).

Opening the lid of this blackbox, these are the components that make up each of the encoder and decoder layers: 

<figure align="center">
  <center><img src="images/encoder_decoder_pieces.png" alt="Encoder Decoder Components"></center>
  <center><figcaption>Encoder Decoder Components <a href="https://jalammar.github.io/illustrated-transformer/">Illustrated Transformer</a></figcaption>
</figure></center>

We'll go over each of these components in detail, starting with **the true secret sauce of the transformer architecture, the well-sung hero, Self-Attention**. As we can see from the image above, the self-attention block is present in both the encoder and decoder layers (we'll also see later that encoder-decoder attention is really just a specialized version of self-attention), and is the fundamental component of the transformer architecture. 

So we'll start our journey by diving into Self-Attention...

### **Motivating Self-Attention: *Contextualize Tokens as Weighted Average of Other Tokens***

First, let's see what *goes into* the self-attention layer, and what *comes out of* the layer (or what transforms into what ;)).

* **Inputs**: For the first encoder layer, typically embeddings of each of the input tokens.
* **Outputs:** Can be seen as *contextualized* encodings of each token, where self-attention *reinterprets* the token in the context of how it relates to the other tokens in the sequence.


The key insight here is that word encodings should not be 'standalone' or fixed, but rather determined by the particular surrounding context of the word. This is similar to humans often being able to discern the meanings of words unfamiliar to them, just by using the 'context clues' provided by the surrounding words in the text.
<figure align="center">
  <center><img src="images/self_attention_input_output.jpg" alt="Self Attention Input Output"></center>
  <figcaption>Transformer Architecture <a href="https://arxiv.org/pdf/1706.03762.pdf">Attention Is All You Need (Transformer Paper)</a></figcaption>
</figure>

To contextualize the embedding of a word w<sub>i</sub> based on its surrounding word w<sub>j</sub>, we need to <span style="color:blue">**find the impact that** ***each*** **word w<sub>j</sub> has on our understanding of** ***this*** **word w<sub>i</sub>**</span>.

For example, in the sentence '*The animal didn't cross the street because **it** was too tired*', our understanding of the word 'it' comes from the surrounding words in the sentence, with each of the words having a varying contribution to our understanding. Also note how the meaning of same word 'it' changes in the sentence '*The animal didn't cross the street because **it** was too wide.*', and different words become more important to our understanding.

<figure align="center">
  <center><img src="images/self_attention_head.png" alt="Attention head visual"></center>
  <figcaption>Attention head visual <a href="https://ai.googleblog.com/2017/08/transformer-novel-neural-network.html">Google Transformer Blog</a></figcaption>
</figure>

Once we find the impacts the different words w<sub>j</sub> have on our understanding of w<sub>i</sub>, we can <span style="color:blue"> **recalculate an encoding for w<sub>i</sub>, that factors in the different w<sub>j</sub>'s based on their contribution to our understanding of w<sub>i</sub>**. </span>

We can almost ***consider self-attention to be a fancier weighted average of sorts***, where encodings of different w_j's are combined to form an encoding of w_i, weighted based on their importance in understanding w_i.
> Note: I do say fancier, so ofc there are other bells and whistles to be discussed. But conceptualizing self-attention as based in this idea of a weighted average is a great way of building some intuition into why the architecture was made the way it was.

This simple idea of doing a pair-wise evaluation of words (a 1 to many mapping from each word in the seq to all words in the seq), and recomputing each word as a weighted average of the others, is what drives the entire transformer!

Next lets dig deeper and see how the weights in this weighted average are calculated (cue bells and whistles!)...

### **Self-Attention Mechanics: *Queries, Keys, and Values Intuited***
#### **How do we find the impact that a surrounding word w_j has on our understanding of word w_i?** 

<span style="color:blue">**We can ask questions to learn more about w_i**</span>, and <span style="color:orange">**check the *extents* to which the different w_j's can provide information about w_i**</span>. **The more information w_j has to contribute about what w_i wants to know, the more impact it has on our understanding of w_i**, and should therefore be factored in more into our contextual reconstruction of w_i.

* So for w_i, we can construct a question vector, known as the <span style="color:blue">**query**</span>, to get information useful for contextualizing w_i.
* Each word w_j can then have a <span style="color:orange">**key**</span> vector, broadcasting the extent to which it can provide this information.

How well the key of w_j matches what the query vector of w_i is looking for, is our measure of the impact or weight w_j has on our understanding of w_i. The extent of the match between the query and key is calculated as a dot product between the two vectors.

> **Why compute the dot product?**  
> Consider the task of making movie recommendations for users, something that a service like Netflix regularly does.  
An intuitive way to go about this could be that each user and each movie is broken down into a set of dimensions (in this case romance, action, comedy). The user vector is defining the extent to which a user likes these dimensions in movies, whereas the movie vector is defining the extent to which the movie itself contains these dimensions.   
Taking the dot product of the user and movie vectors gives us a score of how well the user and the movie align on these different dimensions.   
For example, a user really liking action *(high positive value in this dimension of user vector)* multiplied with the movie containing a lot of action *(high positive value in this dimension of movie vector)* would lead to a high positive contribution to the dot product, as would the movie being serious and noncomical *(high negative value in comedy dimension of movie vector)*, and the user strongly disliking comedy *(high negative value in this dimension of user vector)*. A mismatch in like or dislike would instead decrease the dot product score (or increase it by only a small amt), by adding a negative value (or a small positive value).  

><figure align="center">
  <center><img src="images/dot_product_user_movie.png" alt="query visual" width="600"></center>
  <figcaption><center>User Movie Match <a href="https://ai.googleblog.com/2017/08/transformer-novel-neural-network.html">Google Transformer Blog</a></figcaption></center>
</figure>

**What do Query/Key Vectors Represent?**

Now, let's go a bit deeper with an example of what these query and key vectors could represent:
    Let's consider the word **it**, in the sentence in the attention head visual, "*The animal didn't cross the street because **it** was tired.*"

* <span style="color:blue">**query vector:**</span> think of the query vector as a way to get information about a set of features of the sentence, which can help understand the word in the context of the sentence. 

    To understand the word ***it*** in context, we can build a query using a set of features we want to know about, to better understand this word.
    For example, we can build a query vector to ask about the part-of-speech (POS) attributes of the sentence, such as the, *'subject/actor of the sentence'* (i.e. doing the action), *'the verb of the sentence'* (i.e. the action being done), *'the object of the sentence'* (i.e. acted on by the action). Gaining information about the POS could help us better understand the word (perhaps by learning about its role in the sentence).
    
    We can visualize each element of the query vector corresponding to one of these POS features, and the vector's value for that element denoting *how much* the word needs information about that POS. The query vectors for each word would be aligned in asking about the same features (in our case: subject, verb, object). 
    >(This vector is analogous to the user preference vector mentioned above, instead of how much a user likes a particular attribute, here we simply consider how much a word needs info about a particular attribute)
    
<figure align="center">
  <center><img src="images/query_vector_visual.png" alt="query visual" width="400"></center>
  <figcaption><center>query vector visual <a href="https://ai.googleblog.com/2017/08/transformer-novel-neural-network.html">Google Transformer Blog</a></figcaption></center>
</figure>

* <span style="color:orange">**key vector:**</span> think of this as a way of broadcasting some attributes of the word_j, to give the other words a sense of how relevant it can be to their queries.
    * Each word produces a single key vector, so note that it is using the same vector used to broadcast information about itself to all the words in general.
    * Other words will use their query vectors to determine how relevant this word is, based on its key.
    
So continuing the same example as above, the key vectors of words in the sequence could be advertising how much information they have available on the *actor of the sentence*, *the verb of the sentence*, and the *object of the sentence*.  



<center> -INSERT KEY VISUAL- </center>

**Query/Key Communication**

As shown above, since we are taking the dot product between the queries and keys, querys and keys across words need to be aligned on what each of their dimensions mean, in order for the dot product operation to make sense. In the user/movie example, the user vector's first element is how much they like romance, while the movie vector's first element is how much romance the movie has. Similarly in the POS example, for each element of the vectors, the queries and keys would need to be exhanging information about the same POS.  
Since we are using a dot product to compute the relevance between queries and keys, the model would automatically learn to align the query and key dimensions in this way, to give the most relevant words the highest weight.

<!--- We can think of an 'attention head' as abstraction used to impose this order, attributes words are interested in asking about (to understand themselves better), and also consequently which attributes words can broadcast having information about.

* head: ok, lets learn about the parts of speech! Ask and answer questions about the actor, the verb, and the object of the verb
* Wq -> aight, let's learn how to build query vectors to ask that!
* Wk -> let's learn how to build key vectors to answer that! -->

<center> -INSERT QUERY X KEY EXAMPLE DOT PROD PIC- </center>

**Integrating Values**  
Once we have a dot-product based score from a word w_i to the all the words w_j in the sequence, based on q_i and k_j's

**Splitting into Multiple Heads**

Some more interesting questions:
> **Where the query, key, value notion come from?**  
> This concept also stems from information retrieval, like in a database if you enter a query it could be mapped against a set of keys, at which point the value(s) associated with the most relevant keys could be returned.

Some more references on queries, keys, and vectors:
* <a href="https://ai.googleblog.com/2017/08/transformer-novel-neural-network.html">Illustrated GPT-2</a>

## Multi-Head Attention in Code
> NOTE: in some explanations/implementations of self-attention online, it is made to seem that 'multi-head attention' with n heads means taking the same input vectors of dim k into each head, having n different full-size weight matrices (W_q, W_k, and W_v) (one set per head), getting output vectors of dim k from each head, concatenating them together into vectors of size k $\cdot$ n, and then using a matrix to downscale back to dim k. This is NOT what the original transformer paper (by Vaswani et. al) actually does...their multi-head attention is implemented as described below, which still gets us the sort of benefits we're looking at lower computational cost (as described in the aside next).

* **Inputs** - for the query (x_q), key (x_k), and value (x_v) (these are all the same for the encoder, but can be distinct for the decoder?) with dimension d_model (size of emb or size of output vectors from prev layer)
    * *dimensions of inputs: (batch_size, seq_len, d_model)*
   
   
* **Queries, keys, and values generation** - Inputs are put through the W_q, W_k, and W_v matrices, which perform a linear transform over the inputs to generate the queries, keys, and vectors, per input
    * *dimensions of W_q, W_k, W_v: (d_model, d_model)*


* **Split into multiple heads** - the q, k, v vectors are then split across their vector length (d_model) n times, where n is the number of attention heads (note question - what exactly is this accomplishing, are the results different because of this 'split'? or is just a means of speeding up computation/'viewing things a certain way'?)
    * *dimensions of q, k, v post split: (batch_size, num_heads, seq_len, depth), where d_model = num_heads * depth*

>### Aside - How is splitting into multiple heads different from a single large attention head?
There are real numerical differences if we didn't reshape the query, key, value matrices this way, and splice them into smaller vectors corresponding to more heads (as compared to original large vectors for single head)... (I was personally initially a bit uncertain as to if splitting would lead to a computationally distinct result in this way, or if it was just a way to try to parallelize a single-head computation, which meant there were no expressivity benefits?!)

><figure align="center">
  <center><img src="images/Hydra.jpg" alt="Pay attention to all the heads!"></center>
  <center><figcaption>Pay attention to all the heads! <a href="https://arxiv.org/pdf/1706.03762.pdf">Illustrated Transformer</a></figcaption>
</figure></center>

><figure align="center">
  <center><img src="images/splitting_heads_code_visual.jpg" alt="Pay attention to all the heads!"></center>
  <center><figcaption>Pay attention to all the heads! <a href="https://arxiv.org/pdf/1706.03762.pdf">Illustrated Transformer</a></figcaption>
</figure></center>

>Let's look at splitting the q, k, v vectors into 2 heads in this illustration. In particular, let's consider the self-attention operation for x<sub>1</sub>. q<sub>x1</sub> has been split in half, the first half of the vector is for the first attention head and the second half is for the second attention head. The same is true for the key vectors we're considering, k<sub>x1</sub>, k<sub>x2</sub>, and k<sub>x3</sub>.

>In the case of a single large attention head, we would be taking the dot product of the entire q<sub>x1</sub> vector, with each of the entire key vectors (k<sub>x1</sub>, k<sub>x2</sub>, and k<sub>x3</sub>). Then we would take the softmax of these entire dot products, to calculate the weight x<sub>1</sub> should place on x<sub>1</sub>, x<sub>2</sub>, and x<sub>3</sub>. 

>However, this means that we can't focus on specific parts of the query vector. For example, it's possible that the first part of k<sub>x2</sub> matches really well with the first part of q<sub>x1</sub>, but the second parts of these vectors (which are expressing some different aspects of their words), don't end up matching well. This info gets lost inside a single-head computation.

>Having multi-headed attention (in this example 2 heads) here allows us to distinguish between these effects. Basically for some aspects of x<sub>1</sub>, we want to place a high weight on x<sub>2</sub>, but for others x<sub>2</sub> isn't relevant at all and should be weighted minimally. We're able to weight words highly for the dims in which they're relevant, and weight them lowly for dims in which they're irrelevant. We're able to discern such different facets of interpreting x<sub>1</sub> in the context of the other words, by splitting the vectors 'into different heads'. 

>If we didn't split the q, k, v vectors at all (single large head), we would just get one mixed/averaged out sense of how relevant other words are to x<sub>1</sub>, where the effects of all dimensions of the words (d dimensions of their vector reps) would be blended (as the dot prod op across the full vectors just sums the prods of each of the corresponding dims of the vector, and then we softmax this set of dot prods which just combines info across all dims).

In [8]:
import torch.nn as nn
import torch
import math

class AttentionHead(nn.Module):
    def __init__(self, attention_vec_dim, attention_head_dim):
        super(AttentionHead, self).__init__()
        self.attention_vec_dim = attention_vec_dim
        self.attention_head_dim = attention_head_dim

        # weight matrices
        self.wq = nn.Linear(self.attention_vec_dim, self.attention_head_dim)
        self.wk = nn.Linear(self.attention_vec_dim, self.attention_head_dim)
        self.wv = nn.Linear(self.attention_vec_dim, self.attention_head_dim)

    def scaled_dot_prod_attention(self, q, k, v, mask=None):
        query_key_match = torch.bmm(q, k.transpose(1, 2)) # (batch_size, q_seq_len, attention_head_dim) * (batch_size, attention_head_dim, k_seq_len) = (batch_size, q_seq_len, k_seq_len)
        scaled_query_key_match = torch.div(query_key_match, math.sqrt(self.attention_head_dim)) # (batch_size, q_seq_len, k_seq_len)
        
        if mask is not None: # mask should have dim (batch_size, q_seq_len, k_seq_len)
            #print(q.shape[0], q.shape[1], k.shape[1])
            assert(mask.shape[-1] == k.shape[1])
            scaled_query_key_match = scaled_query_key_match.masked_fill(mask==0, 0) # mask would be 0 where at places we need to mask

        attention_scores = nn.functional.softmax(scaled_query_key_match, dim=-1) # (batch_size, q_seq_len, k_seq_len), compute softmax along last dim
        z = torch.bmm(attention_scores, v) # (batch_size, q_seq_len, k_seq_len) * (batch_size, v_seq_len, attention_head_dim) = (batch_size, q_seq_len, attention_head_dim)
        return z # (batch_size, q_seq_len, attention_head_dim)

    def forward(self, x_q, x_k, x_v, mask):
        #TODO: do each of these inputs have the same dimensions?
        # (batch_size, input_seq_len, attention_vec_dim) * (attention_vec_dim, attention_head_dim) = (batch_size, input_seq_len, attention_head_dim)
        q = self.wq(x_q)
        k = self.wk(x_k)
        v = self.wv(x_v)
        z = self.scaled_dot_prod_attention(q, k, v, mask)
        return z # (batch_size, input_seq_len, attention_head_dim)

class MultiHeadAttention(nn.Module):
    def __init__(self, attention_vec_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.attention_vec_dim = attention_vec_dim 
        self.attention_head_dim = attention_vec_dim // num_heads
        self.num_heads = num_heads

        self.attention_heads = []
        for _ in range(num_heads):
            self.attention_heads.append(AttentionHead(self.attention_vec_dim, self.attention_head_dim))

        self.wo = nn.Linear(attention_vec_dim, attention_vec_dim) 

    def forward(self, x_v, x_k, x_q, mask):
        assert (x_q.shape[2] == x_k.shape[2]) and (x_k.shape[1] == x_v.shape[1])

        z = torch.empty((x_q.size(0), x_q.size(1), 0))
        for i in range(self.num_heads):
            attention_head = self.attention_heads[i]
            z = torch.cat((z, attention_head(x_q, x_k, x_v, mask)), dim=2)
        
        assert z.size() == x_q.size() # z will be (batch_size, input_seq_len, attention_vec_dim)

        mha_output = self.wo(z)
        assert mha_output.size() == x_q.size() # mha_output will be (batch_size, input_seq_len, attention_vec_dim)
        return mha_output 


# This is Markdown
Images: ![transformer diagram](transformer_diagram.png)

# References
* The Illustrated Transformer: 