# GPT-2 LLM Architecture

The time has come for us to now code the GPT-2 LLM architecture in `ttnn`. It involves putting everything together. 

We define what GPT-2 124M looks like

In [49]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,    # Vocabulary size
    "context_length": 1024, # Context length
    "emb_dim": 768,         # Embedding dimension
    "n_heads": 12,          # Number of attention heads
    "n_layers": 12,         # Number of layers
    "drop_rate": 0.1,       # Dropout rate
    "qkv_bias": False       # Query-Key-Value bias
}

# Torch Implementation

We will be brining back a lot of code we have written in the previous notebooks, but will now tweak and adjust things as necessary.

In [50]:
import torch
from torch import nn

In [51]:
class DummyGPTModel(nn.Module):
  def __init__(self, cfg):
    super().__init__()

    self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
    self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
    self.drop_emb = nn.Dropout(cfg["drop_rate"])
    self.trf_blocks = nn.Sequential(
      *[DummyTransformerBlock(cfg) for _ in range(cfg["n_layers"])]
    )
    self.final_norm = DummyLayerNorm(cfg["emb_dim"])
    self.out_head = nn.Linear(
      cfg["emb_dim"], cfg["vocab_size"], bias=False
    )

  def forward(self, in_idx):
    batch_size, seq_len = in_idx.shape
    tok_embeds = self.tok_emb(in_idx)
    pos_embeds = self.pos_emb(
      torch.arange(seq_len, device=in_idx.device)
    )

    x = tok_embeds + pos_embeds
    x = self.drop_emb(x)
    x = self.trf_blocks(x)
    x = self.final_norm(x)

    logits = self.out_head(x)
    
    return logits

In [52]:
class DummyTransformerBlock(nn.Module):
  def __init__(self, cfg):
    super().__init__()

  def forward(self, x):
    return x

class DummyLayerNorm(nn.Module):
  def __init__(self, normalized_shape, eps=1e-5):
    super().__init__()

  def forward(self, x):
    return x

    

In [53]:
import tiktoken

In [54]:
tokenizer = tiktoken.get_encoding("gpt2")
batch = []
txt1 = "Every effort moves you"
txt2 = "Every day holds a"

batch.append(torch.tensor(tokenizer.encode(txt1)))
batch.append(torch.tensor(tokenizer.encode(txt2)))
batch = torch.stack(batch, dim=0)

print(batch)

tensor([[6109, 3626, 6100,  345],
        [6109, 1110, 6622,  257]])


In [55]:
torch.manual_seed(123)

model = DummyGPTModel(GPT_CONFIG_124M)
logits = model(batch)


In [56]:
logits.shape, logits

(torch.Size([2, 4, 50257]),
 tensor([[[-0.9289,  0.2748, -0.7557,  ..., -1.6070,  0.2702, -0.5888],
          [-0.4476,  0.1726,  0.5354,  ..., -0.3932,  1.5285,  0.8557],
          [ 0.5680,  1.6053, -0.2155,  ...,  1.1624,  0.1380,  0.7425],
          [ 0.0447,  2.4787, -0.8843,  ...,  1.3219, -0.0864, -0.5856]],
 
         [[-1.5474, -0.0542, -1.0571,  ..., -1.8061, -0.4494, -0.6747],
          [-0.8422,  0.8243, -0.1098,  ..., -0.1434,  0.2079,  1.2046],
          [ 0.1355,  1.1858, -0.1453,  ...,  0.0869, -0.1590,  0.1552],
          [ 0.1666, -0.8138,  0.2307,  ...,  2.5035, -0.3055, -0.3083]]],
        grad_fn=<UnsafeViewBackward0>))

## Layer Normalization

In [57]:
torch.manual_seed(123)
batch_example = torch.randn(2, 5)
layer = nn.Sequential(nn.Linear(5, 6), nn.ReLU())
out = layer(batch_example)
print(out)

tensor([[0.2260, 0.3470, 0.0000, 0.2216, 0.0000, 0.0000],
        [0.2133, 0.2394, 0.0000, 0.5198, 0.3297, 0.0000]],
       grad_fn=<ReluBackward0>)


In [58]:
mean = out.mean(dim=-1, keepdim=True)
variance = out.var(dim=-1, keepdim=True)

mean, variance

(tensor([[0.1324],
         [0.2170]], grad_fn=<MeanBackward1>),
 tensor([[0.0231],
         [0.0398]], grad_fn=<VarBackward0>))

