---
title: "Transformers: Attention Is All You Need"
author: "richsi"
categories: [Attention, Paper, Transformer]
date: "2024-12-30"
image: "images/thumbnail.png"
---

["Attention Is All You Need"](https://arxiv.org/abs/1706.03762) by Ashish Vaswani et al., 2017.

In [5]:
#| code-summary: Importing packages
#| code-fold: true

import torch
import numpy as np

## Introduction
* What is the topic of this guide, and why is it important?

The transformer is an attention-based network architecture that learns context and meaning by tracking relationships in sequential data like words in a sentence. 

* What problem does this concept or paper aim to solve? 

Sequence modeling and transduction problems such as machine translation have previously been solved by recurrent neural networks (RNNs) and long short-term memory networks (LSTMs). However, due to sequential dependency, the inherent nature of RNNs prevents its training from being parallelized. 

* What are the main contributions or breakthroughs introduced?

Transformers show significant improvements in both computational efficiency as well as model performance.

* How does this fit into the broader context of the field?

This paper laid the groundwork for state-of-the-art models such as BERT and GPT which has revolutionized the natural language processing (NLP) field. Beyong NLP, transformers have successfully adapted to other machine learning domains like computer vision and reinforcement learning. 


## Methods

### Encoder Decoder Stacks

**Encoder:** 

**Decoder:**

### Scaled Dot-Product Attention

$$
\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^{T}}{\sqrt{d_{k}}})V
$$

Scaled Dot-Product Attention takes in three inputs: query $\bf{Q}$, key $\bf{K}$, and value $\bf{V}$, where $X$ is the input and $W^Q$, $W^K$, $W^V$ are learnable weight matrices specific to queries, keys, and values. 

$$
\begin{aligned}
Q = XW^Q\\
K = XW^K\\
V = XW^V\\
\end{aligned}
$$


* $\bf{X}$ represents the input embeddings to the transformer layer. For the first layer, $X$ = Word Embedding + Positional Encoding. Word embedding is a vector that represents the semantic meaning of a word (or token).  Positional encoding is a vector that contains informations about a token's position in a sequence by mapping that information to a latent space.

* $\bf{Q}$ represents the current token's representation to "query" information from other tokens. It's derived from the combined word and positional information.

* $\bf{K}$ represents all tokens in the space and helps determine the relevance of each token to the query token. Like $Q$, it's based on the same input embeddings (word + positional) but transformed different via $W^K$.

* $\bf{V}$ represents the information associated with each token. It is the actual data that gets weighted and aggregated based on attention scores.

The dimensions of $Q$, $K$, and $V$ are determined by the model's **latent space dimensionality** ($d_{model}$) and the number of **attention heads** ($h$). In this example, we initialize the latent space dimension to be **512** and the number of attention heads to be **8**.

In [13]:
def scaled_dot_product_attention(Q, K, V):
  """ 
  Compute scaled dot-product attention

  Args:
  Q: Query matrix, shape (N, d_k)
  K: Key matrix, shape (N, d_k)
  V: Value matrix, shape (N, d_v)

  Returns:
  output: Weighted sum of values after applying attention, shape (N, d_v)
  attention_weights: shape (N, N)
  """
  d_k = Q.shape[1] # Dimensionality of keys/queries
  # Computing scaled attention scores
  scores = np.dot(Q, K.T) / np.sqrt(d_k) # Shape: (N,d_v) @ (d_v,N) -> (N, N)
  
  # Apply softmax to get attention weights
  attention_weights = np.exp(scores) / np.sum(np.exp(scores), axis=1) # axis=1 for column wise addition

  # Computing weighted sum of values
  output = np.dot(attention_weights, V) # Shape: (N,N) @ (N,d_v) -> (N,d_v)

  return output, attention_weights

### Multi-Head Attention

### Positional Encoding

$$
\begin{aligned}
PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{model}}) \\
PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d_{model}})
\end{aligned}
$$

## Applications and Insights

* Where and how is this concept applied in the real world?
* What problems does it solve, and what value does it provide?
* What are some examples or use cases where this has been impactful?
* What are the broader implications or insights gained from this work?
* How does this contribute to advancements in the field or industry?

## Conclusion

* What are the key takeaways or lessons from this guide?
* Why is this concept significant in the broader context of the field?
* What questions remain unanswered or open for further exploration?
* What resources or next steps can help deepen understanding?
* How can this knowledge be applied or expanded upon in practice?