# tt-NN Embedding Layer Example

This notebook shows how you can create an embedding layer out of `ttnn` tensors.

Let's assume we are developing a GPT-2 LLM model. We will need to specify a `vocab_size` and `output_dim`. 

In [16]:
vocab_size = 50257
output_dim = 256

Next, let's import some dependencies:

In [17]:
import torch
import ttnn
from scripts.prepare_data import create_dataloader_v1

## Data Preparation

Let's build a simple dataset by first acquiring some text. We will use the "the-verdict.txt".

In [18]:
import os
import urllib.request

if not os.path.exists("data/the-verdict.txt"):
    url = ("https://raw.githubusercontent.com/rasbt/"
           "LLMs-from-scratch/main/ch02/01_main-chapter-code/"
           "the-verdict.txt")
    file_path = "data/the-verdict.txt"
    urllib.request.urlretrieve(url, file_path)
    
with open("data/the-verdict.txt", "r", encoding="utf-8") as f:
    raw_text = f.read()

print(raw_text[:50])

I HAD always thought Jack Gisburn rather a cheap g


Next, let's create a dataloader so that we can obtain some batches. We'll assume a context length of 4, and batch size of 8.

In [19]:
context_length = 4
batch_size = 8

dataloader = create_dataloader_v1(
    raw_text, batch_size=batch_size, max_length=context_length,
    stride=context_length, shuffle=False
)
data_iter = iter(dataloader)
inputs, targets = next(data_iter)

## Torch Example

First, in `torch`, we can typically create input embeddings by creating a token embedding layer, and positional embedding layer concatenated together. The token embedding layer receives the input batch, and the positional embedding can be initialized to increasing numbers. It is pretty simple.

In [20]:
token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)
token_embeddings = token_embedding_layer(inputs)

positional_embedding_layer = torch.nn.Embedding(context_length, output_dim)
positional_embeddings = positional_embedding_layer(torch.arange(context_length))

input_embeddings = token_embeddings + positional_embeddings

print(input_embeddings[0:2])
print(input_embeddings.shape)

tensor([[[-0.2732,  2.1288,  0.8971,  ..., -1.2919, -0.0776,  2.9252],
         [-1.3840,  3.2146, -2.3263,  ...,  2.0142, -1.3516,  0.9443],
         [ 2.5248,  1.8677, -1.3374,  ...,  0.3811,  2.8022,  0.7201],
         [-0.1077, -0.5230,  0.9847,  ..., -0.7862, -0.1137, -1.3520]],

        [[-3.3425,  1.7041,  1.2159,  ..., -0.4244, -1.1343,  0.3792],
         [-2.1216,  0.7098, -0.6508,  ...,  0.8105, -2.8036,  0.3345],
         [ 1.7248,  1.1010, -0.3695,  ..., -1.2099,  1.5292,  1.6327],
         [ 0.7812, -0.2927,  1.2999,  ..., -1.6942,  0.6392, -0.6615]]],
       grad_fn=<SliceBackward0>)
torch.Size([8, 4, 256])


## tt-NN Example

Unfortunately life isn't as easy with `ttnn`, but we can get there. Let's create the token embeddings and positional embeddings one-by-one and we can combine them to create the input_embeddings. 

Various operations require the tensors to be on the device. So let's initialize it.

In [21]:
device_id = 0 
device = ttnn.open_device(device_id=device_id)

                  Metal | INFO     | Initializing device 0. Program cache is NOT enabled
                  Metal | INFO     | AI CLK for device 0 is:   1000 MHz


In [22]:
inputs_ttnn = ttnn.from_torch(inputs, dtype=ttnn.uint32)
targets_ttnn = ttnn.from_torch(targets, dtype=ttnn.uint32)

inputs_ttnn = ttnn.to_device(inputs_ttnn, device)
targets_ttnn = ttnn.to_device(targets_ttnn, device)

inputs_ttnn, targets_ttnn

(ttnn.Tensor([[   40,   367,  ...,  2885,  1464],
              [ 1807,  3619,  ...,   402,   271],
              ...,
              [ 1049,  5975,  ...,   284,   502],
              [  284,  3285,  ...,   326,    11]], shape=Shape([8, 4]), dtype=DataType::UINT32, layout=Layout::ROW_MAJOR),
 ttnn.Tensor([[  367,  2885,  ...,  1464,  1807],
              [ 3619,   402,  ...,   271, 10899],
              ...,
              [ 5975,   284,  ...,   502,   284],
              [ 3285,   326,  ...,    11,   287]], shape=Shape([8, 4]), dtype=DataType::UINT32, layout=Layout::ROW_MAJOR))

Creating an embedding tensor is more involved. We will need to initialize a weight matrix that has the dimensions of the vocabularly size and output dimensions.

The dimensions are (50257, 256)

In [23]:
token_embedding_weights_ttnn = ttnn.from_torch(
    torch.randn(vocab_size, output_dim),
    dtype=ttnn.bfloat16
)
token_embedding_weights_ttnn = ttnn.to_device(token_embedding_weights_ttnn, device)