In [59]:
out_norm = (out - mean) / torch.sqrt(variance)

mean = out_norm.mean(dim=-1, keepdim=True)
variance = out_norm.var(dim=-1, keepdim=True)

out_norm, mean, variance

(tensor([[ 0.6159,  1.4126, -0.8719,  0.5872, -0.8719, -0.8719],
         [-0.0189,  0.1121, -1.0876,  1.5173,  0.5647, -1.0876]],
        grad_fn=<DivBackward0>),
 tensor([[    0.0000],
         [    0.0000]], grad_fn=<MeanBackward1>),
 tensor([[1.0000],
         [1.0000]], grad_fn=<VarBackward0>))

In [60]:
torch.set_printoptions(sci_mode=False)
print("Mean", mean)
print("Variance", variance)

Mean tensor([[    0.0000],
        [    0.0000]], grad_fn=<MeanBackward1>)
Variance tensor([[1.0000],
        [1.0000]], grad_fn=<VarBackward0>)


Finally, now an implementation!

In [61]:
class LayerNorm(nn.Module):
  def __init__(self, emb_dim):
    super().__init__()

    self.eps = 1e-5
    
    self.scale = nn.Parameter(torch.ones(emb_dim))
    self.shift = nn.Parameter(torch.zeros(emb_dim))

  def forward(self, x):
    mean = x.mean(dim=-1, keepdim=True)
    var = x.var(dim=-1, keepdim=True, unbiased=False)

    norm_x = (x - mean) / torch.sqrt(var + self.eps)

    return self.scale * norm_x + self.shift

In [62]:
ln = LayerNorm(emb_dim=5)
out_ln = ln(batch_example)

mean = out_ln.mean(dim=-1, keepdim=True)
variance = out_ln.var(dim=-1, unbiased=False, keepdim=True)

print("Mean", mean)
print("Variance", variance)

Mean tensor([[    -0.0000],
        [     0.0000]], grad_fn=<MeanBackward1>)
Variance tensor([[1.0000],
        [1.0000]], grad_fn=<VarBackward0>)


## GELU Activation


In [63]:
class GELU(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, x):
    return 0.5 * x * (1 + torch.tanh(
      torch.sqrt(torch.tensor(2.0 / torch.pi)) *
      (x + 0.044715 * torch.pow(x, 3))
    ))

## Feed Forward

linear, GELU, linear

In [None]:
class FeedForward(nn.Module):
  def __init__(self, cfg):
    super().__init__()

    self.layer = nn.Sequential(
      nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
      GELU(),
      nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"])
    )

  def forward(self, x):
    return self.layer(x)

## TransformerBlock

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
    super().__init__()

    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out // num_heads # Reduce the projection dimension to match the output dim

    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    self.out_proj = nn.Linear(d_out, d_out)
    self.dropout = nn.Dropout(dropout)

    self.mask = torch.triu(
      torch.ones(context_length, context_length),
      diagonal=1
    )

  def forward(self, x):
    b, num_tokens, d_in = x.shape

    keys = self.W_key(x)
    queries = self.W_query(x)
    values = self.W_value(x)

    keys = keys.view(
      b,
      num_tokens,
      self.num_heads,
      self.head_dim
    )
    values = values.view(
      b,
      num_tokens,
      self.num_heads,
      self.head_dim
    )
    queries = queries.view(
      b,
      num_tokens,
      self.num_heads,
      self.head_dim
    )
    
    keys = keys.transpose(1, 2)
    queries = queries.transpose(1, 2)
    values = values.transpose(1, 2)
    
    attn_scores = queries @ keys.transpose(2, 3)
    mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

    attn_scores.masked_fill_(mask_bool, -torch.inf)

    attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
    attn_weights = self.dropout(attn_weights)

    context_vec = torch.transpose((attn_weights @ values), 1, 2)
    context_vec = context_vec.contiguous().view(
      b,
      num_tokens,
      self.d_out
    )

    context_vec = self.out_proj(context_vec)

    return context_vec

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.att = MultiHeadAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            num_heads=cfg["n_heads"],
            dropout=cfg["drop_rate"],
            qkv_bias=cfg["qkv_bias"]
        )

        self.ff = FeedForward(cfg)

        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])

        self.drop_shortcut = nn.Dropout(cfg["drop_rate"])

    def forward(self, x):
        shortcut = x
        x = self.norm1(x)
        x = self.att(x)
        x = self.drop_shortcut(x)
        x = x + shortcut

        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_shortcut(x)
        x = x + shortcut
        
        return x


