# Embedding

## Summary

The **Embedding Layer** serves as the entry point of the Transformer, <br>
acting as a lookup table that translates discrete, categorical data (integers) into a continuous, <br>
high-dimensional space where the model can perform "meaning-based" mathematics.


### Implimention

In [1]:
import math
import torch
from jaxtyping import Float
import torch.nn as nn
from torch.nn.parameter import Parameter

class Embedding(nn.Module):

    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        device: torch.device | None=None,
        dtype: torch.dtype | None=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()

        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim

        self.weight = Parameter(
            torch.empty((num_embeddings, embedding_dim), **factory_kwargs)
        )
        self.reset_parameters()

    def reset_parameters(self) -> None:
        nn.init.trunc_normal_(self.weight)
    
    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        return self.weight[token_ids]

## Step-by-Step

#### Definition `Linear` class

In [None]:
class Embedding(nn.Module):

    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        device: torch.device | None=None,
        dtype: torch.dtype | None=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()

        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim

        self.weight = Parameter(
            torch.empty((num_embeddings, embedding_dim), **factory_kwargs)
        )
        self.reset_parameters()

### Initialization

**`trunc_normal_`** initialization

Same as `Linear` See details in `linear_layer.ipynb` 

In [None]:
def reset_parameters(self) -> None:
    nn.init.trunc_normal_(self.weight)

### `Embedding`

In [None]:
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
    return self.weight[token_ids]

### Example

In [None]:
import torch

# 1. Setup dimensions
vocab_size = 10  
# {0: b' ', 1: b'a', 2: b'c', 3: b'e', 4: b'h', 5: b't', 6: b'r', 7: b'at', 8: b'th', 9: b'the'}
d_model = 4      

# 2. Initialize the module
# This creates a weight matrix of shape (10, 4)
embedding_layer = Embedding(num_embeddings=vocab_size, embedding_dim=d_model)

# 3. Define your encoded input (the token IDs)
token_ids = torch.tensor([9, 0, 2, 7, 0, 7, 3, 9, 0, 6, 7])

# 4. Run the forward pass
output = embedding_layer(token_ids)

# 5. Check the results
print(f"Input Shape:  {token_ids.shape}")
print("Input Tensor (Token IDs):")
print(token_ids)

print(f"\nWeight Shape (Vocab Size x Embedding Dim): {embedding_layer.weight.shape}")
print("Weight Tensor (The Lookup Table):")
print(embedding_layer.weight)

print(f"\nOutput Shape (Sequence Length x Embedding Dim): {output.shape}")
print("Output Tensor (The Looked-up Vectors):")
print(output)

Input Shape:  torch.Size([11])
Input Tensor (Token IDs):
tensor([9, 0, 2, 7, 0, 7, 3, 9, 0, 6, 7])

Weight Shape (Vocab Size x Embedding Dim): torch.Size([10, 4])
Weight Tensor (The Lookup Table):
Parameter containing:
tensor([[-0.4136,  0.1851, -0.7171, -0.7127],
        [-1.0888,  0.2085, -1.3409, -0.0901],
        [ 1.5916,  0.6708, -0.2307, -0.5711],
        [-0.2284,  0.2545, -0.6289,  0.9571],
        [ 0.2058,  0.0875, -0.0067,  0.4083],
        [-1.0441, -1.1364, -1.2963,  0.3038],
        [ 0.9817, -0.9500, -0.8207,  1.0706],
        [ 0.3841, -0.4254, -1.2210,  1.8807],
        [ 0.4021, -1.1038, -1.1837,  0.3225],
        [ 0.9744, -0.6042, -0.2462, -0.3535]], requires_grad=True)

Output Shape (Sequence Length x Embedding Dim): torch.Size([11, 4])
Output Tensor (The Looked-up Vectors):
tensor([[ 0.9744, -0.6042, -0.2462, -0.3535],
        [-0.4136,  0.1851, -0.7171, -0.7127],
        [ 1.5916,  0.6708, -0.2307, -0.5711],
        [ 0.3841, -0.4254, -1.2210,  1.8807],
        

## Optimize Embedding Weight

The embedding weights are optimized through **Backpropagation**. <br>
Initially, the vectors in your  matrix are random noise. <br>
Over time, the model "nudges" these numbers so that words used in similar contexts end up with similar vectors. <br>

Here is the overview:

---

**1. The Forward Pass (The Guess)**

The model takes your input `the cat ate the...`, converts them to vectors using the current weights, and predicts the next token ID.

* **Input:** `[9, 0, 2, 7, 0, 7, 3, 9, 0]` ("the cat ate the ")
* **Target:** `6` (ID for "r" in "rat")
* **Model Prediction:** It might guess ID `1` ("a") with high confidence.

**2. The Loss Calculation & Backpropagation**

We compare the model's guess to the actual target using a **Loss Function** (usually Cross-Entropy).

* If the model guessed "a" but the answer was "r", the "Loss" is high.
* The gradient travels backward through the network until it hits the **Embedding Layer**.
* Since the embedding layer is just a lookup table, the optimization is very specific: **only the rows (vectors) used in that sentence get updated.**
* e.g.: If token ID `1` was used, its corresponding row in the weight matrix receives a gradient.

The update for a specific weight follows this logic:

$$
W_{\text{new}} = W_{\text{old}} - \eta \nabla L
$$

*  $\eta$: The **Learning Rate** (how big of a step we take).
*  $\nabla L$: The direction of the error.
