# Causal Attention Mechanisms

This notebook builds off the Simple Weighted Attention Mechanism notebooks. We implemented the causal attention mechanism which is just a self-attention mechanism (like the simple weighted one) where we only consider tokens that appear previously when attempting to predict the next token in the sequence.

Causal attention
> Restricts model to only consider previous and current inputs in a sequence when processing any given token when computing attention scores. 
>
> -- _Sebastian Raschka - Build a Large Language Model from Scratch_

The code is adapted again, from Build a Large Language Model from Scratch.

Import the required modules.

In [None]:
import torch
import ttnn
from torch import nn

Declare the familiar input, `context` representing the string `Your journey starts with one step`.

In [None]:
context = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

For optimizations relating to TTNN, we can declare the dimensions for `CoreGrid` for all optimized `linear` and `matmul` calls.

In [None]:
core_grid_x = 8
core_grid_y = 8

Let's now bring back our `ttnn` optimized `SelfAttention` class and use that to demonstrate causal attention.

In [None]:
torch.manual_seed(789)

class SelfAttention_v2(nn.Module):
  def __init__(self, d_in, d_out, device):
    super().__init__()

    self.W_query = nn.Linear(d_in, d_out , bias=False)
    self.W_key = nn.Linear(d_in, d_out, bias=False)
    self.W_value = nn.Linear(d_in, d_out, bias=False)

    self._device = device
    
    # Extract weight matrices from PyTorch layers and convert to TTNN once
    self.W_query_ttnn = ttnn.from_torch(
      self.W_query.weight, 
      dtype=ttnn.bfloat16, 
      layout=ttnn.TILE_LAYOUT, 
      device=self._device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )
    
    self.W_key_ttnn = ttnn.from_torch(
      self.W_key.weight, 
      dtype=ttnn.bfloat16, 
      layout=ttnn.TILE_LAYOUT, 
      device=self._device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )
    
    self.W_value_ttnn = ttnn.from_torch(
      self.W_value.weight, 
      dtype=ttnn.bfloat16, 
      layout=ttnn.TILE_LAYOUT, 
      device=self._device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )

    self._scaler = 1 / (d_out ** 0.5)

  def forward(self, x):
    x_ttnn = ttnn.from_torch(
      x, 
      dtype=ttnn.bfloat16, 
      layout=ttnn.TILE_LAYOUT, 
      device=device,
    )
    
    queries_ttnn = ttnn.linear(
      x_ttnn,
      self.W_query_ttnn,
      transpose_b=True,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )
    values_ttnn = ttnn.linear(
      x_ttnn,
      self.W_value_ttnn,
      transpose_b=True,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )
    keys_ttnn = ttnn.linear(
      x_ttnn,
      self.W_key_ttnn,
      transpose_b=True,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )

    attn_scores_ttnn = ttnn.matmul(
      queries_ttnn, 
      ttnn.permute(keys_ttnn, (1, 0)),
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )

    attn_weights_ttnn = ttnn.softmax(
      attn_scores_ttnn * self._scaler,
      dim=-1
    )

    context_vec_ttnn = ttnn.matmul(
      attn_weights_ttnn,
      values_ttnn,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )

    context_vec = ttnn.to_torch(context_vec_ttnn)

    return context_vec

2025-05-03 04:29:43.177 | DEBUG    | ttnn:<module>:83 - Initial ttnn.CONFIG:
Config{cache_path=/home/avgdev/.cache/ttnn,model_cache_path=/home/avgdev/.cache/ttnn/models,tmp_dir=/tmp/ttnn,enable_model_cache=false,enable_fast_runtime_mode=true,throw_exception_on_fallback=false,enable_logging=false,enable_graph_report=false,enable_detailed_buffer_report=false,enable_detailed_tensor_report=false,enable_comparison_mode=false,comparison_mode_should_raise_exception=false,comparison_mode_pcc=0.9999,root_report_path=generated/ttnn/reports,report_name=std::nullopt,std::nullopt}


In [None]:
d = context.shape
d_in = context.shape[1]
d_out = context.shape[1]

d_in, d_out

(torch.Size([6, 3]), 3, 3)

