# MHA Attention Performance

In [1]:
import torch
from tqdm import tqdm
from transformer.attention import MultiHeadAttentionSlow, MultiHeadAttention

In [2]:
mha_slow = MultiHeadAttentionSlow(4096, 64)
mha_fast = MultiHeadAttention(4096, 64)

In [3]:
%%time
for _ in tqdm(range(1000)):
    src = torch.rand((16, 4, 4096))
    tgt = torch.rand((16, 8, 4096))

    mha_slow(query=tgt, key=src, value=src)

100%|██████████| 1000/1000 [01:24<00:00, 11.78it/s]

CPU times: total: 6min 54s
Wall time: 1min 24s





In [4]:
%%time
for _ in tqdm(range(1000)):
    src = torch.rand((16, 4, 4096))
    tgt = torch.rand((16, 8, 4096))

    mha_fast(query=tgt, key=src, value=src)

100%|██████████| 1000/1000 [01:00<00:00, 16.41it/s]

CPU times: total: 4min 34s
Wall time: 1min





# Usage

In [5]:
import torch
import torch.onnx
from transformer.transformer import Transformer

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
model = Transformer(
    d_model=32,
    nhead=2,
    num_encoder_layers=2,
    num_decoder_layers=2,
    d_ffn=64,
    dropout=0.1
)

total_params = 0
for param in model.parameters():
    total_params += param.numel()
print(f"Total Parameters: {total_params}")

Total Parameters: 42752


In [7]:
src = torch.rand((4, 4, 32))
tgt = torch.rand((4, 8, 32))
tgt_mask = model.generate_square_subsequent_mask(8)

outputs = model(src, tgt, tgt_mask=tgt_mask)
outputs.shape

torch.Size([4, 8, 32])

In [8]:
torch.onnx.export(
    model,
    (src, tgt),
    "assets/transformer.onnx",
    export_params=True,
    opset_version=12,
    do_constant_folding=True,
    input_names=['src', 'tgt', 'tgt_mask'],
    output_names=['outputs'],
    dynamic_axes={
        'src': {0: 'batch_size'},
        'tgt': {0: 'batch_size'},
        'tgt_mask': {0: 'batch_size'},
        'outputs': {0: 'batch_size'}
    }
)

  d_key = torch.tensor(key.size(3))
  d_key = torch.tensor(key.size(3))