In [None]:
torch.manual_seed(123)
x = torch.rand(2, 4, 768)
block = TransformerBlock(GPT_CONFIG_124M)
output = block(x)

x.shape, output.shape

## Coding the GPT-2 Architecture

In [None]:
class GPTModel(nn.Module):
  def __init__(self, cfg):
    super().__init__()

    self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
    self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
    self.drop_emb = nn.Dropout(cfg["drop_rate"])

    self.trf_blocks = nn.Sequential(
      *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
    )

    self.final_norm = LayerNorm(cfg["emb_dim"])
    self.out_head = nn.Linear(
      cfg["emb_dim"], cfg["vocab_size"], bias=False
    )

  def forward(self, in_idx):
    batch_size, seq_len = in_idx.shape
    tok_embeds = self.tok_emb(in_idx)

    pos_embeds = self.pos_emb(
      torch.arange(seq_len, device=in_idx.device)
    )

    x = tok_embeds + pos_embeds
    x = self.drop_emb(x)
    x = self.trf_blocks(x)
    x = self.final_norm(x)

    logits = self.out_head(x)

    return logits

In [None]:
torch.manual_seed(123)
model = GPTModel(GPT_CONFIG_124M)

out = model(batch)
print("Input batch:\n", batch)
print("Input batch shape:\n", batch.shape)
print(out)

out.shape

In [None]:
start_context = "Hello, I am"
encoded = tokenizer.encode(start_context)
print("encoded:", encoded)

encoded_tensor = torch.tensor(encoded).unsqueeze(0)
print("encoded_tensor.shape:", encoded_tensor.shape)

In [None]:
from scripts.generate import generate_text_simple

model.eval()
out = generate_text_simple(
  model=model,
  idx=encoded_tensor,
  max_new_tokens=6,
  context_size=GPT_CONFIG_124M["context_length"]
)
print("output:", out)
print("output length:", len(out[0]))

In [None]:
decoded_text = tokenizer.decode(out.squeeze(0).tolist())
decoded_text

## Porting to tt-train

I am personally not skilled enough to use the other frameworks right now. so let's be a little creative and just focus on the forward pass.

It will become a problem later on when we need to do backpropagation, but let's not worry about that now. (still looking into how to do this with the tenstorrent hardware)

Import everything

In [None]:
import ttnn
import torch
from torch import nn
import tiktoken

torch.manual_seed(123)

In [None]:
tokenizer = tiktoken.get_encoding("gpt2")
batch = []
txt1 = "Every effort moves you"
txt2 = "Every day holds a"

batch.append(torch.tensor(tokenizer.encode(txt1)))
batch.append(torch.tensor(tokenizer.encode(txt2)))
batch = torch.stack(batch, dim=0)

print(batch)

In [None]:
class LayerNorm_ttnn(nn.Module):
  def __init__(self, emb_dim, device):
    super().__init__()

    self.device = device

    self.eps = 1e-5
    
    self.scale = nn.Parameter(torch.ones(emb_dim))
    self.shift = nn.Parameter(torch.zeros(emb_dim))

    self.scale_ttnn = ttnn.from_torch(
      self.scale,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      memory_config=ttnn.L1_MEMORY_CONFIG,
      device=self.device
    )
    self.shift_ttnn = ttnn.from_torch(
      self.shift,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      memory_config=ttnn.L1_MEMORY_CONFIG,
      device=self.device
    )

  def forward(self, x):
    x_ttnn = ttnn.from_torch(
      x,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device
    )
    
    mean_ttnn = ttnn.mean(
      x_ttnn,
      dim=-1
    )
    var_ttnn = ttnn.var(
      x_ttnn,
      dim=-1
    )

    norm_x_ttnn = ttnn.div(
      ttnn.subtract(
        x_ttnn,
        mean_ttnn
      ),
      ttnn.sqrt(
        ttnn.add(
          var_ttnn,
          self.eps
        )
      )
    )

    result_torch = ttnn.to_torch(
      norm_x_ttnn,
      dtype=torch.bfloat16,
      device=self.device
    )

    return result_torch

In [None]:
device_id = 0
device = ttnn.open_device(device_id=device_id)

ln = LayerNorm_ttnn(emb_dim=5, device=device)
out_ln = ln(batch_example)