## Causal Attention TTNN Demonstration


### Manually Compute Attention Weights

In [5]:

device_id = 0
device = ttnn.open_device(device_id=device_id)

sa_v2 = SelfAttention_v2(d_in, d_out, device)
inputs_ttnn = ttnn.from_torch(
  context,
  dtype=ttnn.bfloat16,
  layout=ttnn.TILE_LAYOUT,
  device=device
)
queries_ttnn = ttnn.linear(
  inputs_ttnn,
  sa_v2.W_query_ttnn,
  transpose_b=True,
  core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
)
keys_ttnn = ttnn.linear(
  inputs_ttnn,
  sa_v2.W_key_ttnn,
  transpose_b=True,
  core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
)
attn_scores_ttnn = ttnn.matmul(
  queries_ttnn, 
  ttnn.permute(keys_ttnn, (1, 0)),
  core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
)

attn_weights_ttnn = ttnn.softmax(
  attn_scores_ttnn * (1 / (d_out ** 0.5)),
  dim=-1
)

attn_scores = ttnn.to_torch(attn_scores_ttnn, device=device)
attn_weights = ttnn.to_torch(attn_weights_ttnn, device= device)
ttnn.close_device(device)

attn_weights, attn_weights.shape

                 Device | INFO     | Opening user mode device driver
[32m2025-05-03 04:29:45.558[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Opened PCI device 0; KMD version: 1.33.0, IOMMU: disabled

[32m2025-05-03 04:29:45.571[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Opened PCI device 0; KMD version: 1.33.0, IOMMU: disabled
[32m2025-05-03 04:29:45.573[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Harvesting mask for chip 0 is 0x200 (physical layout: 0x1, logical: 0x200, simulated harvesting mask: 0x0).
[32m2025-05-03 04:29:45.574[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Opened PCI device 0; KMD version: 1.33.0, IOMMU: disabled
[32m2025-05-03 04:29:45.575[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected PCI devices: [0]
[32m2025-05-03 04:29:45.575[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using local chip ids: 

New chip! We now have 1 chips
Chip initialization complete (found )
Chip initializing complete...
 ARC

 [4/4] DRAM

 [16/16] ETH

 CPU

Chip detection complete (found )


                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0


(TorchTensor([[0.1621, 0.1621, 0.1621, 0.1680, 0.1914, 0.1562],
              [0.1621, 0.1602, 0.1611, 0.1719, 0.1846, 0.1621],
              [0.1631, 0.1611, 0.1621, 0.1719, 0.1846, 0.1621],
              [0.1641, 0.1631, 0.1631, 0.1699, 0.1738, 0.1660],
              [0.1641, 0.1641, 0.1650, 0.1680, 0.1787, 0.1611],
              [0.1641, 0.1611, 0.1621, 0.1709, 0.1748, 0.1660]],
             dtype=torch.bfloat16),
 torch.Size([6, 6]))

### Computing Causal Attention with Torch

In [6]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
mask_simple

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [7]:
masked_simple = attn_weights * mask_simple
masked_simple

TorchTensor([[0.1621, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.1621, 0.1602, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.1631, 0.1611, 0.1621, 0.0000, 0.0000, 0.0000],
             [0.1641, 0.1631, 0.1631, 0.1699, 0.0000, 0.0000],
             [0.1641, 0.1641, 0.1650, 0.1680, 0.1787, 0.0000],
             [0.1641, 0.1611, 0.1621, 0.1709, 0.1748, 0.1660]])

In [8]:
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
masked_simple_norm

TorchTensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.5030, 0.4970, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.3353, 0.3313, 0.3333, 0.0000, 0.0000, 0.0000],
             [0.2485, 0.2470, 0.2470, 0.2574, 0.0000, 0.0000],
             [0.1953, 0.1953, 0.1965, 0.2000, 0.2128, 0.0000],
             [0.1642, 0.1613, 0.1623, 0.1711, 0.1750, 0.1662]])

In [9]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
masked

TorchTensor([[-0.2295,    -inf,    -inf,    -inf,    -inf,    -inf],
             [-0.3105, -0.3281,    -inf,    -inf,    -inf,    -inf],
             [-0.3086, -0.3242, -0.3164,    -inf,    -inf,    -inf],
             [-0.1729, -0.1855, -0.1826, -0.0996,    -inf,    -inf],
             [-0.1338, -0.1309, -0.1235, -0.0820,  0.0620,    -inf],
             [-0.2314, -0.2520, -0.2500, -0.1328, -0.0840, -0.1963]],
            dtype=torch.bfloat16)

In [10]:
attn_weights = torch.softmax(masked / d_out ** 0.5, dim=1)
attn_weights

TorchTensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.5039, 0.4980, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.3359, 0.3320, 0.3340, 0.0000, 0.0000, 0.0000],
             [0.2480, 0.2461, 0.2471, 0.2598, 0.0000, 0.0000],
             [0.1943, 0.1943, 0.1953, 0.2002, 0.2168, 0.0000],
             [0.1631, 0.1611, 0.1611, 0.1719, 0.1768, 0.1660]],
            dtype=torch.bfloat16)

### Dropout with Torch

In [11]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6, 6)
print(dropout(example))

tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])


In [12]:
dropout(attn_weights)

TorchTensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.0000, 0.6680, 0.0000, 0.0000, 0.0000],
             [0.0000, 0.4922, 0.0000, 0.5195, 0.0000, 0.0000],
             [0.0000, 0.3887, 0.3906, 0.4004, 0.4336, 0.0000],
             [0.3262, 0.3223, 0.0000, 0.0000, 0.3535, 0.3320]],
            dtype=torch.bfloat16)

