In [1]:
import tt_lib as ttl
device_id = 0
device = ttl.device.CreateDevice(device_id)
ttl.device.SetDefaultDevice(device)

  from .autonotebook import tqdm as notebook_tqdm


[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Initializing device 0
[38;2;000;128;000m                 Device[0m | [1m[38;2;100;149;237mINFO    [0m | Opening device driver
[32m2023-10-24 19:09:42.219[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 4 PCI devices
[32m2023-10-24 19:09:42.246[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using 1 Hugepages/NumHostMemChannels for TTDevice (pci_interface_id: 3 device_id: 0xfaca revision: 0)
[32m2023-10-24 19:09:42.252[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using 1 Hugepages/NumHostMemChannels for TTDevice (pci_interface_id: 2 device_id: 0xfaca revision: 0)
[32m2023-10-24 19:09:42.256[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using 1 Hugepages/NumHostMemChannels for TTDevice (pci_interface_id: 1 device_id: 0xfaca revision: 0)
[32m2023-10-24 19:09:42.264[0m | [1m[38;2;100;149

In [2]:
import torch
import ttnn

# Configuration

In [3]:
batch_size = 1
sequence_size = 64
num_heads = 4
head_size = 32
hidden_size = num_heads * head_size

# Initialize activations and weights using torch

In [4]:
torch_hidden_states = torch.randn((batch_size, sequence_size, hidden_size), dtype=torch.bfloat16)

torch_attention_mask = torch.zeros((1, 1, 1, sequence_size), dtype=torch.bfloat16)
torch_attention_mask[:, :, ::2, :] = -1e9

torch_query_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_query_bias = torch.randn((1, 1, 1, hidden_size), dtype=torch.bfloat16)
torch_key_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_key_bias = torch.randn((1, 1, 1, hidden_size), dtype=torch.bfloat16)
torch_value_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_value_bias = torch.randn((1, 1, 1, hidden_size), dtype=torch.bfloat16)
torch_output_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_output_bias = torch.randn((1, 1, 1, hidden_size), dtype=torch.bfloat16)

# Convert activations and weights to ttnn

In [5]:
hidden_states = ttnn.from_torch(torch_hidden_states)
attention_mask = ttnn.from_torch(torch_attention_mask)

query_weight = ttnn.from_torch(torch_query_weight)
query_bias = ttnn.from_torch(torch_query_bias)
key_weight = ttnn.from_torch(torch_key_weight)
key_bias = ttnn.from_torch(torch_key_bias)
value_weight = ttnn.from_torch(torch_value_weight)
value_bias = ttnn.from_torch(torch_value_bias)
output_weight = ttnn.from_torch(torch_output_weight)
output_bias = ttnn.from_torch(torch_output_bias)

# Write multi_head_attention using ttnn

In [6]:
def multi_head_attention(
    hidden_states,
    attention_mask,
    query_weight,
    query_bias,
    key_weight,
    key_bias,
    value_weight,
    value_bias,
    output_weight,
    output_bias,
    *,
    head_size,
):
    batch_size, sequence_size, hidden_size = hidden_states.shape
    num_heads = hidden_size // head_size

    query = hidden_states @ query_weight
    query = query + query_bias
    query = ttnn.reshape(query, (batch_size, sequence_size, num_heads, head_size))
    query = ttnn.permute(query, (0, 2, 1, 3))

    key = hidden_states @ key_weight
    key = key + key_bias
    key = ttnn.reshape(key, (batch_size, sequence_size, num_heads, head_size))
    key = ttnn.permute(key, (0, 2, 3, 1))

    value = hidden_states @ value_weight
    value = value + value_bias
    value = ttnn.reshape(value, (batch_size, sequence_size, num_heads, head_size))
    value = ttnn.permute(value, (0, 2, 1, 3))

    attention_scores = query @ key
    attention_scores = attention_scores * (1 / (head_size**0.5))
    if attention_mask is not None:
        attention_scores = attention_scores + attention_mask

    attention_probs = ttnn.softmax(attention_scores, dim=-1)

    context_layer = attention_probs @ value
    context_layer = ttnn.permute(context_layer, (0, 2, 1, 3))
    context_layer = ttnn.reshape(context_layer, (batch_size, sequence_size, hidden_size))

    self_output = context_layer @ output_weight
    self_output = self_output + output_bias

    return self_output

# Run using ttnn

In [7]:
output = multi_head_attention(
    hidden_states,
    attention_mask,
    query_weight,
    query_bias,
    key_weight,
    key_bias,
    value_weight,
    value_bias,
    output_weight,
    output_bias,
    head_size=head_size,
)



# Use output

In [8]:
print("Printing ttnn tensor")
print(output.shape)
print(output[0, 0, :1])

print("\n\n")
print("Printing torch tensor")
torch_output = ttnn.to_torch(output)
print(torch_output.shape)
print(torch_output[0, 0, :1])

Printing ttnn tensor
[1, 1, 64, 128]
Tensor([ [-3.01562, 0.925781, 19.875, -19.25, 18.625, 7.65625, 14.75, -8.25, 0.21875, -14.4375, 7.25, -5.59375, 26.125, 26, 1.20312, 31.5, -14.8125, -16.625, -0.226562, 33.25, -1.33594, 5.3125, 8.8125, -12.9375, 2.73438, 9.875, 0.722656, -39, -9.125, -38.5, -29.625, 11.125, 33, 6.03125, 19, -24.125, 3.40625, -59.5, 13.8125, 3.75, -15, 51.25, -14.5, -25.25, 30.75, -31.625, 26.75, 43, 13.25, -18.625, 12.875, -20, -10.125, -5.3125, 35.5, 3.78125, -14.3125, -15.75, -23.25, -36.25, -2.29688, -1.71094, -6.75, -16.75, 13.75, -0.535156, -42.25, -26, -15.125, -39, -0.625, -5.84375, -6.625, 10.6875, -21.75, -0.015625, 5.6875, -52.5, 17.875, 54.75, 1.11719, 7.125, -34.25, 1.69531, 8.5, 13.6875, -15.1875, 11.625, -0.546875, -2.57812, -18.125, -16, -0.445312, -7.8125, 30.25, -13, 21.5, 0.714844, -15.9375, 3.96875, 10.5625, 1.46875, 19.5, 10.625, -0.296875, -6.9375, -2.48438, 1.61719, -9.5625, -11.75, -10.6875, -15.625, 35.25, 1.42969, -4.75, 9.9375, -24.375, 56.

# Free tensor

In [9]:
del output

In [10]:
ttl.device.CloseDevice(device)

[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Closing device 0


True