In [5]:
# pip install tokenizers safetensors

import os, sys, math, random, textwrap
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional
from transformers import GPT2Tokenizer
from yaml import safe_load, Loader

sys.path.append(f"{os.environ['TT_METAL_HOME']}/tt-train/build/sources/ttml")
import _ttml as ttml

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)

set_seed()
# Change working directory to TT_METAL_HOME
os.chdir(os.environ['TT_METAL_HOME'])

@dataclass
class TransformerConfig:
    n_head: int = 12
    embed_dim: int = 768
    dropout: float = 0.2
    n_blocks : int = 12
    vocab_size: int = 96
    max_seq_len: int = 1024
    runner_type: str = "memory_efficient"
    weight_tying: str = "enabled"

In [6]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
print(os.getcwd())
transformer_cfg = safe_load(open("tt-train/configs/training_shakespeare_gpt2s.yaml", "r"))["training_config"]["transformer_config"]

/home/ubuntu/tt-metal


In [23]:
def build_causal_mask(T: int) -> ttml.autograd.Tensor:
    # [1,1,T,T] float32 with 1s for allowed positions (i >= j), else 0
    m = np.tril(np.ones((T, T), dtype=np.float32))
    return ttml.autograd.Tensor.from_numpy(m.reshape(1, 1, T, T), ttml.Layout.TILE, ttml.autograd.DataType.FLOAT32)



In [13]:
def create_model(cfg, vocab_size: int, seq_len: int):
    # GPT2 config via your bindings
    gcfg = ttml.models.gpt2.GPT2TransformerConfig()
    gcfg.num_heads = cfg["num_heads"]
    gcfg.embedding_dim = cfg["embedding_dim"]
    gcfg.num_blocks = cfg["num_blocks"]
    gcfg.vocab_size = int(vocab_size)
    gcfg.max_sequence_length = seq_len
    gcfg.dropout_prob = cfg["dropout_prob"]
    # optional flags exist (runner_type, weight_tying, positional_embedding_type, experimental, ...)
    # we keep defaults for a minimal demo

    model = ttml.models.gpt2.create_gpt2_model(gcfg)
    return model

vocab_size = tokenizer.vocab_size