mean = out_ln.mean(dim=-1, keepdim=True)
variance = out_ln.var(dim=-1, unbiased=False, keepdim=True)

ttnn.close_device(device)

print("Mean", mean)
print("Variance", variance)

In [123]:
class GELU_ttnn(nn.Module):
  def __init__(self, device):
    super().__init__()

    self.device = device

  def forward(self, x):
    x_ttnn = ttnn.from_torch(
      x,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device
    )

    x_ttnn_cubed = ttnn.pow(
      x_ttnn,
      3
    )

    pi_tensor = ttnn.from_torch(
      torch.tensor(2.0 / torch.pi),
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device
    )

    sqrt_pi_tensor = ttnn.sqrt(
      pi_tensor
    )

    sqrt_factor_ttnn = ttnn.multiply(
      sqrt_pi_tensor,
      ttnn.add(
        x_ttnn,
        ttnn.multiply(
          x_ttnn_cubed,
          0.044715
        )
      )      
    )

    tanh_factor_ttnn = ttnn.tanh(
      sqrt_factor_ttnn
    )

    result_ttnn = ttnn.multiply(
      x_ttnn,
      ttnn.multiply(
        ttnn.add(tanh_factor_ttnn, 1),
        0.5
      )
    )

    result_torch = ttnn.to_torch(result_ttnn, dtype=torch.float16, device=self.device)

    return result_torch


In [None]:
device_id = 0
device = ttnn.open_device(device_id=device_id)

import matplotlib.pyplot as plt

gelu, gelu_ttnn, relu = GELU(), GELU_ttnn(device), nn.ReLU()

# Some sample data
x = torch.linspace(-3, 3, 100)
y_gelu, y_gelu_ttnn, y_relu = gelu(x), gelu_ttnn(x), relu(x)

plt.figure(figsize=(8, 3))
for i, (y, label) in enumerate(zip([y_gelu, y_gelu_ttnn, y_relu], ["GELU", "GELU_ttnn", "RELU"]), 1):
    plt.subplot(1, 3, i)
    plt.plot(x, y)
    plt.title(f"{label} activation function")
    plt.xlabel("x")
    plt.ylabel(f"{label}(x)")
    plt.grid(True)

plt.tight_layout()
plt.show()

ttnn.close_device(device)

In [88]:
class FeedForward_ttnn(nn.Module):
  def __init__(self, cfg, device):
    super().__init__()

    self.device = device

    self.lin_1 = nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"])
    self.gelu = GELU()
    self.lin_2 = nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"])

    self.lin_1_ttnn = ttnn.from_torch(
      self.lin_1.weight,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )
    self.lin_2_ttnn = ttnn.from_torch(
      self.lin_2.weight,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )
    self.gelu_ttnn = GELU_ttnn(self.device)


  def forward(self, x):
    x_ttnn = ttnn.from_torch(
      x,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device 
    )

    lin_1_bias = ttnn.from_torch(
      self.lin_1.bias,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )

    x_ttnn = ttnn.linear(
      x_ttnn,
      self.lin_1_ttnn,
      transpose_b=True,
      bias=lin_1_bias,
    )

    x_ttnn = self.gelu_ttnn(
      ttnn.to_torch(x_ttnn, device=self.device)
    )
    x_ttnn = ttnn.from_torch(
      x_ttnn,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device,
    )

    lin_2_bias = ttnn.from_torch(
      self.lin_2.bias,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )

    x_ttnn = ttnn.linear(
      x_ttnn,
      self.lin_2_ttnn,
      transpose_b=True,
      bias=lin_2_bias,
    )

    result = ttnn.to_torch(x_ttnn, device=self.device)

    return result 

In [91]:
torch.manual_seed(123)

ffn = FeedForward(GPT_CONFIG_124M)

# input shape: [batch_size, num_token, emb_size]
x = torch.rand(2, 3, 768) 
out = ffn(x)
out, out.shape

(tensor([[[ 0.0139,  0.0136, -0.1403,  ..., -0.2203, -0.1387, -0.0715],
          [-0.0181,  0.0461, -0.1763,  ..., -0.1154, -0.0052,  0.0039],
          [ 0.0301, -0.0376, -0.1168,  ..., -0.1506, -0.1201, -0.1278]],
 
         [[ 0.0862,  0.0446,  0.0118,  ..., -0.1831, -0.0280, -0.0259],
          [-0.0950,  0.0471, -0.1487,  ..., -0.1297, -0.0834, -0.0053],
          [-0.0161, -0.0762, -0.0622,  ..., -0.0481, -0.0952, -0.1189]]],
        grad_fn=<ViewBackward0>),
 torch.Size([2, 3, 768]))

