# Embeddings Example with TTNN

First define some variables that we will be using:

In [1]:
# GPT-2 
vocab_size = 50257
output_dim = 256

Import `torch` and `ttnn` libraries. If `ttnn` is succesffully imported, you'll see some logging related to `Config`.

In [2]:
import torch
import ttnn

2025-04-28 16:07:12.596 | DEBUG    | ttnn:<module>:83 - Initial ttnn.CONFIG:
Config{cache_path=/home/avgdev/.cache/ttnn,model_cache_path=/home/avgdev/.cache/ttnn/models,tmp_dir=/tmp/ttnn,enable_model_cache=false,enable_fast_runtime_mode=true,throw_exception_on_fallback=false,enable_logging=false,enable_graph_report=false,enable_detailed_buffer_report=false,enable_detailed_tensor_report=false,enable_comparison_mode=false,comparison_mode_should_raise_exception=false,comparison_mode_pcc=0.9999,root_report_path=generated/ttnn/reports,report_name=std::nullopt,std::nullopt}


## Prepare Dataset
We will use the short story `the-verdict.txt` as our sample dataset. Download the text to `data/the-verdict.txt`

In [3]:
import os
import urllib.request

if not os.path.exists("data/the-verdict.txt"):
    url = ("https://raw.githubusercontent.com/rasbt/"
           "LLMs-from-scratch/main/ch02/01_main-chapter-code/"
           "the-verdict.txt")
    
    file_path = "data/the-verdict.txt"
    urllib.request.urlretrieve(url, file_path)

In [4]:
with open("data/the-verdict.txt", "r", encoding="utf-8") as f:
    raw_text = f.read()

print(raw_text[:50])

I HAD always thought Jack Gisburn rather a cheap g


Define the `context_length` and `batch_size` to be used throughout the notebook. The `batch_size` is the number of tensors that we get total for each iteration of our dataloader.

In [5]:
context_length = 4
batch_size = 8

`pip` install `tiktoken` if not already in the system.

In [6]:
!pip install tiktoken

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cpu
Collecting tiktoken
  Using cached tiktoken-0.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Collecting regex>=2022.1.18 (from tiktoken)
  Using cached regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
Using cached tiktoken-0.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
Using cached regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (781 kB)
Installing collected packages: regex, tiktoken
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2/2[0m [tiktoken]
[1A[2KSuccessfully installed regex-2024.11.6 tiktoken-0.9.0


## Creating Custom Dataloader

This dataloader will store our inputs and targets tensors. They are still in `torch` form. We can initialize the dataloader using the `create_dataloader` function call

In [8]:
import torch
from torch.utils.data import Dataset, DataLoader
import tiktoken

class GPTDataset(Dataset):
    def __init__(self, txt, tokenizer, max_length, stride):
        self._input_ids = []
        self._target_ids = []

        token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})

        for i in range(0, len(token_ids) - max_length, stride):
            input_chunk = token_ids[i:i + max_length]
            target_chunk = token_ids[i + 1: i + max_length + 1]

            self._input_ids.append(torch.tensor(input_chunk))
            self._target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self._input_ids)

    def __getitem__(self, idx):
        return self._input_ids[idx], self._target_ids[idx]
    
def create_dataloader(txt, batch_size=4, max_length=256, stride=128, shuffle=True, drop_last=True):
    tokenizer = tiktoken.get_encoding("gpt2")
    dataset = GPTDataset(txt, tokenizer, max_length, stride)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=0
    )

    return dataloader


In [9]:
dataloader = create_dataloader(
    raw_text, 
    batch_size=batch_size, 
    max_length=context_length, 
    stride=context_length,
    shuffle=False
)

data_iter = iter(dataloader)
inputs, targets = next(data_iter)

inputs, targets

(tensor([[   40,   367,  2885,  1464],
         [ 1807,  3619,   402,   271],
         [10899,  2138,   257,  7026],
         [15632,   438,  2016,   257],
         [  922,  5891,  1576,   438],
         [  568,   340,   373,   645],
         [ 1049,  5975,   284,   502],
         [  284,  3285,   326,    11]]),
 tensor([[  367,  2885,  1464,  1807],
         [ 3619,   402,   271, 10899],
         [ 2138,   257,  7026, 15632],
         [  438,  2016,   257,   922],
         [ 5891,  1576,   438,   568],
         [  340,   373,   645,  1049],
         [ 5975,   284,   502,   284],
         [ 3285,   326,    11,   287]]))

# Torch Embeddings Example

It is easier to do this exercise first using `torch` and then we'll adapt and rewrite it using `ttnn` to the best of our ability.

In [12]:
token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)
token_embeddings = token_embedding_layer(inputs)
token_embeddings, token_embeddings.shape