if vocab_size % 32 != 0:
    print(f"Warning: vocab size {vocab_size} is not multiple of 32, padding for tilizing.")
    padded_vocab_size = ((tokenizer.vocab_size + 31) // 32) * 32

model = create_model(transformer_cfg, vocab_size, transformer_cfg["max_sequence_length"])
model


Transformer configuration:
    Vocab size: 50257
    Max sequence length: 1024
    Embedding dim: 768
    Num heads: 12
    Dropout probability: 0.2
    Num blocks: 12
    Positional embedding type: Trainable
    Runner type: Default
    Composite layernorm: false
    Weight tying: Disabled


<_ttml.models.gpt2.GPT2Transformer at 0x7fa15177cc70>

In [37]:
model.eval()

logits_mask = np.zeros((1, 1, 1, padded_vocab_size), dtype=np.float32)
logits_mask[:, :, :, vocab_size:] = 1e4

logits_mask_tensor = ttml.autograd.Tensor.from_numpy(logits_mask, ttml.Layout.ROW_MAJOR, ttml.autograd.DataType.BFLOAT16)   # [1,1,1,T], float32

input_str = "The difference between cats and dogs is:"
input_tokens = tokenizer.encode(input_str)

if len(input_tokens) < transformer_cfg["max_sequence_length"]:
    input_tokens += [tokenizer.eos_token_id] * (transformer_cfg["max_sequence_length"] - len(input_tokens))

input_tokens_tensor = ttml.autograd.Tensor.from_numpy(np.array(input_tokens, dtype=np.int32).reshape(1, 1, 1, -1), ttml.Layout.TILE, ttml.autograd.DataType.UINT32)  # [1, seq_len], int32
causal_mask = build_causal_mask(transformer_cfg["max_sequence_length"])  # [1,1,seq_len,seq_len], float32

# generator = ttml.autograd.AutoContext().get_generator()

while True:
    logits = model(input_tokens_tensor, causal_mask)  # [1,1,1, vocab_size]
    ttml.ops.sample.sample_op(logits, 1.0, 42, logits_mask_tensor)

2025-09-29 22:05:09.952 | critical |          Always | Broadcasting rule violation for rank -1, dim a: 50257, dim b: 50272 (assert.hpp:103)


RuntimeError: TT_FATAL @ /home/ubuntu/tt-metal/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp:301: a_dim == b_dim || a_dim == 1 || b_dim == 1
info:
Broadcasting rule violation for rank -1, dim a: 50257, dim b: 50272
backtrace:
 --- /home/ubuntu/tt-metal/build/lib/_ttnncpp.so(+0xaae4bb) [0x7fa14db4e4bb]
 --- ttnn::operations::binary_ng::BinaryNgDeviceOperation::validate_on_program_cache_hit(ttnn::operations::binary_ng::BinaryNgDeviceOperation::operation_attributes_t const&, ttnn::operations::binary_ng::BinaryNgDeviceOperation::tensor_args_t const&)
 --- ttnn::operations::binary_ng::BinaryNgDeviceOperation::validate_on_program_cache_miss(ttnn::operations::binary_ng::BinaryNgDeviceOperation::operation_attributes_t const&, ttnn::operations::binary_ng::BinaryNgDeviceOperation::tensor_args_t const&)
 --- /home/ubuntu/tt-metal/build/lib/_ttnncpp.so(_ZN4ttnn16device_operation6detail29launch_operation_with_adapterINS0_26MeshDeviceOperationAdapterINS_10operations9binary_ng23BinaryNgDeviceOperationEEEEEvRKNT_22operation_attributes_tERKNS8_13tensor_args_tERNS8_21tensor_return_value_tEPN2tt8tt_metal11distributed10MeshDeviceE+0x1af) [0x7fa14db081cf]
 --- ttnn::operations::binary_ng::BinaryNgDeviceOperation::tensor_return_value_t ttnn::device_operation::detail::launch_on_device<ttnn::operations::binary_ng::BinaryNgDeviceOperation>(ttnn::operations::binary_ng::BinaryNgDeviceOperation::operation_attributes_t const&, ttnn::operations::binary_ng::BinaryNgDeviceOperation::tensor_args_t const&)
 --- ttnn::operations::binary_ng::BinaryNgDeviceOperation::tensor_return_value_t ttnn::device_operation::detail::invoke<ttnn::operations::binary_ng::BinaryNgDeviceOperation>(ttnn::operations::binary_ng::BinaryNgDeviceOperation::operation_attributes_t const&, ttnn::operations::binary_ng::BinaryNgDeviceOperation::tensor_args_t const&)
 --- /home/ubuntu/tt-metal/build/lib/_ttnncpp.so(+0xa6789d) [0x7fa14db0789d]
 --- /home/ubuntu/tt-metal/build/lib/_ttnncpp.so(+0xa67524) [0x7fa14db07524]
 --- /home/ubuntu/tt-metal/build/lib/_ttnncpp.so(+0xa4ac67) [0x7fa14daeac67]
 --- ttnn::operations::binary::BinaryOperation<(ttnn::operations::binary::BinaryOpType)1>::invoke(tt::tt_metal::Tensor const&, tt::tt_metal::Tensor const&, std::optional<tt::tt_metal::DataType const> const&, std::optional<tt::tt_metal::MemoryConfig> const&, std::optional<tt::tt_metal::Tensor> const&, std::span<ttnn::operations::unary::BasicUnaryWithParam<float, int, unsigned int> const, 18446744073709551615ul>, std::span<ttnn::operations::unary::BasicUnaryWithParam<float, int, unsigned int> const, 18446744073709551615ul>, std::span<ttnn::operations::unary::BasicUnaryWithParam<float, int, unsigned int> const, 18446744073709551615ul>, std::optional<bool> const&)
 --- /home/ubuntu/tt-metal/tt-train/build/sources/ttml/_ttml.cpython-310-x86_64-linux-gnu.so(_ZNK4ttnn10decorators22registered_operation_tIXtlN7reflect6v1_2_512fixed_stringIcLm14EEEtlA15_cLc116ELc116ELc110ELc110ELc58ELc58ELc115ELc117ELc98ELc116ELc114ELc97ELc99ELc116EEEENS_10operations6binary15BinaryOperationILNS8_12BinaryOpTypeE1EEEE16invoke_compositeIJRN2tt8tt_metal6TensorESH_EEEDaDpOT_+0x10c) [0x7fa1503207ec]
 --- /home/ubuntu/tt-metal/tt-train/build/sources/ttml/_ttml.cpython-310-x86_64-linux-gnu.so(_ZNK4ttnn10decorators22registered_operation_tIXtlN7reflect6v1_2_512fixed_stringIcLm14EEEtlA15_cLc116ELc116ELc110ELc110ELc58ELc58ELc115ELc117ELc98ELc116ELc114ELc97ELc99ELc116EEEENS_10operations6binary15BinaryOperationILNS8_12BinaryOpTypeE1EEEE6invokeIJRN2tt8tt_metal6TensorESH_EEEDaDpOT_+0x30) [0x7fa1503206d0]
 --- /home/ubuntu/tt-metal/tt-train/build/sources/ttml/_ttml.cpython-310-x86_64-linux-gnu.so(_ZNK4ttnn10decorators22registered_operation_tIXtlN7reflect6v1_2_512fixed_stringIcLm14EEEtlA15_cLc116ELc116ELc110ELc110ELc58ELc58ELc115ELc117ELc98ELc116ELc114ELc97ELc99ELc116EEEENS_10operations6binary15BinaryOperationILNS8_12BinaryOpTypeE1EEEE13traced_invokeIJRN2tt8tt_metal6TensorESH_EEEDaDpOT_+0x89) [0x7fa150320639]
 --- /home/ubuntu/tt-metal/tt-train/build/sources/ttml/_ttml.cpython-310-x86_64-linux-gnu.so(_ZNK4ttnn10decorators22registered_operation_tIXtlN7reflect6v1_2_512fixed_stringIcLm14EEEtlA15_cLc116ELc116ELc110ELc110ELc58ELc58ELc115ELc117ELc98ELc116ELc114ELc97ELc99ELc116EEEENS_10operations6binary15BinaryOperationILNS8_12BinaryOpTypeE1EEEEclIJRN2tt8tt_metal6TensorESH_EEEDaDpOT_+0x30) [0x7fa15031f480]
 --- ttml::ttnn_fixed::sample(tt::tt_metal::Tensor const&, float, unsigned int, std::optional<tt::tt_metal::Tensor>)
 --- ttml::ops::sample_op(std::shared_ptr<ttml::autograd::Tensor> const&, float, unsigned int, std::shared_ptr<ttml::autograd::Tensor> const&)
 --- /home/ubuntu/tt-metal/tt-train/build/sources/ttml/_ttml.cpython-310-x86_64-linux-gnu.so(+0x55b28f) [0x7fa15025228f]
 --- /home/ubuntu/tt-metal/tt-train/build/sources/ttml/libnanobind.so(+0x1471c) [0x7fa25c85a71c]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyEval_EvalFrameDefault+0x5603) [0x5602a2f45763]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x25a566) [0x5602a3025566]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(PyEval_EvalCode+0x86) [0x5602a3025436]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x25fb4d) [0x5602a302ab4d]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x18b419) [0x5602a2f56419]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyEval_EvalFrameDefault+0x6c0) [0x5602a2f40820]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x1a7270) [0x5602a2f72270]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyEval_EvalFrameDefault+0x278e) [0x5602a2f428ee]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x1a7270) [0x5602a2f72270]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyEval_EvalFrameDefault+0x278e) [0x5602a2f428ee]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x1a7270) [0x5602a2f72270]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x277a5f) [0x5602a3042a5f]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x195eab) [0x5602a2f60eab]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyEval_EvalFrameDefault+0x807) [0x5602a2f40967]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyFunction_Vectorcall+0x7c) [0x5602a2f561bc]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyEval_EvalFrameDefault+0x6c0) [0x5602a2f40820]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyFunction_Vectorcall+0x7c) [0x5602a2f561bc]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyEval_EvalFrameDefault+0x807) [0x5602a2f40967]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x198311) [0x5602a2f63311]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(PyObject_Call+0x122) [0x5602a2f63fb2]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyEval_EvalFrameDefault+0x2a8e) [0x5602a2f42bee]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x198311) [0x5602a2f63311]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyEval_EvalFrameDefault+0x1990) [0x5602a2f41af0]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x1a7270) [0x5602a2f72270]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyEval_EvalFrameDefault+0x278e) [0x5602a2f428ee]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x1a7270) [0x5602a2f72270]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyEval_EvalFrameDefault+0x278e) [0x5602a2f428ee]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x1a7270) [0x5602a2f72270]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyEval_EvalFrameDefault+0x278e) [0x5602a2f428ee]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x1a7270) [0x5602a2f72270]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyEval_EvalFrameDefault+0x278e) [0x5602a2f428ee]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x1a7270) [0x5602a2f72270]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyEval_EvalFrameDefault+0x278e) [0x5602a2f428ee]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x1a7270) [0x5602a2f72270]
 --- /usr/lib/python3.10/lib-dynload/_asyncio.cpython-310-x86_64-linux-gnu.so(+0x928e) [0x7fa2a405e28e]
 --- /usr/lib/python3.10/lib-dynload/_asyncio.cpython-310-x86_64-linux-gnu.so(+0xa49b) [0x7fa2a405f49b]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x18a384) [0x5602a2f55384]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x25bc65) [0x5602a3026c65]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x2c895a) [0x5602a309395a]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x17e50f) [0x5602a2f4950f]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyEval_EvalFrameDefault+0x65c3) [0x5602a2f46723]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyFunction_Vectorcall+0x7c) [0x5602a2f561bc]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyEval_EvalFrameDefault+0x807) [0x5602a2f40967]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyFunction_Vectorcall+0x7c) [0x5602a2f561bc]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyEval_EvalFrameDefault+0x807) [0x5602a2f40967]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyFunction_Vectorcall+0x7c) [0x5602a2f561bc]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyEval_EvalFrameDefault+0x807) [0x5602a2f40967]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyFunction_Vectorcall+0x7c) [0x5602a2f561bc]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyEval_EvalFrameDefault+0x807) [0x5602a2f40967]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyFunction_Vectorcall+0x7c) [0x5602a2f561bc]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyEval_EvalFrameDefault+0x807) [0x5602a2f40967]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x198311) [0x5602a2f63311]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyEval_EvalFrameDefault+0x5603) [0x5602a2f45763]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x25a566) [0x5602a3025566]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(PyEval_EvalCode+0x86) [0x5602a3025436]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x25fb4d) [0x5602a302ab4d]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x18b419) [0x5602a2f56419]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyEval_EvalFrameDefault+0x6c0) [0x5602a2f40820]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyFunction_Vectorcall+0x7c) [0x5602a2f561bc]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyEval_EvalFrameDefault+0x6c0) [0x5602a2f40820]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_PyFunction_Vectorcall+0x7c) [0x5602a2f561bc]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(+0x27545d) [0x5602a304045d]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(Py_RunMain+0x128) [0x5602a303f218]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(Py_BytesMain+0x2d) [0x5602a301947d]
 --- /lib/x86_64-linux-gnu/libc.so.6(+0x29d90) [0x7fa2a5474d90]
 --- /lib/x86_64-linux-gnu/libc.so.6(__libc_start_main+0x80) [0x7fa2a5474e40]
 --- /home/ubuntu/.tenstorrent-venv/bin/python(_start+0x25) [0x5602a3019375]