In [92]:
torch.manual_seed(123)

device_id = 0
device = ttnn.open_device(device_id=device_id)

ffn_ttnn = FeedForward_ttnn(GPT_CONFIG_124M, device)
x = torch.rand(2, 3, 768)
out_ttnn = ffn_ttnn(x)

ttnn.close_device(device)

out_ttnn, out_ttnn.shape

                  Metal | INFO     | Initializing device 0. Program cache is NOT enabled
                  Metal | INFO     | AI CLK for device 0 is:   1000 MHz
ttnn.Tensor([[[ 0.19238,  0.11621,  ..., -0.08545,  0.29297],
              [ 0.12402, -0.09180,  ..., -0.01697,  0.10791],
              [-0.05151,  0.08301,  ...,  0.12451,  0.23926]],

             [[ 0.24805,  0.34375,  ..., -0.01184,  0.15625],
              [ 0.00290,  0.01758,  ...,  0.20801,  0.23828],
              [-0.08252,  0.02332,  ...,  0.19727,  0.06982]]], shape=Shape([2, 3, 3072]), dtype=DataType::BFLOAT16, layout=Layout::TILE)
                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0


(TorchTensor([[[ 0.0128,  0.0142, -0.1465,  ..., -0.2295, -0.1396, -0.0747],
               [-0.0187,  0.0457, -0.1807,  ..., -0.1216, -0.0044,  0.0077],
               [ 0.0288, -0.0361, -0.1216,  ..., -0.1553, -0.1221, -0.1279]],
 
              [[ 0.0889,  0.0483,  0.0128,  ..., -0.1855, -0.0322, -0.0288],
               [-0.0967,  0.0474, -0.1543,  ..., -0.1318, -0.0854, -0.0061],
               [-0.0129, -0.0767, -0.0654,  ..., -0.0474, -0.0996, -0.1230]]],
             dtype=torch.bfloat16),
 torch.Size([2, 3, 768]))

In [93]:
from scripts.compare_tensors import compare_tensors

compare_tensors(out, out_ttnn)

=== Tensor Comparison ===
Shapes: PyTorch torch.Size([2, 3, 768]), TTNN torch.Size([2, 3, 768])
Data types: PyTorch torch.bfloat16, TTNN torch.bfloat16

Tolerance Checks:
  Max Absolute Diff: 0.017578 (Tolerance: 0.020000) ✅ PASS
  Mean Absolute Diff: 0.003067 (Tolerance: 0.020000) ✅ PASS
  Correlation: 0.996094 (Tolerance: 0.990000) ✅ PASS

Overall Status: ✅ PASS

Sample Value Comparisons (first 3 positions):
  Position [0,0,0]: PyTorch=0.013855, TTNN=0.012756, Diff=0.001099 ✅
  Position [0,0,1]: PyTorch=0.013550, TTNN=0.014160, Diff=0.000610 ✅
  Position [0,0,2]: PyTorch=-0.140625, TTNN=-0.146484, Diff=0.005859 ✅
  Position [0,0,3]: PyTorch=0.056396, TTNN=0.060791, Diff=0.004395 ✅
  Position [0,0,4]: PyTorch=0.082520, TTNN=0.086426, Diff=0.003906 ✅
  Position [0,0,5]: PyTorch=0.165039, TTNN=0.169922, Diff=0.004883 ✅
  Position [0,0,6]: PyTorch=0.062500, TTNN=0.065918, Diff=0.003418 ✅
  Position [0,0,7]: PyTorch=0.038330, TTNN=0.041992, Diff=0.003662 ✅
  Position [0,0,8]: PyTorch=-0.2

{'max_diff': 0.017578125,
 'mean_diff': 0.0030670166015625,
 'correlation': 0.99609375,
 'max_diff_status': True,
 'mean_diff_status': True,
 'correlation_status': True,
 'overall_status': True}

