In [100]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

# set random seed
torch.manual_seed(0)

# About word embedding, taking sequence modeling as an example
# Consider the source sentence and target sentence
# Build a sequence, where the characters in the sequence are represented by their index in the vocabulary.
batch_size = 2

# word dict size
max_num_src_words = 8
max_num_tgt_words = 8
model_dim = 8

# max length of source sentence and target sentence
max_src_len = 5
max_tgt_len = 5
max_postion_len = 6

# src_len = torch.randint(2, 5, (batch_size,))
# tgt_len = torch.randint(2, 5, (batch_size,))
src_len = torch.Tensor([2, 4]).to(torch.int32)
tgt_len = torch.Tensor([4, 3]).to(torch.int32)

# Create empty lists for the source sequence and target sequence.
src_seq = []
tgt_seq = []

# step 1: create sequence
# Generate random numbers and fill in the source sequence
for L in src_len:
    random_numbers = torch.randint(1, max_num_src_words, (L,))
    padded_sequence = F.pad(random_numbers, (0, max(src_len) - L))
    src_seq.append(padded_sequence.unsqueeze(0))
src_seq = torch.cat(src_seq, dim=0)

# Generate random numbers and fill in the target sequence
for L in tgt_len:
    random_numbers = torch.randint(1, max_num_tgt_words, (L,))
    padded_sequence = F.pad(random_numbers, (0, max(tgt_len) - L))
    tgt_seq.append(padded_sequence.unsqueeze(0))
tgt_seq = torch.cat(tgt_seq, dim=0)

# step 2: create word embedding
src_embedding_table = nn.Embedding(max_num_src_words+1, model_dim)
tgt_embedding_table = nn.Embedding(max_num_tgt_words+1, model_dim)
src_embedding = src_embedding_table(src_seq)
tgt_embedding = tgt_embedding_table(tgt_seq)

# step 3: create position embedding
pos_mat = torch.arange(max_postion_len).unsqueeze(1)
i_mat = torch.pow(10000, torch.arange(0, 8, 2).unsqueeze(0) / model_dim) 
pe_embedding_table = torch.zeros(max_postion_len, model_dim)

pe_embedding_table[:, 0::2] = torch.sin(pos_mat / i_mat)
pe_embedding_table[:, 1::2] = torch.cos(pos_mat / i_mat)

pe_embedding = nn.Embedding(max_postion_len, model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False)

src_pos = [torch.arange(max(src_len)).unsqueeze(0) for _ in src_len]
src_pos = torch.cat(src_pos, dim=0)
tgt_pos = [torch.arange(max(tgt_len)).unsqueeze(0) for _ in tgt_len]
tgt_pos = torch.cat(tgt_pos, dim=0)

src_pe_embedding = pe_embedding(src_pos)
tgt_pe_embedding = pe_embedding(tgt_pos)

# alpha1 = 0.1
# alpha2 = 10
# score = torch.randn(5)
# prob1 = F.softmax(score*alpha1, dim=-1)
# prob2 = F.softmax(score*alpha2, dim=-1)

# def softmax_func(score):
#     return F.softmax(score, dim=-1)

# jaco_mat1 = torch.autograd.functional.jacobian(softmax_func, score*alpha1)
# jaco_mat2 = torch.autograd.functional.jacobian(softmax_func, score*alpha2)

# step 4 create encoder's self-attention mask
# mask shape: (batch_size, max_src_len, max_src_len), value is 1 or -inf
valid_encoder_pos = []
for L in src_len:
    valid_encoder_pos.append(F.pad(torch.ones(L), (0, max(src_len) - L)).unsqueeze(0))
valid_encoder_pos = torch.cat(valid_encoder_pos, dim=0).unsqueeze(2)

valid_decoder_pos = []
for L in tgt_len:
    valid_decoder_pos.append(F.pad(torch.ones(L), (0, max(tgt_len) - L)).unsqueeze(0))
valid_decoder_pos = torch.cat(valid_decoder_pos, dim=0).unsqueeze(2)

valid_encoder_pos_matrix = torch.bmm(valid_encoder_pos, valid_encoder_pos.transpose(1, 2))
invalid_encoder_pos_matrix = 1 - valid_encoder_pos_matrix
mask_encoder_self_attention = invalid_encoder_pos_matrix.to(torch.bool)

