In [1]:
import torch
import ttnn

torch.manual_seed(0)

device_id = 0
device = ttnn.open(device_id)

[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Initializing device 0
[38;2;000;128;000m                 Device[0m | [1m[38;2;100;149;237mINFO    [0m | Opening device driver
[32m2023-10-30 22:22:08.864[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 4 PCI devices
[32m2023-10-30 22:22:08.895[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using 1 Hugepages/NumHostMemChannels for TTDevice (pci_interface_id: 3 device_id: 0xfaca revision: 0)
[32m2023-10-30 22:22:08.900[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using 1 Hugepages/NumHostMemChannels for TTDevice (pci_interface_id: 2 device_id: 0xfaca revision: 0)
[32m2023-10-30 22:22:08.904[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using 1 Hugepages/NumHostMemChannels for TTDevice (pci_interface_id: 1 device_id: 0xfaca revision: 0)
[32m2023-10-30 22:22:08.912[0m | [1m[38;2;100;149

# Configuration

In [2]:
batch_size = 1
sequence_size = 64
num_heads = 4
head_size = 32
hidden_size = num_heads * head_size

# Initialize activations and weights using torch

In [3]:
torch_hidden_states = torch.randn((batch_size, sequence_size, hidden_size), dtype=torch.bfloat16)

torch_attention_mask = torch.zeros((1, 1, 1, sequence_size), dtype=torch.bfloat16)
torch_attention_mask[:, :, ::2, :] = -1e9

torch_query_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_query_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)
torch_key_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_key_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)
torch_value_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_value_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)
torch_output_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_output_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)

# Convert activations and weights to ttnn

In [4]:
hidden_states = ttnn.from_torch(torch_hidden_states)
attention_mask = ttnn.from_torch(torch_attention_mask)

query_weight = ttnn.from_torch(torch_query_weight)
query_bias = ttnn.from_torch(torch_query_bias)
key_weight = ttnn.from_torch(torch_key_weight)
key_bias = ttnn.from_torch(torch_key_bias)
value_weight = ttnn.from_torch(torch_value_weight)
value_bias = ttnn.from_torch(torch_value_bias)
output_weight = ttnn.from_torch(torch_output_weight)
output_bias = ttnn.from_torch(torch_output_bias)

hidden_states = ttnn.to_device(hidden_states, device)
attention_mask = ttnn.to_device(attention_mask, device)
query_weight = ttnn.to_device(query_weight, device)
query_bias = ttnn.to_device(query_bias, device, memory_config=ttnn.L1_MEMORY_CONFIG)
key_weight = ttnn.to_device(key_weight, device)
key_bias = ttnn.to_device(key_bias, device, memory_config=ttnn.L1_MEMORY_CONFIG)
value_weight = ttnn.to_device(value_weight, device)
value_bias = ttnn.to_device(value_bias, device, memory_config=ttnn.L1_MEMORY_CONFIG)
output_weight = ttnn.to_device(output_weight, device)
output_bias = ttnn.to_device(output_bias, device, memory_config=ttnn.L1_MEMORY_CONFIG)


# Write multi_head_attention using ttnn

In [5]:
def multi_head_attention(
    hidden_states,
    attention_mask,
    query_weight,
    query_bias,
    key_weight,
    key_bias,
    value_weight,
    value_bias,
    output_weight,
    output_bias,
    *,
    head_size,
):
    batch_size, sequence_size, hidden_size = hidden_states.shape
    num_heads = hidden_size // head_size

    query = hidden_states @ query_weight
    query = query + query_bias
    query = ttnn.reshape(query, (batch_size, sequence_size, num_heads, head_size))
    query = ttnn.permute(query, (0, 2, 1, 3))

    key = hidden_states @ key_weight
    key = key + key_bias
    key = ttnn.reshape(key, (batch_size, sequence_size, num_heads, head_size))
    key = ttnn.permute(key, (0, 2, 3, 1))

    value = hidden_states @ value_weight
    value = value + value_bias
    value = ttnn.reshape(value, (batch_size, sequence_size, num_heads, head_size))
    value = ttnn.permute(value, (0, 2, 1, 3))

    attention_scores = query @ key
    attention_scores = attention_scores * (1 / (head_size**0.5))
    if attention_mask is not None:
        attention_scores = attention_scores + attention_mask

    attention_probs = ttnn.softmax(attention_scores, dim=-1)

    context_layer = attention_probs @ value
    context_layer = ttnn.permute(context_layer, (0, 2, 1, 3))
    context_layer = ttnn.reshape(context_layer, (batch_size, sequence_size, hidden_size))

    self_output = context_layer @ output_weight
    self_output = self_output + output_bias

    return self_output

# Run multi_head_attention using ttnn

In [6]:
output = multi_head_attention(
    hidden_states,
    attention_mask,
    query_weight,
    query_bias,
    key_weight,
    key_bias,
    value_weight,
    value_bias,
    output_weight,
    output_bias,
    head_size=head_size,
)



# Explore the output

In [7]:
print("Printing ttnn tensor")
output = ttnn.to_layout(output, ttnn.ROW_MAJOR_LAYOUT)
output = ttnn.from_device(output)
print(output.shape)
print(output[0, :1])

print("\n\n")
print("Printing torch tensor")
torch_output = ttnn.to_torch(output)
print(torch_output.shape)
print(torch_output[0, :1])

Printing ttnn tensor
[1, 64, 128]
Tensor([ [-2.26674e+24, -4.30737e-08, nan, 0, 0, 0, 0, 0, 9.18355e-41, 0, 0, 0, -2.07526e+19, -4.30737e-08, nan, 0, -8.04661e+27, -4.30737e-08, nan, 0, 3.20624e+35, 3.36295e+38, 2.69833e+38, 0, -2.26674e+23, -4.30737e-08, nan, 0, 0, 0, 0, 0, -4.79702e+27, -4.30737e-08, nan, 0, 0, 0, 3.25661e+38, 0, 9.18355e-41, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1.70141e+38, -nan, -nan, -nan, 0, 9.18355e-41, 1.83671e-40, 0, 7.14905e-31, 0, 0, 0, 2.51513e+13, -58368, 2.68504e+38, 0, -4.79702e+27, -4.30737e-08, nan, 0, 9.18355e-40, 0, 2.68504e+38, 0, -1.88895e+23, -4.30737e-08, nan, 0, 0, 0, nan, 0, 9.18355e-41, 0, nan, 0, -2.17607e+25, -4.30737e-08, nan, 0, 1.07448e-38, 0, 0, 0, 2.00447e+31, nan, 2.69833e+38, 0, 0, 0, 0, 0, 0, 0, 0, 0, -5.90296e+21, -4.30737e-08, nan, 0, 9.18355e-41, 0, 0, 0, -2.07526e+19, -4.30737e-08, nan, 0, 0, 0, 9.18355e-41, 0]], dtype=bfloat16 )




Printing torch tensor
torch.Size([1, 64, 128])
tensor([[-2.2667e+24, -4.3074e-08,       

# Close the device

In [8]:
ttnn.close(device)

[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Closing device 0
