<h4>Change List:</h4>

1. Deleting `Head` and `Multi-Head Attention` and writing everything in a single class `CasualSelfAttention`: that handles the projection, splitting into heads, RoPE, and attention all in one place.



In [None]:
import torch
from torch.nn import functional as F
import torch.nn as nn
from typing import Tuple    

torch.manual_seed(1337)
print("mps : ",torch.backends.mps.is_available()) # Should return True

In [None]:
INPUT_DATA_FILE = r'/Users/kunal/My Works/Learning/GptFromScratch/data/input/tinyStoriesData.txt'
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
BATCH_SIZE = 4 # how many independent sequences will we process in parallel? 
BLOCK_SIZE = 8 # what is the maximum context length for predictions? ---> 256 to predict all range of tokens from 0-255
MAX_ITERS = 1000
EVAL_INTERVAL = 300
LEARNING_RATE = 3e-4
EVAL_ITERS = 200

# --------------
N_EMBD = 24
NUM_HEAD = 2
# ----Ever Head is of dim = n_embd//n_head => 486//8 = 64


N_LAYER = 6 # Number of transformer block
DROPOUT = 0.2 # 20% dropout, Regularizetion to prevent overfitting  

In [None]:
with open(INPUT_DATA_FILE, "r", encoding="utf-8") as f:
    text = f.read()


## Encoder and Decoder

In [None]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("".join(chars))
print(f"Calculated Vocab Size : {vocab_size}")

In [None]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

## Data Preparation

In [None]:
# Train on a 70% subset of the data for faster training.
data_divider = int(0.7 * len(text))
data = torch.tensor(encode(text[:data_divider]), dtype=torch.long, device=DEVICE)

# Split this data into 90% training and 10% validation
split_point = int(0.9 * len(data))
train_data = data[:split_point]
val_data = data[split_point:]

## Batch Processing Data

In [None]:
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE,))
    x = torch.stack([data[i:i+BLOCK_SIZE] for i in ix])
    y = torch.stack([data[i+1:i+BLOCK_SIZE+1] for i in ix])
    # x, y = x.to(device), y.to(device)
    return x, y


## Single Attention Head

In [None]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(N_EMBD, head_size, bias=False)
        self.query = nn.Linear(N_EMBD, head_size, bias=False)
        self.value = nn.Linear(N_EMBD, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(BLOCK_SIZE, BLOCK_SIZE)))

        self.dropout = nn.Dropout(DROPOUT)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs) -> hs: head_size
        q = self.query(x) # (B,T,hs)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)

        # -------------------------
        wei = self.dropout(wei)
        # -------------------------

        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

## Multi-Head Attention

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_head, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_head)])
        self.proj = nn.Linear(N_EMBD, N_EMBD)
        self.dropout = nn.Dropout(DROPOUT) 
         

    def forward(self, x):
        out =  torch.cat([h(x) for h in self.heads], dim = -1)
        out = self.proj(out)

        return out

## RoPE Implementation (Rotary Positional Encoding) from llama 2 implementation

Instead of adding a vector, RoPE rotates the Query and Key vectors.
Imagine the Query vector is an arrow on a 2D graph (Channel dim).
- If the token is at position 0, we rotate the arrow by 0 degrees.
- If the token is at position 1, we rotate the arrow by θ degrees.
- If the token is at position 2, we rotate the arrow by 2θ degrees.

Because the dot product depends on the angle between vectors, rotating them creates a system where the model cares about the relative distance between tokens, not their absolute position. This allows the model to generalize to longer sequences better.



### What is freqs_cis ?
- freqs: Refers to the "frequencies" or angles ($\theta$) calculated for each dimension of the hidden state.<br>
- cis: Refers to the complex exponential form $e^{i\theta}$, which, according to Euler's formula, is:<br><p>
$e^{i\theta} = \cos(\theta) + i\sin(\theta)$

`freqs_cis` is a precomputed table (matrix) of complex numbers.
It contains the rotation values for every possible position in your sequence (Time) and every pair of dimensions in your head (Head_Dim).


In [None]:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    # freqs_cis shape: (seq_len, dim//2)
    # x shape: (batch, seq_len, heads, dim//2)
    ndim = x.ndim
    # Reshape freqs_cis from (seq_len, dim//2) to (1, seq_len, 1, dim//2)
    # This allows it to broadcast correctly across batch and heads dimensions
    shape = [1, freqs_cis.shape[0], 1, freqs_cis.shape[1]]
    return freqs_cis.view(*shape)

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    # xq shape: (Batch, Time, Heads, Head_Dim)
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    
    # Broadcast freqs_cis to match (1, Time, 1, Head_Dim/2)
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