### Dropout with TTNN

In [13]:
from ttnn import TILE_LAYOUT


device_id = 0
device = ttnn.open_device(device_id=device_id)

attn_weights_ttnn = ttnn.to_device(
  ttnn.from_torch(attn_weights, dtype=ttnn.bfloat16, layout=TILE_LAYOUT),
  device
)
dropout_ttnn = ttnn.experimental.dropout(
  attn_weights_ttnn,
  seed=123,
  probability=0.5,
  scale=1.0/(1.0-0.5)
)

dropped = ttnn.to_torch(dropout_ttnn)

ttnn.close_device(device)

dropped

                  Metal | INFO     | Initializing device 0. Program cache is NOT enabled
                  Metal | INFO     | AI CLK for device 0 is:   1000 MHz
                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0


TorchTensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [1.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
             [0.6719, 0.6641, 0.6680, 0.0000, 0.0000, 0.0000],
             [0.4961, 0.0000, 0.0000, 0.5195, 0.0000, 0.0000],
             [0.0000, 0.3887, 0.3906, 0.4004, 0.4336, 0.0000],
             [0.0000, 0.3223, 0.3223, 0.0000, 0.3535, 0.0000]],
            dtype=torch.bfloat16)

## Implementing Causal Attention with Torch

In [43]:
class CausalAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
    super().__init__()

    self.d_out = d_out
    
    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    
    self.dropout = nn.Dropout(dropout)

    self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)

  def forward(self, x): 
    b, num_tokens, d_in = x.shape
    keys = self.W_query(x)
    queries = self.W_query(x)
    values = self.W_value(x)

    attn_scores = queries @ keys.transpose(1, 2)
    attn_scores.masked_fill(
      self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
    )
    attn_weights = torch.softmax(
      attn_scores / keys.shape[-1]**0.5, dim=-1
    )

    attn_weights = self.dropout(attn_weights)

    print(attn_weights)

    context_vec = attn_weights @ values
    return context_vec


In [44]:
torch.manual_seed(123)
batch = torch.reshape(context, [1, context.shape[0], context.shape[1]])
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.5)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)
context_vecs


tensor([[[0.3615, 0.0000, 0.0000, 0.0000, 0.0000, 0.3078],
         [0.0000, 0.3768, 0.3752, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.3747, 0.0000, 0.0000, 0.3192],
         [0.0000, 0.3575, 0.3566, 0.3179, 0.0000, 0.3292],
         [0.0000, 0.3559, 0.3552, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.3622, 0.3153, 0.3117, 0.0000]]],
       grad_fn=<MulBackward0>)
