In [1]:
import tt_lib as ttl
device_id = 0
device = ttl.device.CreateDevice(device_id)
ttl.device.SetDefaultDevice(device)

  from .autonotebook import tqdm as notebook_tqdm


[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Initializing device 0
[38;2;000;128;000m                 Device[0m | [1m[38;2;100;149;237mINFO    [0m | Opening device driver
[32m2023-10-25 15:53:16.862[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 1 PCI device
[32m2023-10-25 15:53:16.873[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using 1 Hugepages/NumHostMemChannels for TTDevice (pci_interface_id: 0 device_id: 0xfaca revision: 0)
[0;33m---- ttSiliconDevice::init_hugepage: bind_area_to_memory_nodeset() failed (physical_device_id: 0 ch: 0). Hugepage allocation is not on NumaNode matching TT Device. Side-Effect is decreased Device->Host perf (Issue #893).
[0m[32m2023-10-25 15:53:17.092[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 1 PCI device
[32m2023-10-25 15:53:17.093[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Usin

In [2]:
import torch
import ttnn

# Configuration

In [3]:
batch_size = 1
sequence_size = 64
num_heads = 4
head_size = 32
hidden_size = num_heads * head_size

# Initialize activations and weights using torch

In [4]:
torch_hidden_states = torch.randn((batch_size, sequence_size, hidden_size), dtype=torch.bfloat16)

torch_attention_mask = torch.zeros((1, 1, 1, sequence_size), dtype=torch.bfloat16)
torch_attention_mask[:, :, ::2, :] = -1e9

torch_query_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_query_bias = torch.randn((1, 1, 1, hidden_size), dtype=torch.bfloat16)
torch_key_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_key_bias = torch.randn((1, 1, 1, hidden_size), dtype=torch.bfloat16)
torch_value_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_value_bias = torch.randn((1, 1, 1, hidden_size), dtype=torch.bfloat16)
torch_output_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_output_bias = torch.randn((1, 1, 1, hidden_size), dtype=torch.bfloat16)

# Convert activations and weights to ttnn

In [5]:
hidden_states = ttnn.from_torch(torch_hidden_states)
attention_mask = ttnn.from_torch(torch_attention_mask)

query_weight = ttnn.from_torch(torch_query_weight)
query_bias = ttnn.from_torch(torch_query_bias)
key_weight = ttnn.from_torch(torch_key_weight)
key_bias = ttnn.from_torch(torch_key_bias)
value_weight = ttnn.from_torch(torch_value_weight)
value_bias = ttnn.from_torch(torch_value_bias)
output_weight = ttnn.from_torch(torch_output_weight)
output_bias = ttnn.from_torch(torch_output_bias)

# Write multi_head_attention using ttnn

In [6]:
def multi_head_attention(
    hidden_states,
    attention_mask,
    query_weight,
    query_bias,
    key_weight,
    key_bias,
    value_weight,
    value_bias,
    output_weight,
    output_bias,
    *,
    head_size,
):
    batch_size, sequence_size, hidden_size = hidden_states.shape
    num_heads = hidden_size // head_size

    query = hidden_states @ query_weight
    query = query + query_bias
    query = ttnn.reshape(query, (batch_size, sequence_size, num_heads, head_size))
    query = ttnn.permute(query, (0, 2, 1, 3))

    key = hidden_states @ key_weight
    key = key + key_bias
    key = ttnn.reshape(key, (batch_size, sequence_size, num_heads, head_size))
    key = ttnn.permute(key, (0, 2, 3, 1))

    value = hidden_states @ value_weight
    value = value + value_bias
    value = ttnn.reshape(value, (batch_size, sequence_size, num_heads, head_size))
    value = ttnn.permute(value, (0, 2, 1, 3))

    attention_scores = query @ key
    attention_scores = attention_scores * (1 / (head_size**0.5))
    if attention_mask is not None:
        attention_scores = attention_scores + attention_mask

    attention_probs = ttnn.softmax(attention_scores, dim=-1)

    context_layer = attention_probs @ value
    context_layer = ttnn.permute(context_layer, (0, 2, 1, 3))
    context_layer = ttnn.reshape(context_layer, (batch_size, sequence_size, hidden_size))

    self_output = context_layer @ output_weight
    self_output = self_output + output_bias

    return self_output

# Run using ttnn

In [7]:
output = multi_head_attention(
    hidden_states,
    attention_mask,
    query_weight,
    query_bias,
    key_weight,
    key_bias,
    value_weight,
    value_bias,
    output_weight,
    output_bias,
    head_size=head_size,
)



# Use output

In [8]:
print("Printing ttnn tensor")
print(output.shape)
print(output[0, 0, :1])

print("\n\n")
print("Printing torch tensor")
torch_output = ttnn.to_torch(output)
print(torch_output.shape)
print(torch_output[0, 0, :1])

Printing ttnn tensor
[1, 1, 64, 128]
Tensor([ [-6.09375, 27.25, -39.25, 8.1875, -21.625, 0.265625, -18.375, 3.6875, 9.375, -21, 77, -4.84375, -7.78125, -25.5, -32.75, 11.625, -20.25, 0.515625, 21.75, 10, 31.375, 3.39062, -8.3125, -26, -36.25, -46.75, 26.625, 24.125, 9.6875, -50, 11, 16.25, -2.15625, 3.53125, 9, -0.00976562, -13.75, 22.25, 2.32812, -5.28125, 11, -24.625, -45.25, 30.125, -0.140625, -0.859375, -9.625, 7.8125, 16.125, -17.75, -20.125, -20.625, 5.28125, -6.15625, -13.5625, 30.375, 32.25, 3.76562, -42.5, -1.28125, 3.5625, 1.69531, 12.3125, -0.902344, 11.4375, -13.3125, -11.375, -24.375, -25.875, 6.09375, 33.5, -54.75, 31.125, -28.375, 31.125, -23.875, -9, 9.875, 5, 23.625, 0.625, -10.3125, -6.5, 13.4375, -5.125, -5.40625, 7.75, 5.4375, -15.125, -17.625, -25.75, -2.76562, -35, 21.125, -33.25, -6.59375, -17.75, -24.5, -9.0625, -25.5, -11.375, -20.875, 12.5, -21.75, -19.5, 10.8125, -18.25, 16.375, -36, 3.34375, 17.75, -12.75, 16.875, -6.3125, -41.25, 6.09375, -0.351562, -14.375

# Free tensor

In [9]:
del output

In [10]:
ttl.device.CloseDevice(device)

[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Closing device 0


True