In [2]:
%load_ext autoreload
%autoreload 2

In [15]:
import torch
from transformer import Transformer, MultiHeadAttention

In [17]:
input = torch.randn(2, 3, 8) # (batch, seq_len, emb_dim)

MultiHeadAttention(d_model=8, num_heads=2, qkv_proj=True)(input, input, input)

tensor([[[-9.2432e-02, -7.4963e-02, -3.9764e-02,  1.9147e-01,  2.6207e-02,
          -4.1348e-02,  3.7420e-02,  7.5355e-02],
         [-1.9543e-01, -2.4322e-01, -1.9727e-01,  1.7659e-01,  7.3606e-02,
           7.7915e-02, -1.4432e-02,  9.0979e-02],
         [-1.5270e-01, -2.5979e-01, -1.6798e-01,  1.9020e-01,  5.8147e-02,
          -3.2706e-02,  2.1564e-03,  9.4873e-02]],

        [[-6.8309e-02,  2.4378e-01,  1.3729e-01, -2.4940e-02,  4.1162e-02,
           9.3957e-02,  8.1113e-02,  8.8998e-02],
         [ 4.6381e-02,  2.9870e-01,  1.7701e-01,  4.3833e-02,  7.3284e-05,
           6.6136e-02,  1.9146e-01,  9.2549e-02],
         [-3.2271e-02,  2.6685e-01,  1.3234e-01, -1.8044e-02,  5.9349e-02,
           5.3862e-02,  1.3626e-01,  1.0933e-01]]],
       grad_fn=<UnsafeViewBackward0>)

In [33]:
src_input = torch.randint(1, 100, (50, 5)) # (batch, seq_len)
print(src_input.shape)
src_input[:10]

torch.Size([50, 5])


tensor([[46, 85, 84, 32, 65],
        [14, 98,  1, 24, 19],
        [15, 85, 25, 77, 70],
        [55, 86, 50, 43, 22],
        [86, 37, 18, 52,  9],
        [84, 82, 79, 70, 56],
        [53, 88, 65, 53, 90],
        [51, 47, 33, 21, 95],
        [30, 87, 20, 48, 72],
        [89,  8, 34, 14, 75]])

In [34]:
tgt_input = torch.randint(1, 100, (50, 3)) # (batch, seq_len)
print(tgt_input.shape)
tgt_input[:10]

torch.Size([50, 3])


tensor([[32,  8, 25],
        [ 5, 72, 69],
        [10, 39, 57],
        [65, 71, 24],
        [67, 47, 47],
        [36, 21, 17],
        [88, 99, 29],
        [47, 39, 17],
        [37, 96, 31],
        [54, 10, 87]])

In [35]:
model = Transformer(
    vocab_size=100, 
    d_model=32,     # the embedding dimension
    num_heads=4,    # the number of heads in the multi-head attention
    block_size=5,   # the maximum sequence length
    num_encoders=6, # the number of encoders
    num_decoders=6, # the number of decoders
    bias=False,
    dropout=0.1 
)

In [36]:
model

Transformer(
  (input_embedding): InputEmbedding(
    (tok_embed): Embedding(100, 32)
    (pos_embed): PositionalEncoding()
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder_blocks): ModuleList(
    (0-5): 6 x EncoderBlock(
      (qkv_proj): Linear(in_features=32, out_features=96, bias=False)
      (attention): MultiHeadAttention(
        (dropout): Dropout(p=0.1, inplace=False)
        (output_fc): Linear(in_features=32, out_features=32, bias=False)
      )
      (dropout1): Dropout(p=0.1, inplace=False)
      (layer_norm1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=32, out_features=128, bias=False)
        (fc2): Linear(in_features=128, out_features=32, bias=False)
      )
      (layer_norm2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    )
  )
  (decoder_blocks): ModuleList(
    (0-5): 6 x DecoderBlock(
      (qkv_proj): Linear(in_features=32, out_features=96, bias=False)
      (attention1): MultiHea

In [37]:
output = model(src_input, tgt_input)
output.shape # (batch, seq_len, vocab_size)

torch.Size([50, 3, 100])

In [38]:
output[:1,:,:10]

tensor([[[ 0.0723,  0.1462,  0.7454,  0.2798,  0.4531,  0.2072, -0.1207,
           0.1987,  1.3193,  0.0299],
         [ 0.2442,  0.1691,  0.4099,  0.7136, -0.4226, -0.1607, -0.6516,
           0.7076, -0.1966, -0.3897],
         [ 1.0458,  0.0208, -0.5490, -0.1117,  0.3067,  1.0483, -0.5634,
           0.7478,  0.2823,  0.5300]]], grad_fn=<SliceBackward0>)

In [39]:
torch.softmax(output, dim=-1)[:1,:,:10]

tensor([[[0.0080, 0.0086, 0.0157, 0.0098, 0.0117, 0.0091, 0.0066, 0.0091,
          0.0278, 0.0077],
         [0.0098, 0.0091, 0.0116, 0.0157, 0.0050, 0.0065, 0.0040, 0.0156,
          0.0063, 0.0052],
         [0.0210, 0.0075, 0.0043, 0.0066, 0.0100, 0.0211, 0.0042, 0.0156,
          0.0098, 0.0125]]], grad_fn=<SliceBackward0>)