(tensor([[[-0.8662, -0.0436,  1.5156,  ...,  0.2127, -3.2786, -0.5743],
          [ 1.3781, -0.1197,  0.0313,  ...,  0.2775,  0.3691,  0.2817],
          [-0.4382, -2.2696, -0.1110,  ..., -0.2127, -1.9321, -0.6311],
          [ 1.1630,  0.4681, -1.2171,  ...,  0.8214,  0.1959,  0.8996]],
 
         [[-0.4888,  0.3827, -1.6942,  ...,  1.2904,  0.5649,  0.6273],
          [ 0.1420, -0.7241, -1.1617,  ..., -1.3439,  1.6568, -0.6777],
          [-0.8134,  0.1044,  0.1628,  ...,  0.7499, -0.0285, -0.6549],
          [-0.1758,  0.5710,  0.9514,  ...,  4.1751,  1.1144, -2.0448]],
 
         [[-0.8470,  0.5591, -0.4546,  ..., -0.1664, -0.5267, -0.0662],
          [ 0.2358, -0.0514,  0.6138,  ...,  0.5978,  0.0912,  0.6538],
          [-1.0551, -0.2106,  1.7852,  ..., -2.0223,  1.3278, -1.5084],
          [-1.0741,  1.2759,  0.2565,  ...,  1.2733, -1.2093,  0.8373]],
 
         ...,
 
         [[ 0.4378,  1.2705,  0.2923,  ..., -0.5806, -0.2155,  0.4700],
          [ 2.2364, -1.3076, -0.6452,  

In [None]:
positional_embedding_layer = torch.nn.Embedding(context_length, output_dim)
positional_embeddings = positional_embedding_layer(torch.arange(context_length))

positional_embeddings, positional_embeddings.shape

(tensor([[ 0.6888, -0.6568, -0.8298,  ..., -0.3943,  1.1660, -0.2200],
         [ 2.4855,  1.4199, -0.4613,  ..., -0.4492, -0.6024,  0.7097],
         [ 1.2576, -1.4132,  0.4963,  ..., -3.4569,  1.3010,  0.8833],
         [-1.2176, -0.0290,  0.2775,  ...,  0.5771,  1.4045,  0.8579]],
        grad_fn=<EmbeddingBackward0>),
 torch.Size([4, 256]))

In [16]:
input_embeddings = token_embeddings + positional_embeddings
input_embeddings, input_embeddings.shape

(tensor([[[-1.7740e-01, -7.0034e-01,  6.8585e-01,  ..., -1.8157e-01,
           -2.1127e+00, -7.9434e-01],
          [ 3.8636e+00,  1.3002e+00, -4.3003e-01,  ..., -1.7162e-01,
           -2.3325e-01,  9.9138e-01],
          [ 8.1939e-01, -3.6828e+00,  3.8526e-01,  ..., -3.6697e+00,
           -6.3107e-01,  2.5222e-01],
          [-5.4543e-02,  4.3906e-01, -9.3961e-01,  ...,  1.3985e+00,
            1.6004e+00,  1.7574e+00]],
 
         [[ 1.9999e-01, -2.7411e-01, -2.5240e+00,  ...,  8.9612e-01,
            1.7308e+00,  4.0726e-01],
          [ 2.6275e+00,  6.9581e-01, -1.6231e+00,  ..., -1.7930e+00,
            1.0544e+00,  3.1978e-02],
          [ 4.4418e-01, -1.3088e+00,  6.5906e-01,  ..., -2.7070e+00,
            1.2725e+00,  2.2835e-01],
          [-1.3934e+00,  5.4199e-01,  1.2289e+00,  ...,  4.7522e+00,
            2.5189e+00, -1.1869e+00]],
 
         [[-1.5829e-01, -9.7678e-02, -1.2844e+00,  ..., -5.6076e-01,
            6.3923e-01, -2.8620e-01],
          [ 2.7213e+00,  1.3685

# ttnn Example

Let's rewrite everything using `ttnn` but first let's open a device by using `open_device` and storing the handle.

In [17]:
device_id = 0
device = ttnn.open_device(device_id=device_id)


                 Device | INFO     | Opening user mode device driver
[32m2025-04-28 16:37:58.737[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Opened PCI device 0; KMD version: 1.33.0, IOMMU: disabled

[32m2025-04-28 16:37:58.752[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Opened PCI device 0; KMD version: 1.33.0, IOMMU: disabled
[32m2025-04-28 16:37:58.754[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Harvesting mask for chip 0 is 0x200 (physical layout: 0x1, logical: 0x200, simulated harvesting mask: 0x0).
[32m2025-04-28 16:37:58.755[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Opened PCI device 0; KMD version: 1.33.0, IOMMU: disabled
[32m2025-04-28 16:37:58.756[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected PCI devices: [0]
[32m2025-04-28 16:37:58.756[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using local chip ids: 

New chip! We now have 1 chips
Chip initialization complete (found )
Chip initializing complete...
 ARC

 [4/4] DRAM

 [16/16] ETH

 CPU

Chip detection complete (found )


In [18]:
inputs_ttnn = ttnn.from_torch(inputs, dtype=ttnn.uint32)
targets_ttnn = ttnn.from_torch(targets, dtype=ttnn.uint32)

inputs_ttnn, targets_ttnn

(ttnn.Tensor([[   40,   367,  ...,  2885,  1464],
              [ 1807,  3619,  ...,   402,   271],
              ...,
              [ 1049,  5975,  ...,   284,   502],
              [  284,  3285,  ...,   326,    11]], shape=Shape([8, 4]), dtype=DataType::UINT32, layout=Layout::ROW_MAJOR),
 ttnn.Tensor([[  367,  2885,  ...,  1464,  1807],
              [ 3619,   402,  ...,   271, 10899],
              ...,
              [ 5975,   284,  ...,   502,   284],
              [ 3285,   326,  ...,    11,   287]], shape=Shape([8, 4]), dtype=DataType::UINT32, layout=Layout::ROW_MAJOR))

In [19]:
inputs_ttnn = ttnn.to_device(inputs_ttnn, device)
targets_ttnn = ttnn.to_device(targets_ttnn, device)

inputs_ttnn, targets_ttnn

(ttnn.Tensor([[   40,   367,  ...,  2885,  1464],
              [ 1807,  3619,  ...,   402,   271],
              ...,
              [ 1049,  5975,  ...,   284,   502],
              [  284,  3285,  ...,   326,    11]], shape=Shape([8, 4]), dtype=DataType::UINT32, layout=Layout::ROW_MAJOR),
 ttnn.Tensor([[  367,  2885,  ...,  1464,  1807],
              [ 3619,   402,  ...,   271, 10899],
              ...,
              [ 5975,   284,  ...,   502,   284],
              [ 3285,   326,  ...,    11,   287]], shape=Shape([8, 4]), dtype=DataType::UINT32, layout=Layout::ROW_MAJOR))

In [20]:
token_embedding_weights_ttnn = ttnn.from_torch(
    torch.randn(vocab_size, output_dim),
    dtype=ttnn.bfloat16
)
token_embedding_weights_ttnn = ttnn.to_device(token_embedding_weights_ttnn, device)

token_embedding_weights_ttnn

ttnn.Tensor([[-0.53906, -0.62500,  ...,  0.75781, -0.21094],
             [ 0.85547, -0.80859,  ..., -0.12158, -0.64062],
             ...,
             [ 0.49414,  1.69531,  ..., -1.25000,  0.45898],
             [-0.90234, -0.91016,  ..., -0.05908, -0.94141]], shape=Shape([50257, 256]), dtype=DataType::BFLOAT16, layout=Layout::ROW_MAJOR)

In [21]:
token_embeddings_ttnn = ttnn.embedding(inputs_ttnn, token_embedding_weights_ttnn)
token_embeddings_ttnn

ttnn.Tensor([[[ 0.37500, -0.37305,  ..., -0.50391, -0.78516],
              [ 1.14062, -0.59766,  ...,  0.87500, -0.83984],
              ...,
              [-0.67578,  0.89453,  ..., -0.59766,  2.20312],
              [ 0.42969,  0.42578,  ...,  0.72266,  1.18750]],

             [[-0.78906,  0.70312,  ..., -0.07568,  0.79688],
              [ 0.36914, -1.42969,  ..., -2.81250,  0.01221],
              ...,
              [ 1.07031, -0.26562,  ...,  0.52734, -0.57422],
              [-0.13281,  0.92188,  ..., -0.99219, -1.99219]],

             ...,

             [[ 0.45703, -1.40625,  ...,  0.71484, -2.64062],
              [ 0.33594,  0.25391,  ..., -0.14551,  0.82422],
              ...,
              [ 0.71484, -0.55859,  ..., -0.55078, -0.31836],
              [-0.82422, -0.77734,  ..., -0.17090, -0.05322]],

             [[ 0.71484, -0.55859,  ..., -0.55078, -0.31836],
              [ 0.29297, -0.41211,  ..., -0.04321,  0.75000],
              ...,
              [-0.03540, -0.503

In [22]:
positional_inputs_ttnn = ttnn.arange(end=context_length, dtype=ttnn.uint32)
positional_inputs_ttnn = ttnn.to_device(positional_inputs_ttnn, device)

positional_inputs_ttnn

ttnn.Tensor([    0,     1,  ...,     2,     3], shape=Shape([4]), dtype=DataType::UINT32, layout=Layout::ROW_MAJOR)

In [24]:
positional_embeddings_weights_ttnn = ttnn.from_torch(
    torch.randn(context_length, output_dim),
    dtype=ttnn.bfloat16
)
positional_embeddings_weights_ttnn = ttnn.to_device(positional_embeddings_weights_ttnn, device)

positional_embeddings_weights_ttnn

ttnn.Tensor([[ 0.50781,  0.34961,  ...,  1.02344, -0.26953],
             [-0.48438, -0.66797,  ..., -1.00781, -0.40820],
             ...,
             [-0.49023, -0.19727,  ...,  0.14355, -1.26562],
             [ 1.04688,  1.87500,  ..., -0.72656, -0.57422]], shape=Shape([4, 256]), dtype=DataType::BFLOAT16, layout=Layout::ROW_MAJOR)

In [25]:
positional_embeddings_ttnn = ttnn.embedding(positional_inputs_ttnn, positional_embeddings_weights_ttnn)
positional_embeddings_ttnn

ttnn.Tensor([[ 0.50781,  0.34961,  ...,  1.02344, -0.26953],
             [-0.48438, -0.66797,  ..., -1.00781, -0.40820],
             ...,
             [-0.49023, -0.19727,  ...,  0.14355, -1.26562],
             [ 1.04688,  1.87500,  ..., -0.72656, -0.57422]], shape=Shape([4, 256]), dtype=DataType::BFLOAT16, layout=Layout::ROW_MAJOR)

In [26]:
# token embeddings -> [8, 4, 256]
# position embeddings -> [4, 256]

positional_embeddings_ttnn = ttnn.reshape(positional_embeddings_ttnn, (1, context_length, output_dim))
positional_embeddings_ttnn = ttnn.repeat_interleave(positional_embeddings_ttnn, repeats=batch_size, dim=0)
positional_embeddings_ttnn

ttnn.Tensor([[[ 0.50781,  0.34961,  ...,  1.02344, -0.26953],
              [-0.48438, -0.66797,  ..., -1.00781, -0.40820],
              ...,
              [-0.49023, -0.19727,  ...,  0.14355, -1.26562],
              [ 1.04688,  1.87500,  ..., -0.72656, -0.57422]],

             [[ 0.50781,  0.34961,  ...,  1.02344, -0.26953],
              [-0.48438, -0.66797,  ..., -1.00781, -0.40820],
              ...,
              [-0.49023, -0.19727,  ...,  0.14355, -1.26562],
              [ 1.04688,  1.87500,  ..., -0.72656, -0.57422]],

             ...,

             [[ 0.50781,  0.34961,  ...,  1.02344, -0.26953],
              [-0.48438, -0.66797,  ..., -1.00781, -0.40820],
              ...,
              [-0.49023, -0.19727,  ...,  0.14355, -1.26562],
              [ 1.04688,  1.87500,  ..., -0.72656, -0.57422]],

             [[ 0.50781,  0.34961,  ...,  1.02344, -0.26953],
              [-0.48438, -0.66797,  ..., -1.00781, -0.40820],
              ...,
              [-0.49023, -0.197

In [27]:
token_embeddings_ttnn = ttnn.to_layout(token_embeddings_ttnn, layout=ttnn.TILE_LAYOUT, device=device)
positional_embeddings_ttnn = ttnn.to_layout(positional_embeddings_ttnn, layout=ttnn.TILE_LAYOUT, device=device)

input_embeddings_ttnn = ttnn.add(
    token_embeddings_ttnn,
    positional_embeddings_ttnn    
)

input_embeddings_ttnn



ttnn.Tensor([[[ 0.88281, -0.02344,  ...,  0.51953, -1.05469],
              [ 0.65625, -1.26562,  ..., -0.13281, -1.25000],
              ...,
              [-1.16406,  0.69922,  ..., -0.45508,  0.93750],
              [ 1.47656,  2.29688,  ..., -0.00391,  0.61328]],

             [[-0.28125,  1.05469,  ...,  0.94922,  0.52734],
              [-0.11523, -2.09375,  ..., -3.82812, -0.39648],
              ...,
              [ 0.58203, -0.46289,  ...,  0.67188, -1.84375],
              [ 0.91406,  2.79688,  ..., -1.71875, -2.56250]],

             ...,

             [[ 0.96484, -1.05469,  ...,  1.74219, -2.90625],
              [-0.14844, -0.41406,  ..., -1.15625,  0.41602],
              ...,
              [ 0.22461, -0.75781,  ..., -0.40820, -1.58594],
              [ 0.22266,  1.10156,  ..., -0.89844, -0.62891]],

             [[ 1.22656, -0.20898,  ...,  0.47266, -0.58984],
              [-0.19141, -1.07812,  ..., -1.05469,  0.34180],
              ...,
              [-0.52734, -0.703

In [28]:
ttnn.close_device(device)

                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0
