In [59]:
import math
import torch
import torch.nn as nn
from torch import Tensor

In [60]:
num_heads=8
size=embed_dim=256

In [61]:
head_size=size//num_heads
head_size

32

In [62]:
dim_feedforward=512
dropout=0.

In [63]:
transformer_model=nn.Transformer(d_model=embed_dim,nhead=num_heads, num_encoder_layers=2,num_decoder_layers=1,dim_feedforward=dim_feedforward,dropout=dropout)

In [64]:
transformer_model

Transformer(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): _LinearWithBias(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=512, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (linear2): Linear(in_features=512, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.0, inplace=False)
        (dropout2): Dropout(p=0.0, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): _LinearWithBias(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=512, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (linear2): Lin

# 直接构造源句子、目标句子Embedding
### S: is the source sequence length,
### T: is the target sequence length, 
### N: is the batch size, E is the feature number

In [83]:
src = torch.rand((10, 32, embed_dim)) # (S, N, E)
tgt = torch.rand((20, 32, embed_dim)) # (T, N, E)

In [84]:
out = transformer_model(src, tgt)
out.shape # (T, N, E)

torch.Size([20, 32, 256])

# 源句子、目标句子的构造
### src_idx: (N,S)
### tgt_idx: (N,T)

In [90]:
N=2
S=5
T=10
src_idx=torch.randint(0, 11, (N,S))
tgt_idx=torch.randint(0, 11, (N,T))

In [81]:
src_idx

tensor([[ 5,  4,  7,  1, 10],
        [ 3,  0,  2,  1,  8]])

In [82]:
tgt_idx

tensor([[ 9,  5,  1,  7,  9,  6,  8, 10,  6, 10],
        [ 5,  5,  6,  6,  0, 10,  5,  2,  7,  7]])

In [85]:
src_idx_reshaped=src_idx.permute(1,0)
tgt_idx_reshaped=tgt_idx.permute(1,0)

In [86]:
src_idx_reshaped

tensor([[ 5,  3],
        [ 4,  0],
        [ 7,  2],
        [ 1,  1],
        [10,  8]])

In [87]:
tgt_idx_reshaped

tensor([[ 9,  5],
        [ 5,  5],
        [ 1,  6],
        [ 7,  6],
        [ 9,  0],
        [ 6, 10],
        [ 8,  5],
        [10,  2],
        [ 6,  7],
        [10,  7]])

# Attention相关的输入参数构造
### src_mask: (S,S),
### tgt_mask: (T, T) 
### memory_mask: (T, S)

In [95]:
src_mask=memory_mask=None #encoder的暂时假设为None
tgt_mask=transformer_model.generate_square_subsequent_mask(sz=T)
tgt_mask

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [96]:
T,tgt_mask.shape

(10, torch.Size([10, 10]))

In [97]:
src = torch.rand((S, B, embed_dim)) # (S, N, E)
tgt = torch.rand((T, B, embed_dim)) # (T, N, E)

In [98]:
out = transformer_model(src, tgt,src_mask=src_mask,tgt_mask=tgt_mask,memory_mask=memory_mask)
out.shape # (T, N, E)

torch.Size([10, 2, 256])

# Padding_mask相关的输入参数构造
### src_key_padding_mask: (N,S),
### tgt_key_padding_mask: (N, T) 
### memory_key_padding_mask: (N, S)

In [111]:
pad_idx=1

In [114]:
src_idx=torch.randint(0, 11, (N,S))
src_idx

tensor([[ 1,  9,  2, 10,  2],
        [ 3,  7,  6,  1,  0]])

In [115]:
src_key_padding_mask=(src_idx==pad_idx)
src_key_padding_mask # If a BoolTensor is provided, the positions with the value of True will be ignored while the position with the value of False will be unchanged.

tensor([[ True, False, False, False, False],
        [False, False, False,  True, False]])

In [119]:
tgt_idx=torch.randint(0, 11, (N,T))
tgt_idx

tensor([[ 5, 10,  8, 10,  6,  7,  9,  8, 10,  4],
        [ 8,  0,  3,  5,  4,  1,  9,  2,  4,  7]])

In [120]:
tgt_key_padding_mask=(tgt_idx==pad_idx)
tgt_key_padding_mask # If a BoolTensor is provided, the positions with the value of True will be ignored while the position with the value of False will be unchanged.

tensor([[False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False,  True, False, False, False, False]])

In [121]:
src = torch.rand((S, B, embed_dim)) # (S, N, E)
tgt = torch.rand((T, B, embed_dim)) # (T, N, E)

In [122]:
out = transformer_model(src, tgt,
                        src_mask=src_mask,tgt_mask=tgt_mask,memory_mask=memory_mask,
                        src_key_padding_mask=src_key_padding_mask,tgt_key_padding_mask=tgt_key_padding_mask,
                        memory_key_padding_mask=None
                       )
out.shape # (T, N, E)

torch.Size([10, 2, 256])

## 构造最后的Liner层

In [123]:
vocab_size=11 # [0...10]

In [124]:
lin=nn.Linear(embed_dim,vocab_size)

In [125]:
ret=lin(out)

In [126]:
ret.shape

torch.Size([10, 2, 11])

In [128]:
ret_reshaped=ret.permute(1,0,2)
ret_reshaped.shape

torch.Size([2, 10, 11])

In [131]:
predict=ret_reshaped.argmax(dim=-1)
predict.shape

torch.Size([2, 10])

In [132]:
predict

tensor([[1, 3, 3, 1, 1, 1, 1, 1, 2, 1],
        [1, 5, 1, 1, 1, 1, 1, 1, 3, 1]])