Following is a demonstration of how tensor reshaping works in general. This operation will be used in various parts of the transformer implementation, especially in attention mechanisms. hence it's important to understand it well.

---

```python
import torch  
data = torch.tensor([1,2,3,4])  
# convert to floating point

data.float() # -->  output : tensor([1., 2., 3., 4.])
```
example 1:
```python
print(data.float().reshape(*data.shape[:-1],-1,2))
```

```text
<!-- Output -->
tensor([[1., 2.],
        [3., 4.]])
```

example 2:
```python
print(data.float().reshape(*data.shape[:-1],-1,1))
```
```text
<!-- Output -->
tensor([[1.],
        [2.],
        [3.],
        [4.]])
```

---


<h3>Reason why Reshaping is important</h3>

When using `reshape(*xq.shape[:-1], -1, 2)` <p> The reshape operation changes the shape of `xq` to a new shape defined by the parameters. The `*` operator unpacks the shape tuple, effectively expanding it as separate arguments. The `-1` in reshape is a placeholder that automatically calculates the appropriate size for that dimension based on the total number of elements and the other specified dimensions. The `2` at the end splits the last dimension into two dimensions, which is essential for representing complex numbers (as they consist of real and imaginary parts).

For example, if xq initially has a shape (32, 8, 128, 64) (a common shape in transformer models, with 32 being the batch size, 8 the number of heads, 128 the sequence length, and 64 the dimensionality of each head), after this reshaping, it will have the shape (32, 8, 128, 32, 2). The last dimension of 64 is split into two dimensions: 32 for real parts and 2 for representing them as complex numbers (real and imaginary parts).

<h3>Reason for why we use torch.view_as_complex()</h3> 

Let's look at an example of what happens when we use `torch.view_as_complex()`.

```python
import torch
data = torch.tensor([1,2,3,4])
print(torch.view_as_complex(data.float().reshape(*data.shape[:-1],-1,2)))

# Output :
# tensor([1.+2.j, 3.+4.j])


```

<p>

Q. **What is is `j` here?** <p>
A. In Python, `j` is used to denote the imaginary part of a complex number. So, `1.+2.j` represents the complex number 1 + 2i, where 1 is the real part and 2 is the imaginary part.


Q. **Why does this matter?**<p>
A. This goes back to the rotation (RoPE). A complex number is just a point on a 2D plane. Rotating a "pair of numbers" using standard matrix multiplication is computationally expensive.

<u> *Rotating a "single complex number" is just a simple multiplication.* </u>

By converting the tensor to complex mode (with j), PyTorch can rotate the embeddings much faster.

> [!NOTE] 
> The essence of RoPE is that by applying these phase shifts, the dot product (used in the self-attention mechanism) between queries and keys becomes sensitive to their relative positions. It's not merely about measuring "real angle distance" but rather about how the phase-shifted dot product correlates with the positional relationships of tokens in the input sequence.

By converting xq and xk into complex tensors, these lines prepare them for such a rotation. The actual rotation is performed later in the code, where these complex tensors (xq_ and xk_) are multiplied by another complex tensor representing the rotation (typically through phase factors, like freqs_cis in this code).

---

## Casual Self-Attention

Instead of having separate classes for `Head` and `Multi-Head Attention`, we can consolidate everything into a single class called `CasualSelfAttention`. This class will handle the projection of inputs, splitting into multiple heads, applying RoPE (Rotary Positional Encoding), and performing the attention mechanism all in one place. This approach simplifies the architecture and makes it easier to manage the attention mechanism as a whole.

The key changes includes:
- **Projection**: The class will include linear layers to project the input embeddings into query, key, and value vectors. Previoulsy we had Q, K and V as three seperate layers inside a loop. But here we do one massive matrx multiplication. For example if your embedding size is 64, we project it to 192 (64 for Query + 64 for Key + 64 for Value). This is significantly faster on GPUs than doing 3 smaller operations.



