In [1]:
import torch
import ttnn

torch.manual_seed(0)

device_id = 0
device = ttnn.open(device_id)

[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Initializing device 0
[38;2;000;128;000m                 Device[0m | [1m[38;2;100;149;237mINFO    [0m | Opening device driver
[32m2023-10-30 16:03:36.602[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 4 PCI devices
[32m2023-10-30 16:03:36.629[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using 1 Hugepages/NumHostMemChannels for TTDevice (pci_interface_id: 3 device_id: 0xfaca revision: 0)
[32m2023-10-30 16:03:36.635[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using 1 Hugepages/NumHostMemChannels for TTDevice (pci_interface_id: 2 device_id: 0xfaca revision: 0)
[32m2023-10-30 16:03:36.639[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using 1 Hugepages/NumHostMemChannels for TTDevice (pci_interface_id: 1 device_id: 0xfaca revision: 0)
[32m2023-10-30 16:03:36.647[0m | [1m[38;2;100;149

# Configuration

In [2]:
batch_size = 1
sequence_size = 64
num_heads = 4
head_size = 32
hidden_size = num_heads * head_size

# Initialize activations and weights using torch

In [3]:
torch_hidden_states = torch.randn((batch_size, sequence_size, hidden_size), dtype=torch.bfloat16)

torch_attention_mask = torch.zeros((1, 1, 1, sequence_size), dtype=torch.bfloat16)
torch_attention_mask[:, :, ::2, :] = -1e9

torch_query_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_query_bias = torch.randn((1, 1, 1, hidden_size), dtype=torch.bfloat16)
torch_key_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_key_bias = torch.randn((1, 1, 1, hidden_size), dtype=torch.bfloat16)
torch_value_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_value_bias = torch.randn((1, 1, 1, hidden_size), dtype=torch.bfloat16)
torch_output_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_output_bias = torch.randn((1, 1, 1, hidden_size), dtype=torch.bfloat16)

# Convert activations and weights to ttnn

In [4]:
hidden_states = ttnn.from_torch(torch_hidden_states)
attention_mask = ttnn.from_torch(torch_attention_mask)

query_weight = ttnn.from_torch(torch_query_weight)
query_bias = ttnn.from_torch(torch_query_bias)
key_weight = ttnn.from_torch(torch_key_weight)
key_bias = ttnn.from_torch(torch_key_bias)
value_weight = ttnn.from_torch(torch_value_weight)
value_bias = ttnn.from_torch(torch_value_bias)
output_weight = ttnn.from_torch(torch_output_weight)
output_bias = ttnn.from_torch(torch_output_bias)

hidden_states = ttnn.to_device(hidden_states, device)
attention_mask = ttnn.to_device(attention_mask, device)
query_weight = ttnn.to_device(query_weight, device)
query_bias = ttnn.to_device(query_bias, device, memory_config=ttnn.L1_MEMORY_CONFIG)
key_weight = ttnn.to_device(key_weight, device)
key_bias = ttnn.to_device(key_bias, device, memory_config=ttnn.L1_MEMORY_CONFIG)
value_weight = ttnn.to_device(value_weight, device)
value_bias = ttnn.to_device(value_bias, device, memory_config=ttnn.L1_MEMORY_CONFIG)
output_weight = ttnn.to_device(output_weight, device)
output_bias = ttnn.to_device(output_bias, device, memory_config=ttnn.L1_MEMORY_CONFIG)


# Write multi_head_attention using ttnn

In [5]:
def multi_head_attention(
    hidden_states,
    attention_mask,
    query_weight,
    query_bias,
    key_weight,
    key_bias,
    value_weight,
    value_bias,
    output_weight,
    output_bias,
    *,
    head_size,
):
    batch_size, sequence_size, hidden_size = hidden_states.shape
    num_heads = hidden_size // head_size

    query = hidden_states @ query_weight
    query = query + query_bias
    query = ttnn.reshape(query, (batch_size, sequence_size, num_heads, head_size))
    query = ttnn.permute(query, (0, 2, 1, 3))

    key = hidden_states @ key_weight
    key = key + key_bias
    key = ttnn.reshape(key, (batch_size, sequence_size, num_heads, head_size))
    key = ttnn.permute(key, (0, 2, 3, 1))

    value = hidden_states @ value_weight
    value = value + value_bias
    value = ttnn.reshape(value, (batch_size, sequence_size, num_heads, head_size))
    value = ttnn.permute(value, (0, 2, 1, 3))

    attention_scores = query @ key
    attention_scores = attention_scores * (1 / (head_size**0.5))
    if attention_mask is not None:
        attention_scores = attention_scores + attention_mask

    attention_probs = ttnn.softmax(attention_scores, dim=-1)

    context_layer = attention_probs @ value
    context_layer = ttnn.permute(context_layer, (0, 2, 1, 3))
    context_layer = ttnn.reshape(context_layer, (batch_size, sequence_size, hidden_size))

    self_output = context_layer @ output_weight
    self_output = self_output + output_bias

    return self_output

# Run multi_head_attention using ttnn

In [6]:
output = multi_head_attention(
    hidden_states,
    attention_mask,
    query_weight,
    query_bias,
    key_weight,
    key_bias,
    value_weight,
    value_bias,
    output_weight,
    output_bias,
    head_size=head_size,
)



# Explore output

In [7]:
print("Printing ttnn tensor")
output = ttnn.to_layout(output, ttnn.ROW_MAJOR_LAYOUT)
output = ttnn.from_device(output)
print(output.shape)
print(output[0, :1])

print("\n\n")
print("Printing torch tensor")
torch_output = ttnn.to_torch(output)
print(torch_output.shape)
print(torch_output[0, :1])

Printing ttnn tensor
[1, 64, 128]
Tensor([ [23, 3.625, -3.21875, -12.375, 8.8125, -2.15625, 27.125, -5.5, -7.21875, -26.5, -2.78125, 10.125, -16.875, -22.5, -21, 11.1875, -12.75, 3.32812, -2.78125, 0.5625, -2.48438, -6.5, -0.179688, -24, 14.3125, 32.25, -5.40625, -17.625, 2.48438, -1.28125, 24.5, 6.9375, -21.375, 37.75, 12.25, -6.9375, 11.375, 1.02344, 23.875, -5.5625, -28.375, 6, 4.78125, 8.875, 4.625, 23.125, -2.15625, -14.75, 17.75, -2.20312, 11.6875, -28.375, -27.625, -14.5625, 9.0625, -13, 1.04688, -11.25, 8.5625, 6.0625, 21.625, -17.5, -5.53125, -11.25, 25.875, -0.0986328, 1.39062, -6.53125, -9.875, 7.4375, -6.78125, 2.65625, -32.5, -13.375, 14.375, 12.875, -3.32812, 7.625, 9.4375, 20.875, -5.21875, -6.84375, 0.671875, 18.625, -23.625, -23.25, -8.1875, -5.40625, 12.75, 9.75, -16.75, 1.20312, 11.6875, 4.65625, -7.25, -11.625, -13.375, -17, -0.3125, -8.9375, -18.875, 21.5, 42, -18.5, 3.25, -12, 1.35938, -5.03125, 8.0625, -19.5, 26, 2.46875, -0.953125, 15.9375, -4.15625, 7.09375, 18

# Free tensor

In [8]:
ttnn.free(output)

# Free the device

In [9]:
ttnn.close(device)

[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Closing device 0