In [110]:
core_grid_x = 8
core_grid_y = 8
MINUS_INFINITY = -1e9
class MultiHeadAttention_ttnn(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, device, qkv_bias=False):
    super().__init__()

    self.device = device
    self.dropout_prob = dropout

    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out // num_heads # Reduce the projection dimension to match the output dim

    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.out_proj = nn.Linear(d_out, d_out)
    self.dropout = nn.Dropout(dropout)
    self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)

    self.W_query_ttnn = ttnn.from_torch(
      self.W_query.weight,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )
    self.W_key_ttnn = ttnn.from_torch(
      self.W_key.weight,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )
    self.W_value_ttnn = ttnn.from_torch(
      self.W_value.weight,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )
    self.out_proj_ttnn = ttnn.from_torch(
      self.out_proj.weight,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )
    self.mask_ttnn = ttnn.from_torch(
      self.mask,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )

  def forward(self, x):
    b, num_tokens, d_in = x.shape

    x_ttnn = ttnn.from_torch(
      x,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device
    )

    keys_ttnn = ttnn.linear(
      x_ttnn,
      self.W_key_ttnn,
      transpose_b=True,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )
    queries_ttnn = ttnn.linear(
      x_ttnn,
      self.W_query_ttnn,
      transpose_b=True,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )
    values_ttnn = ttnn.linear(
      x_ttnn,
      self.W_value_ttnn,
      transpose_b=True,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )
    
    keys_ttnn = ttnn.reshape(keys_ttnn, (b, num_tokens, self.num_heads, self.head_dim))
    values_ttnn = ttnn.reshape(values_ttnn, (b, num_tokens, self.num_heads, self.head_dim))
    queries_ttnn = ttnn.reshape(queries_ttnn, (b, num_tokens, self.num_heads, self.head_dim))

    # NOTE! This is intentional. We want the transposed version of keys_ttnn. That's why the
    # shape has a different permutation than the values_ttnn and queries_ttnn!
    keys_transposed_ttnn = ttnn.permute(keys_ttnn, (0, 2, 3, 1))
    values_ttnn = ttnn.permute(values_ttnn, (0, 2, 1, 3))
    queries_ttnn = ttnn.permute(queries_ttnn, (0, 2, 1, 3))

    attn_scores_ttnn = ttnn.matmul(
      queries_ttnn, 
      keys_transposed_ttnn,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )

    attn_scores_ttnn = attn_scores_ttnn * (1 / (self.head_dim ** 0.5))

    attn_mask_ttnn = self.mask_ttnn[:num_tokens, :num_tokens] * MINUS_INFINITY
    attn_mask_ttnn = ttnn.reshape(attn_mask_ttnn, (1, 1, num_tokens, num_tokens))
    attn_scores_ttnn += attn_mask_ttnn
    
    attn_weights_ttnn = ttnn.softmax(attn_scores_ttnn, dim=-1)

    if self.dropout_prob > 0.0:
      attn_weights_ttnn = ttnn.experimental.dropout(
        attn_weights_ttnn,
        seed=123,
        probability=self.dropout_prob,
        scale=1.0 / (1.0 - self.dropout_prob)
      )

    
    context_vec_ttnn = ttnn.matmul(
      attn_weights_ttnn,
      values_ttnn,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    ) 

    context_vec_ttnn = ttnn.permute(context_vec_ttnn, (0, 2, 1, 3))
    context_vec_ttnn = ttnn.reshape(context_vec_ttnn, (b, num_tokens, self.d_out))

    out_proj_bias_ttnn = ttnn.from_torch(
      self.out_proj.bias,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device
    )

    context_vec_ttnn = ttnn.linear(
      context_vec_ttnn,
      self.out_proj_ttnn,
      transpose_b=True,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )
    context_vec_ttnn = ttnn.add(context_vec_ttnn, out_proj_bias_ttnn)

    # Send the context vector back to the CPU
    context_vec = ttnn.from_device(context_vec_ttnn)
    context_vec = ttnn.to_torch(context_vec)

    return context_vec