In [None]:
class CasualHeadAttention(nn.Module):

    def __init__(self, n_embd, n_head):
        super().__init__()
        self.n_embd = n_embd
        self.n_head = n_head
        self.head_dim = n_embd//n_head

        # KEY CHANGE 1 in CasualAttention: Calculate Q, K, V in ONE Linear layer
        self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=False)

        # This is a linear layer applied after attention is calculated. It allows the results from all the different heads to "talk" to each other and mix their features before moving to the next block.
        # Previously this was named as self.proj
        self.c_proj = nn.Linear(n_embd, n_embd, bias=False)

        self.dropout = nn.Dropout(DROPOUT) 

        # We create the lower-triangular mask (ones in the bottom-left, zeros in the top-right). We call it a "buffer" because it is a tensor that is part of the model state, but it is not a trainable parameter (gradients won't update it).
        self.register_buffer("bias", torch.tril(torch.ones(BLOCK_SIZE, BLOCK_SIZE)).view(1, 1, BLOCK_SIZE, BLOCK_SIZE))
        
    def forward(self, x, freqs_cis = None):
        B, T, C = x.shape

        # Calculate Q, K, V for ALL heads at once
        # Result shape: (B, T, 3 * n_embd) --> shape (B, T, 3 * C)
        qkv = self.c_attn(x)

        # Split Q, K, V
        # New shape: (B, T, n_head, head_dim)
        q, k, v = qkv.split(self.n_embd, dim=2)

        k = k.view(B, T, self.n_head, self.head_dim)
        q = q.view(B, T, self.n_head, self.head_dim)
        v = v.view(B, T, self.n_head, self.head_dim)

        # Before we calculate attention, we need to inject "position" information. We rotate the q and k vectors based on their position in the sequence (index T).
        # Apply RoPE (Rotary Positional Embeddings)
        if freqs_cis is not None:
            # Slice freqs_cis to match the current sequence length T
            freqs_cis_sliced = freqs_cis[:T]
            q, k = apply_rotary_emb(q, k, freqs_cis_sliced)

        # Transpose for Matrix Multiplication
        # PyTorch performs matrix multiplication on the last two dimensions. We want to do attention on the 'Time' dimension for every 'Head' independently.
        # i.e. (B, Time, Heads, Dim) → (B, Heads, Time, Dim).
        q = q.transpose(1, 2) 
        k = k.transpose(1, 2) 
        v = v.transpose(1, 2)

        # Self Attention Score Calculation (Scaled Dot Product)
        # (B, Heads, T, T) @ (B, Heads, T, Dim) -> (B, Heads, T, T)
        #  We scale by 1/sqrt(head_dim) to keep the numbers stable.
        att = (q @ k.transpose(-2, -1)) * (1.0 / (k.size(-1) ** 0.5))  
        
        # Causal Masking, coverting zeros to -inf to avoid future token knowledge
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
        
        # Softmax
        att = F.softmax(att, dim=-1)
        att = self.dropout(att)
        
        # We calculate the weighted sum of the values based on how "interesting" (attention score) they were.
        # (B, Heads, T, T) @ (B, Heads, T, Dim) → (B, Heads, T, Dim)
        y = att @ v

        # Reassemble Heads:
        # We are done processing heads separately. We transpose back to (B, T, Heads, Dim) and then view (flatten) the last two dimensions back into (B, T, C).
        # we use contiguous is a Pytorch function that moves the data physcially into contiguous memory because 'view' requires the data to be physically contiguous in memory
        y = y.transpose(1, 2).contiguous().view(B, T, C)
    
        # The result of the attention is passed through one final linear layer to mix the information found by different heads.
        return self.dropout(self.c_proj(y))

## FFN

In [None]:
class FeedForward(nn.Module):
    """This is a single linear layer followed by non-linearity"""
    def __init__(self, n_embd):
        super().__init__()


        """
        As per the paper of GPT2, they implemented the FFN as:
        The outer layer was of size 512 and their inner layer was of size 2048 i.e. outer layer = 4 x inner layer

        Hence when going inside the FFN we will use the dimension = n_embd and the inner layer will be 4 x n_embd
        """


        self.net = nn.Sequential(
            nn.Linear(n_embd, 4*n_embd),
            nn.GELU(),
            nn.Linear(4*n_embd, n_embd), # ---> this particular layer is the projection layer that goes back in the residual pathway       
            # dropout is something that is added right before the residual connection gets back in the residual pathway
            nn.Dropout(DROPOUT) 
        )
    def forward(self, x):
        return self.net(x)

## RMS Normalization -->

Standard Layer Normalization re-centers data around zero (subtracting the mean) and then scales it (dividing by variance). RMSNorm assumes the mean is already close to zero and skips the subtraction. It only scales the input based on the Root Mean Square (RMS).

For our problem we will use RMSNorm as :
$\bar{x}_i = \frac{x_i}{\text{RMS}(x)} \cdot g_i$

Where : $\text{RMS}(x) = \sqrt{\frac{1}{n} \sum_{i=1}^{n} x_i^2 + \epsilon}$