token_embedding_weights_ttnn

ttnn.Tensor([[-1.49219, -0.12695,  ..., -0.68750, -1.35156],
             [-2.32812,  0.56641,  ..., -1.00781,  0.49219],
             ...,
             [ 0.22852, -0.76562,  ...,  0.34766, -0.01544],
             [-0.14844,  1.01562,  ..., -1.02344, -0.09717]], shape=Shape([50257, 256]), dtype=DataType::BFLOAT16, layout=Layout::ROW_MAJOR)

Now we can create the token_embeddings in one shot with `ttnn.embedding`. 

In [24]:
token_embeddings_ttnn = ttnn.embedding(inputs_ttnn, token_embedding_weights_ttnn)
token_embeddings_ttnn

ttnn.Tensor([[[-2.10938,  0.69531,  ..., -0.66406,  0.23926],
              [ 1.18750,  1.09375,  ...,  0.46484, -0.04858],
              ...,
              [-0.54297,  0.03662,  ...,  0.90625, -1.48438],
              [ 0.15820,  1.18750,  ..., -0.15430, -0.79688]],

             [[-2.28125, -0.88672,  ...,  0.57422, -0.40820],
              [-0.38672, -2.31250,  ..., -1.14844,  0.68359],
              ...,
              [-0.44727, -1.00781,  ...,  0.21191,  0.01501],
              [ 1.12500,  2.42188,  ..., -2.37500, -1.69531]],

             ...,

             [[-0.61328, -1.60156,  ..., -0.27930,  1.07031],
              [-0.42773,  0.78906,  ...,  0.52734, -1.40625],
              ...,
              [ 2.29688, -0.69141,  ...,  1.17969, -0.17969],
              [-0.69531, -1.21875,  ...,  0.91406, -0.88672]],

             [[ 2.29688, -0.69141,  ...,  1.17969, -0.17969],
              [-0.97266, -0.61328,  ...,  0.57422,  0.76953],
              ...,
              [ 0.07910,  1.101

We can repeat the same thing with positional embeddings

We'll need to generate some positional inputs first. We'll create a simple tensor from 0 to the context_length. 

In [25]:
positional_inputs_ttnn = ttnn.arange(end=context_length, dtype=ttnn.uint32)
positional_inputs_ttnn = ttnn.to_device(positional_inputs_ttnn, device)

positional_inputs_ttnn

ttnn.Tensor([    0,     1,  ...,     2,     3], shape=Shape([4]), dtype=DataType::UINT32, layout=Layout::ROW_MAJOR)

Now we can create positional embedding weights

In [26]:
positional_embeddings_weights = ttnn.from_torch(
    torch.randn(context_length, output_dim),
    dtype=ttnn.bfloat16
)
positional_embeddings_weights = ttnn.to_device(positional_embeddings_weights, device)

Create positional embeddings now. 

In [27]:
positional_embeddings_ttnn = ttnn.embedding(positional_inputs_ttnn, positional_embeddings_weights)
positional_embeddings_ttnn

ttnn.Tensor([[ 0.01794, -1.33594,  ...,  1.28125,  0.78125],
             [-0.25195,  0.05713,  ...,  1.07031,  0.60938],
             ...,
             [ 0.79688, -1.25000,  ...,  0.17578, -1.56250],
             [-0.19629,  0.48828,  ..., -0.88672, -1.84375]], shape=Shape([4, 256]), dtype=DataType::BFLOAT16, layout=Layout::ROW_MAJOR)

We're not quite done with the positional_embeddings_ttn yet. We have to now reshape for addition operation coming up. This involves:
1. Reshape the positional_embeddings_ttnn tensor to be the same number of dimensions as the token_embeddings_ttn.
2. Use repeat_interleave to make an effective addition broadcast across all elements in the tensor when added against the token_embeddings_ttnn

It is expected that we turn the (4, 246) shape into a (8, 4, 256) shape tensor

In [28]:
positional_embeddings_ttnn = ttnn.reshape(positional_embeddings_ttnn, (1, context_length, output_dim))
positional_embeddings_ttnn = ttnn.repeat_interleave(positional_embeddings_ttnn, repeats=batch_size, dim=0)
positional_embeddings_ttnn

ttnn.Tensor([[[ 0.01794, -1.33594,  ...,  1.28125,  0.78125],
              [-0.25195,  0.05713,  ...,  1.07031,  0.60938],
              ...,
              [ 0.79688, -1.25000,  ...,  0.17578, -1.56250],
              [-0.19629,  0.48828,  ..., -0.88672, -1.84375]],

             [[ 0.01794, -1.33594,  ...,  1.28125,  0.78125],
              [-0.25195,  0.05713,  ...,  1.07031,  0.60938],
              ...,
              [ 0.79688, -1.25000,  ...,  0.17578, -1.56250],
              [-0.19629,  0.48828,  ..., -0.88672, -1.84375]],

             ...,

             [[ 0.01794, -1.33594,  ...,  1.28125,  0.78125],
              [-0.25195,  0.05713,  ...,  1.07031,  0.60938],
              ...,
              [ 0.79688, -1.25000,  ...,  0.17578, -1.56250],
              [-0.19629,  0.48828,  ..., -0.88672, -1.84375]],

             [[ 0.01794, -1.33594,  ...,  1.28125,  0.78125],
              [-0.25195,  0.05713,  ...,  1.07031,  0.60938],
              ...,
              [ 0.79688, -1.250

We can now compute the input_embeddings with token_embeddings_tttn and positional_embeddings_ttn

In [30]:
input_embeddings_ttnn = ttnn.add(
    ttnn.tilize(token_embeddings_ttnn),
    ttnn.tilize(positional_embeddings_ttnn)
)
input_embeddings_ttnn

ttnn.Tensor([[[-2.09375, -0.64062,  ...,  0.61719,  1.02344],
              [ 0.93750,  1.14844,  ...,  1.53906,  0.56250],
              ...,
              [ 0.25391, -1.21094,  ...,  1.08594, -3.04688],
              [-0.03809,  1.67969,  ..., -1.03906, -2.64062]],

             [[-1.64062,  0.16406,  ..., -3.67188,  1.45312],
              [ 1.88281, -1.03906,  ..., -0.11719, -0.57812],
              ...,
              [ 0.64062, -0.12109,  ..., -0.50781, -1.46094],
              [ 0.51953, -0.50391,  ..., -0.61328,  2.45312]],

             ...,

             [[ 0.00000,  0.00000,  ..., -2.56250, -0.89844],
              [-4.28125,  1.70312,  ..., -2.46875,  1.97656],
              ...,
              [ 0.00000,  0.00000,  ..., -2.28125, -1.56250],
              [ 1.82812, -0.74219,  ...,  0.60938, -4.09375]],

             [[ 0.00000, -0.00000,  ..., 14637248544768.00000, 14637248544768.00000],
              [-0.00000, -0.00000,  ..., 29274497089536.00000, 14637248544768.00000],
  

Thre's a lot of padding inserted, which is why you will see extreme values at the end of the tensors. We can untilize

In [31]:
input_embeddings_ttnn = ttnn.untilize(input_embeddings_ttnn)
input_embeddings_ttnn



ttnn.Tensor([[[-2.09375, -0.64062,  ...,  0.61719,  1.02344],
              [ 0.93750,  1.14844,  ...,  1.53906,  0.56250],
              ...,
              [ 0.25391, -1.21094,  ...,  1.08594, -3.04688],
              [-0.03809,  1.67969,  ..., -1.03906, -2.64062]],

             [[-2.26562, -2.21875,  ...,  1.85938,  0.37305],
              [-0.64062, -2.25000,  ..., -0.07812,  1.29688],
              ...,
              [ 0.34961, -2.26562,  ...,  0.38867, -1.54688],
              [ 0.92969,  2.90625,  ..., -3.26562, -3.54688]],

             ...,

             [[-0.59375, -2.93750,  ...,  1.00000,  1.85156],
              [-0.67969,  0.84766,  ...,  1.60156, -0.79688],
              ...,
              [ 3.09375, -1.94531,  ...,  1.35938, -1.74219],
              [-0.89062, -0.73047,  ...,  0.02734, -2.73438]],

             [[ 2.31250, -2.03125,  ...,  2.46875,  0.60156],
              [-1.22656, -0.55469,  ...,  1.64844,  1.38281],
              ...,
              [ 0.87500, -0.148

Let's do a sanity check. We're expecting the same (8, 4, 256) shape.

This means a batch_size of 8, with 4 tokens in context, for 256 dimensions. The greater the dimensions the more "detail" we will have to record the embeddings for each token.

In [32]:
print(input_embeddings_ttnn[0:2])
print(input_embeddings_ttnn.shape)

ttnn.Tensor([[[-2.09375, -0.64062,  ...,  0.61719,  1.02344],
              [ 0.93750,  1.14844,  ...,  1.53906,  0.56250],
              ...,
              [ 0.25391, -1.21094,  ...,  1.08594, -3.04688],
              [-0.03809,  1.67969,  ..., -1.03906, -2.64062]],

             [[-1.64062,  0.16406,  ..., -3.67188,  1.45312],
              [ 1.88281, -1.03906,  ..., -0.11719, -0.57812],
              ...,
              [ 0.64062, -0.12109,  ..., -0.50781, -1.46094],
              [ 0.51953, -0.50391,  ..., -0.61328,  2.45312]]], shape=Shape([2, 4, 256]), dtype=DataType::BFLOAT16, layout=Layout::ROW_MAJOR)
Shape([8, 4, 256])


Finally, don't forget to clean up.

In [33]:
ttnn.close_device(device)

                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0