In [120]:
class TransformerBlock_ttnn(nn.Module):
  def __init__(self, cfg, device):
    super().__init__()

    self.cfg = cfg
    self.device = device

    self.att = MultiHeadAttention_ttnn(
      d_in=cfg["emb_dim"],
      d_out=cfg["emb_dim"],
      context_length=cfg["context_length"],
      num_heads=cfg["n_heads"],
      dropout=cfg["drop_rate"],
      qkv_bias=cfg["qkv_bias"],
      device=self.device
    )

    self.ff = FeedForward_ttnn(cfg, self.device)

    self.norm1 = LayerNorm_ttnn(cfg["emb_dim"], self.device)
    self.norm2 = LayerNorm_ttnn(cfg["emb_dim"], self.device)

  def do_dropout(self, x):
    x_ttnn = ttnn.from_torch(
      x,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device
    )
    x_ttnn = ttnn.experimental.dropout(
      x_ttnn,
      seed=123,
      probability=self.cfg["drop_rate"],
      scale=1.0 / (1.0 - self.cfg["drop_rate"])
    )
    x = ttnn.to_torch(
      x_ttnn,
      device=self.device
    )

    return x

  def forward(self, x):
    """
            shortcut = x
            x = self.norm1(x)
            x = self.att(x)
            x = self.drop_shortcut(x)
            x = x + shortcut

            shortcut = x
            x = self.norm2(x)
            x = self.ff(x)
            x = self.drop_shortcut(x)
            x = x + shortcut
    """

    shortcut = x
    x = self.norm1(x)
    x = self.att(x)
    x = self.do_dropout(x) 
    x = x + shortcut

    shortcut = x
    x = self.norm2(x)
    x = self.ff(x)
    x = self.do_dropout(x)
    x = x + shortcut

    return x

In [124]:
torch.manual_seed(123)

x = torch.rand(2, 4, 768)  # Shape: [batch_size, num_tokens, emb_dim]
block = TransformerBlock(GPT_CONFIG_124M)
output = block(x)

print("Input shape:", x.shape)
print("Output shape:", output.shape)

output

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])


tensor([[[ 0.1648,  0.4002, -0.0749,  ...,  1.2646,  0.3324,  0.7243],
         [ 0.0293,  0.0498,  0.2529,  ...,  0.4698,  0.1281,  0.9749],
         [ 0.5532,  0.5788, -0.0310,  ...,  1.1544,  0.3947,  0.7600],
         [ 0.1631,  0.7128,  0.7271,  ...,  0.3312,  0.5730,  0.9258]],

        [[ 0.1787,  1.1682,  0.5810,  ...,  0.1828,  0.0073, -0.5603],
         [-0.2920,  0.6318,  0.2002,  ...,  0.3218,  0.4670, -0.0383],
         [ 0.9275,  0.4203,  0.3183,  ...,  0.3771,  0.7190, -0.1205],
         [ 0.6035,  0.5767,  0.3411,  ...,  1.3798,  1.2683,  0.3916]]],
       grad_fn=<AddBackward0>)

In [125]:
torch.manual_seed(123)

device_id = 0
device = ttnn.open_device(device_id=device_id)

x = torch.rand(2, 4, 768)
block = TransformerBlock_ttnn(GPT_CONFIG_124M, device)
output = block(x)

ttnn.close_device(device)

print("Input shape:", x.shape)
print("Output shape:", output.shape)

output

                  Metal | INFO     | Initializing device 0. Program cache is NOT enabled
                  Metal | INFO     | AI CLK for device 0 is:   1000 MHz
Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])
                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0


TorchTensor([[[ 0.1320,  0.2456, -0.1260,  ...,  1.3096,  0.3645,  0.5947],
              [ 0.0349, -0.0332,  0.0333,  ...,  0.6810,  0.1310,  0.9527],
              [ 0.5470,  0.4363,  0.0214,  ...,  1.1962,  0.4507,  0.6432],
              [ 0.0579,  0.6281,  0.9223,  ...,  0.4575,  0.4067,  0.9529]],

             [[ 0.2005,  1.0735,  0.5085,  ...,  0.2282,  0.1497, -0.4888],
              [-0.1730,  0.5932,  0.2341,  ...,  0.3284,  0.2286,  0.0083],
              [ 0.8647,  0.3382,  0.4286,  ...,  0.3010,  0.8743, -0.1834],
              [ 0.2771,  0.8157,  0.0662,  ...,  1.4832,  1.2212,  0.4542]]])

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.att = MultiHeadAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            num_heads=cfg["n_heads"],
            dropout=cfg["drop_rate"],
            qkv_bias=cfg["qkv_bias"]
        )

        self.ff = FeedForward(cfg)

        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])

        self.drop_shortcut = nn.Dropout(cfg["drop_rate"])

    def forward(self, x):
        shortcut = x
        x = self.norm1(x)
        x = self.att(x)
        x = self.drop_shortcut(x)
        x = x + shortcut

        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_shortcut(x)
        x = x + shortcut
        
        return x

In [118]:
ttnn.close_device(device)

                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0