----
<mark>So basically we calulate the mean of sqared values of the input, take square root. Then take reciprocal of that and multiply with input 

---

In [None]:
class RMSNorm(torch.nn.Module):

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        # eps (epsilon): A tiny number added to prevent division by zero errors during calculation.
        self.eps = eps

        # weights are initialized as ones beacuase this will not change the magnitude of the output.
        # during training it will learn to scale normalized output to the best value for the network.
        self.weight = nn.Parameter(torch.ones(dim))
        
    def _norm(self, x):
        # here mean is calulated for the last dimention i.e. Feature/Channel dimensionx 
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        # using float32 (full-precision) for stability during norm calculation
        output = self._norm(x.float()).type_as(x)
        return output * self.weight


## Transformer Block

In [None]:
class Block(nn.Module):
    """A transformer Block comprising of Attention, FFN and BLM"""
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd//n_head

        # self.sa = MultiHeadAttention( n_head, head_size)  
        self.sa = CasualHeadAttention(n_embd, n_head)
        self.fwd = FeedForward(n_embd)

        # using RMSNorm
        self.ln1 = nn.RMSNorm(n_embd)
        self.ln2 = nn.RMSNorm(n_embd)
    
    def forward(self, x, freqs_cis):
        # In `x + self.sa(x)` or `x + self.fwd(x)` the `x + _` denotes that we fork from x (i.e. residuals are added at the Self-attention step and FFN step) 
        # Layer normalization is applied immediately on x
        # Since nn.Sequential expects layers to take only one input, but CausalSelfAttention now needs x AND freqs_cis. Hence we pass freqs_cis here
        x = x + self.sa(self.ln1(x), freqs_cis)
        x = x + self.fwd(self.ln2(x))
        return x

## Bigram Language Model

In [None]:
import torch.nn as nn
from torch.nn import functional as F

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, N_EMBD)
        
        # NOTE: We NO LONGER need position_embedding_table because RoPE handles it!
        # self.position_embedding_table = nn.Embedding(BLOCK_SIZE, N_EMBD) 
        
        # PRECOMPUTE FREQUENCIES HERE
        head_dim = N_EMBD // NUM_HEAD

        # Precompute for the maximum block size
        freqs_cis = precompute_freqs_cis(head_dim, BLOCK_SIZE)

        # Register as buffer so it saves with model and moves to GPU
        self.register_buffer('freqs_cis', freqs_cis)

        self.blocks = nn.Sequential(*[
            Block(N_EMBD, n_head=NUM_HEAD) for _ in range(N_LAYER)
        ])
        self.ln_f = RMSNorm(N_EMBD) # Final normalization
        self.lm_head = nn.Linear(N_EMBD, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # Token embeddings only (No Positional Embeddings added here!)
        x = self.token_embedding_table(idx)
        
        # PASS FREQS_CIS INTO BLOCKS
        # We can't use nn.Sequential directly anymore because we need to pass 2 arguments (x, freqs_cis)
        # So we loop through blocks manually
        for block in self.blocks:
            x = block(x, self.freqs_cis) 
            
        x = self.ln_f(x)
        logits = self.lm_head(x)

        if targets is None:
            return logits, None
        
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        targets = targets.view(B*T)
        loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # Crop idx to the last BLOCK_SIZE tokens
            idx_cond = idx[:, -BLOCK_SIZE:]
            # Get the predictions
            logits, _ = self(idx_cond)
            # Focus only on the last time step
            logits = logits[:, -1, :]  # becomes (B, C)
            # Apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            # Sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # Append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx
    
# m = BigramLanguageModel(vocab_size)
model = BigramLanguageModel(vocab_size)
m = model.to(DEVICE)

## Loss estimation

In [None]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(EVAL_ITERS)
        for k in range(EVAL_ITERS):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

## optimizer

In [None]:
optimizer = torch.optim.AdamW(m.parameters(), lr=LEARNING_RATE)

## Training Loop

In [None]:
for iter in range(MAX_ITERS):

    # every once in a while evaluate the loss on train and val sets
    if iter % EVAL_INTERVAL == 0 or iter == MAX_ITERS - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print("\nTraining loop completed successfully!")

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=DEVICE)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))

In [None]:
test = "Can you share it with me and"

context = torch.zeros(encode([test]), dtype=torch.long, device=DEVICE)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))


In [None]:
import torch
data = torch.tensor([1,2,3,4])

In [None]:
data.float()

In [None]:
data.float().reshape(*data.shape[:-1],-1,2)

In [None]:
torch.view_as_complex(data.float().reshape(*data.shape[:-1],-1,2))