# score = torch.randn(batch_size, max(src_len), max(src_len))
# masked_score = score.masked_fill(mask_encoder_self_attention, -1e9)
# prob = F.softmax(masked_score, dim=-1)

# step 5: create intra-attention mask
# Q @ K^T shape: [batch_size, tgt_seq_len, src_seq_len]
valid_cross_pos = torch.bmm(valid_decoder_pos, valid_encoder_pos.transpose(1, 2))
invalid_cross_pos = 1 - valid_cross_pos
mask_cross_attention = invalid_cross_pos.to(torch.bool)

# step 6: decoder self-attention mask
valid_decoder_tri_matrix = []
for L in tgt_len:
    tri_matrix = torch.tril(torch.ones(L, L))
    valid_decoder_tri_matrix.append(F.pad(tri_matrix, (0, max(tgt_len) - L, 0, max(tgt_len) - L)).unsqueeze(0))

valid_decoder_tri_matrix = torch.cat(valid_decoder_tri_matrix, dim=0)
invalid_decoder_tri_matrix = 1 - valid_decoder_tri_matrix
mask_decoder_self_attention = invalid_decoder_tri_matrix.to(torch.bool)

# score = torch.randn(batch_size, max(tgt_len), max(tgt_len))
# masked_score = score.masked_fill(mask_decoder_self_attention, -1e9)
# prob = F.softmax(masked_score, dim=-1)

# step 7: create scaled self-attention
def scaled_dot_product_attention(Q, K, V, atten_mask):
    # shape of Q, K, V: [batch_size*num_head, seq_len, model_dim/num_head]
    print(Q.shape, K.shape, V.shape, atten_mask.shape)
    score = torch.bmm(Q, K.transpose(-2, -1)) / np.sqrt(model_dim)
    masked_score = score.masked_fill(atten_mask, -1e9)
    prob = F.softmax(masked_score, dim=-1)
    context = torch.bmm(prob, V)
    return context

num_head = 2
# evaluate scaled_dot_product_attention
Q = torch.randn(batch_size*num_head, max(tgt_len), model_dim//num_head)
K = torch.randn(batch_size*num_head, max(src_len), model_dim//num_head)
V = torch.randn(batch_size*num_head, max(src_len), model_dim//num_head)
atten_mask = mask_encoder_self_attention.repeat(num_head, 1, 1)

context = scaled_dot_product_attention(Q, K, V, atten_mask)

display(
    # src_embedding,
    # tgt_embedding,
    # atten_mask,
    # context
    # mask_decoder_self_attention,
    # masked_score,
    # prob
)

torch.Size([2, 4])
torch.Size([2, 4, 1])
torch.Size([4, 4, 4]) torch.Size([4, 4, 4]) torch.Size([4, 4, 4]) torch.Size([4, 4, 4])


In [105]:
# step 8: mask loss

logits = torch.randn(2, 3, 4).transpose(1, 2)
label = torch.randint(0, 4, (2, 3))

tgt_len = torch.tensor([2, 3]).to(torch.int32)

tgt_mask = []
for L in tgt_len:
    tgt_mask.append(F.pad(torch.ones(L), (0, max(tgt_len) - L)).unsqueeze(0))
tgt_mask = torch.cat(tgt_mask, dim=0)

label[0, 2] = -100

display(
    tgt_mask,
    logits,
    label,
    F.cross_entropy(logits, label, reduction='none')
)

tensor([[1., 1., 0.],
        [1., 1., 1.]])

tensor([[[ 0.7980,  0.7009,  0.2552],
         [-1.1071,  2.1099,  0.1299],
         [ 2.3306, -1.1239,  1.3476],
         [-1.0456,  0.7607,  0.5408]],

        [[-0.9478, -0.0352,  0.8931],
         [ 0.2021,  0.4964, -1.4541],
         [-0.3507, -0.8405,  1.1875],
         [ 0.5450,  0.4047, -0.2995]]])

tensor([[   1,    2, -100],
        [   2,    0,    3]])

tensor([[3.6864, 3.6677, 0.0000],
        [1.7471, 1.5478, 2.2011]])