#### Understanding the Difference Between Embedding Layers and Linear Layers

###### -> Embedding layers in PyTorch accomplish the same as linear layers that perform matrix multiplications; the reason we use embedding layers is computational efficiency
###### -> We will take a look at this relationship step by step using code examples in PyTorch

In [None]:
import torch

print("PyTorch version:", torch.__version__)

#### Using nn.Embedding

In [None]:
# Suppose we have the following 3 training examples,
# which may represent token IDs in a LLM context
idx = torch.tensor([2, 3, 1])

# The number of rows in the embedding matrix can be determined 
# by obtaining the largest token ID + 1 
# If the highest token ID is 3, then we want 4 rows, for the possible
# token IDs 0, 1, 2, 3
num_idx = max(idx) + 1

# The desired embedding dimension is a hyperparameter
out_dim = 5

###### Let's implement a simple embedding layer:

In [None]:
# We use the random seed for reproducibility since 
# weights in the embedding layer are initialized with
# small random values
torch.manual_seed(123)

embedding = torch.nn.Embedding(num_idx, out_dim)


###### We can optionally take a look at the embedding weights:

In [None]:
embedding.weight

###### We can then use the embedding layers to obtain the vector representation of a training example with ID 1:


In [None]:
embedding(torch.tensor([1]))

###### Below is a visualization of what happens under the hood:

###### Similarly, we can use embedding layers to obtain the vector representation of a training example with ID 2:


In [None]:
embedding(torch.tensor([2]))

###### Now, let's convert all the training examples we have defined previously:

In [None]:
idx = torch.tensor([2, 3, 1])
embedding(idx)

###### Under the hood, it's still the same look-up concept:

#### Using nn.Linear
###### Now, we will demonstrate that the embedding layer above accomplishes exactly the same as nn.Linear layer on a one-hot encoded representation in PyTorch
###### First, let's convert the token IDs into a one-hot representation:

onehot = torch.nn.functional.one_hot(idx)
onehot

###### Next, we initialize a Linear layer, which carries out a matrix multiplication :

In [None]:
torch.manual_seed(123)
linear = torch.nn.Linear(num_idx, out_dim, bias=False)
linear.weight

In [None]:
linear.weight = torch.nn.Parameter(embedding.weight.T)

###### Now we can use the linear layer on the one-hot encoded representation of the inputs:

In [None]:
linear(onehot.float())

###### As we can see, this is exactly the same as what we got when we used the embedding layer:

In [None]:
embedding(idx)