context_vecs.shape: torch.Size([1, 6, 3])


tensor([[[ 0.2329,  0.3394, -0.1157],
         [ 0.2621,  0.4198, -0.1169],
         [ 0.2447,  0.3468, -0.0615],
         [ 0.4368,  0.6380, -0.1251],
         [ 0.2478,  0.3970, -0.1106],
         [ 0.1850,  0.3498, -0.1246]]], grad_fn=<UnsafeViewBackward0>)

## Implementing Causal Attention with TTNN

In [None]:
class CausalAttention_ttnn(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, device, qkv_bias=False):
    super().__init__()

    self.d_out = d_out
    
    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.dropout = nn.Dropout(dropout)
    self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)

    self.device = device
    self.W_query_ttnn = ttnn.from_torch(
      self.W_query.weight,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )

    self.W_key_ttnn = ttnn.from_torch(
      self.W_key.weight,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )
    
    self.W_value_ttnn = ttnn.from_torch(
      self.W_value.weight,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )

    self.mask_ttnn = ttnn.from_torch(
      self.mask,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )

    self.dropout_prob = dropout
    self.context_length = context_length
    

  def forward(self, x):
    b, num_tokens, d_in = x.shape
    x_ttnn = ttnn.from_torch(
      x,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device
    )
    keys_ttnn = ttnn.linear(
      x_ttnn,
      self.W_key_ttnn,
      transpose_b=True,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )
    values_ttnn = ttnn.linear(
      x_ttnn,
      self.W_value_ttnn,
      transpose_b=True,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )
    queries_ttnn = ttnn.linear(
      x_ttnn,
      self.W_query_ttnn,
      transpose_b=True,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )

    keys_ttnn_transpose = ttnn.permute(keys_ttnn, (0, 2, 1))
    attn_scores_ttnn = ttnn.matmul(
      queries_ttnn,
      keys_ttnn_transpose,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )

    if self.dropout_prob > 0.0:
      inf_ttnn = ttnn.full_like(attn_scores_ttnn, -1e9, layout=ttnn.TILE_LAYOUT)
      attn_scores_ttnn = ttnn.where(self.mask_ttnn, inf_ttnn, attn_scores_ttnn)

    attn_weights_ttnn = ttnn.softmax(attn_scores_ttnn * (1/self.d_out ** 0.5), dim=-1)

    context_vec_ttnn = ttnn.matmul(
      attn_weights_ttnn,
      values_ttnn,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )

    context_vec = ttnn.to_torch(context_vec_ttnn, device=self.device)

    return context_vec


In [67]:
device_id = 0
device = ttnn.open_device(device_id=device_id)

torch.manual_seed(123)
batch = torch.reshape(context, [1, context.shape[0], context.shape[1]])
context_length = batch.shape[1]
ca = CausalAttention_ttnn(d_in, d_out, context_length, 0.5, device)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)
context_vecs

ttnn.close_device(device)

context_vecs

                  Metal | INFO     | Initializing device 0. Program cache is NOT enabled
                  Metal | INFO     | AI CLK for device 0 is:   1000 MHz
ttnn.Tensor([[[-0.40234, -998244352.00000,  ..., -998244352.00000, -998244352.00000],
              [-0.26367,  0.16211,  ..., -998244352.00000, -998244352.00000],
              ...,
              [-0.20117,  0.01044,  ..., -0.00145, -998244352.00000],
              [-0.10596,  0.20605,  ...,  0.13965,  0.15625]]], shape=Shape([1, 6, 6]), dtype=DataType::BFLOAT16, layout=Layout::TILE)
context_vecs.shape: torch.Size([1, 6, 3])
                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0


TorchTensor([[[ 0.3340,  0.5703, -0.3145],
              [ 0.3359,  0.5508, -0.2178],
              [ 0.3457,  0.5664, -0.2012],
              [ 0.3105,  0.4941, -0.1602],
              [ 0.2441,  0.4316, -0.1650],
              [ 0.2676,  0.4336, -0.1377]]], dtype=torch.bfloat16)

In [64]:
ttnn.close_device(device)

