In [1]:
import math
import pandas as pd
import torch
from torch import nn
from d2l import torch as d2l

In [2]:
class FFN(nn.Module):
    def __init__(self, hidden_d, output_d):
        super().__init__()
        self.dense1 = nn.LazyLinear(hidden_d)
        self.act = nn.ReLU()
        self.dense2 = nn.Linear(hidden_d, output_d)

    def forward(self, X):
        return self.dense2(self.act(self.dense1(X)))

In [3]:
ffn = FFN(100, 8)
test_tensor = torch.ones((5, 100))
ffn(test_tensor)

tensor([[ 0.0449,  0.3388, -0.2792,  0.0094, -0.0191,  0.1003,  0.0348, -0.1328],
        [ 0.0449,  0.3388, -0.2792,  0.0094, -0.0191,  0.1003,  0.0348, -0.1328],
        [ 0.0449,  0.3388, -0.2792,  0.0094, -0.0191,  0.1003,  0.0348, -0.1328],
        [ 0.0449,  0.3388, -0.2792,  0.0094, -0.0191,  0.1003,  0.0348, -0.1328],
        [ 0.0449,  0.3388, -0.2792,  0.0094, -0.0191,  0.1003,  0.0348, -0.1328]],
       grad_fn=<AddmmBackward0>)

In [4]:
ln = nn.LayerNorm(3)
bn = nn.LazyBatchNorm1d()
test_tensor = torch.tensor([[0, 2, 3], [2, 3, 4], [2, 3, 4], [5, 6, 7]], dtype=torch.float32)

In [5]:
ln(test_tensor)

tensor([[-1.3363,  0.2673,  1.0690],
        [-1.2247,  0.0000,  1.2247],
        [-1.2247,  0.0000,  1.2247],
        [-1.2247,  0.0000,  1.2247]], grad_fn=<NativeLayerNormBackward0>)

In [6]:
bn(test_tensor)

tensor([[-1.2603, -1.0000, -1.0000],
        [-0.1400, -0.3333, -0.3333],
        [-0.1400, -0.3333, -0.3333],
        [ 1.5403,  1.6667,  1.6667]], grad_fn=<NativeBatchNormBackward0>)

In [7]:
class AddNorm(nn.Module):
    def __init__(self, norm_shape, dropout):
        super().__init__()
        self.norm = nn.LayerNorm(norm_shape)
        self.drop = nn.Dropout(dropout)

    def forward(self, X, Y):
        return self.norm(self.drop(Y) + X)

In [8]:
an = AddNorm(3, 0.1)
test_x = torch.ones((2, 3, 3))
test_y = torch.ones((2, 3, 3))
an(test_x, test_y)

tensor([[[ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [-0.7071, -0.7071,  1.4142]],

        [[ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000]]], grad_fn=<NativeLayerNormBackward0>)

In [9]:
class MultiheadAttention(nn.Module):
    def __init__(self, num_hidden, num_head, dropout, use_bias=False):
        super().__init__()
        self.num_head = num_head
        self.dropout = nn.Dropout(dropout)
        self.W_q = nn.LazyLinear(num_hidden, bias=use_bias)
        self.W_k = nn.LazyLinear(num_hidden, bias=use_bias)
        self.W_v = nn.LazyLinear(num_hidden, bias=use_bias)
        self.W_o = nn.LazyLinear(num_hidden, bias=use_bias)

    def transpose_qkv(self, x):
        x = x.reshape(x.shape[0], x.shape[1], self.num_head, -1)
        x = x.permute(0, 2, 1, 3)
        x = x.reshape(-1, x.shape[2], x.shape[3])
        return x

    def transpose_o(self, x):
        # x shape: (batch * num_head, seq_len, traits_dim / num_head)
        x = x.reshape(-1, self.num_head, x.shape[1], x.shape[2])
        x = x.permute(0, 2, 1, 3)
        x = x.reshape(x.shape[0], x.shape[1], -1)
        return x

    def dot_product_attetion_score(self, q, k, v):
        d = q.shape[-1]

    def sequence_mask(self, valid_len, x, value=0):
        #shape of valid_len: (batch)
        #shape of x: (batch, dim)
        max_len = x.shape[-1]
        mask = torch.repeat_interleave((torch.arange(max_len, dtype=torch.float32)).reshape(1, -1), repeats=valid_len.shape[0], dim=0)
        valid_len = torch.repeat_interleave(valid_len, repeats=max_len).reshape(-1, max_len)
        mask = mask < valid_len
        x[~mask] = value
        return x

    def masked_softmax(self, valid_len, x):
        # shape of x: (batch, query_size, kv_size)
        # it means, in each batch, there are query_size query and
        # each of them map to kv_size key-value pair.
        # shape of valid_len: (batch, query_size) or (batch)
        # it means for each query, we should just take care how
        # many key-value pair. So we could say each element in valid_len
        # denotes the valid length of a batch
        if valid_len is None:
            return nn.functional.softmax(x, dim=-1)
        else:
            shape = x.shape
            if valid_len.dim() == 1:
                valid_len = torch.repeat_interleave(valid_len, shape[1])
            else:
                valid_len = valid_len.reshape(-1)

            x = x.reshape((-1, shape[-1]))
            x = self.sequence_mask(valid_len, x, value=-1e6)
            return nn.functional.softmax(x.reshape(shape), dim=-1)

    def forward(self, Q, K, V, valid_len=None):
        # shape of Q, K and V: (Batch, seq_len, traits_dimension)
        q = self.transpose_qkv(self.W_q(Q))
        k = self.transpose_qkv(self.W_k(K))
        v = self.transpose_qkv(self.W_v(V))
        # shape of q, k and v: (Batch * num_head, seq_len, traits_dimension / num_head)
        attetion_score = torch.bmm(q, k.transpose(1, 2)) / math.sqrt(q.shape[-1])
        if valid_len is not None:
            valid_len = torch.repeat_interleave(valid_len, repeats=self.num_head, dim=0)
        self.attention_weights = self.masked_softmax(valid_len, attetion_score)
        # res shape: (batch * num_head, seq_len, traits_dim / num_head)
        res = torch.bmm(self.dropout(self.attention_weights), v)
        # o shape: (batch, seq_len, traits_dim)
        o = self.transpose_o(res)
        output = self.W_o(o)
        return output

In [10]:
att = MultiheadAttention(8, 2, 0)
# query size is 5, kv size is 10.
# query size is length of query, kv size is the length of kv sequence.
test_mask = torch.ones((2, 5, 10))
test_mask_valid_len = torch.tensor([6, 8])
test_res = att.masked_softmax(test_mask_valid_len, test_mask)
test_res, test_res.shape

(tensor([[[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000,
           0.0000, 0.0000]],
 
         [[0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250,
           0.0000, 0.0000],
          [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250,
           0.0000, 0.0000],
          [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250,
           0.0000, 0.0000],
          [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250,
           0.0000, 0.0000],
          [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1

In [11]:
att = MultiheadAttention(8, 2, 0)
test_att = torch.ones((3, 5, 8))
res = att(test_att, test_att, test_att)
res.shape

torch.Size([3, 5, 8])

In [12]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, hidden_d, ffn_hidden_d, num_head, dropout, use_bias=False):
        super().__init__()
        self.attention = MultiheadAttention(hidden_d, num_head, dropout, use_bias)
        self.addnorm1 = AddNorm(hidden_d, dropout)
        self.ffn = FFN(ffn_hidden_d, hidden_d)
        self.addnorm2 = AddNorm(hidden_d, dropout)

    def forward(self, X, valid_lens=None):
        # X shape: (batch, seq_len, traits_dim)
        attention_output = self.attention(X, X, X, valid_lens)
        addnorm1 = self.addnorm1(X, attention_output)
        ffn = self.ffn(addnorm1)
        addnorm2 = self.addnorm2(addnorm1, ffn)
        return addnorm2

In [13]:
# the 1st dim for teb should equal to last dim for test_teb,
# because we should it's additive for add norm layer 1.
teb = TransformerEncoderBlock(8, 16, 2, 0)
test_teb = torch.ones((3, 5, 8))
res = teb(test_teb)
res.shape

torch.Size([3, 5, 8])

In [14]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, num_hidden):
        super().__init__()
        self.vocab_size = vocab_size
        self.num_hidden = num_hidden
        self.linear = nn.Linear(vocab_size, num_hidden, bias=False)

    def forward(self, X):
        # shape of X: (batch, seq_len)
        one_hot_x = torch.nn.functional.one_hot(X, num_classes=self.vocab_size).float()
        embedding_res = self.linear(one_hot_x)
        return embedding_res

In [15]:
emd = Embedding(10, 20)
test_emd = torch.tensor([1, 5, 6])
emd_res = emd(test_emd)
emd_res, emd_res.shape

(tensor([[ 0.0136,  0.0778, -0.0492,  0.1628,  0.2445, -0.2048,  0.1555, -0.0560,
           0.1219,  0.0179,  0.2353,  0.2344, -0.2532, -0.0658, -0.0751, -0.2879,
           0.0574,  0.2323,  0.1044,  0.2377],
         [-0.1416,  0.0018, -0.2485,  0.0349, -0.0452, -0.0141,  0.0453,  0.2215,
           0.1966, -0.1201, -0.1656,  0.2511, -0.2178, -0.1673,  0.0027, -0.0467,
           0.1079, -0.0530,  0.1784,  0.2047],
         [-0.0287,  0.0944, -0.2321, -0.0761, -0.2841,  0.2770,  0.0496, -0.0037,
          -0.1656, -0.2322, -0.1400, -0.1111,  0.0765, -0.0110,  0.2477,  0.1008,
          -0.1553,  0.1353, -0.2510, -0.0174]], grad_fn=<MmBackward0>),
 torch.Size([3, 20]))

In [16]:
class PositionEncoding(nn.Module):
    def __init__(self, num_dim, dropout, max_len=1000):
        super().__init__()
        # We use a fixed priori as P for current position encoding now.
        # There are 3 dims:
        # 1st dim for batch, because we could use boardcast mechanism, we set it as 1.
        # 2nd dim is the max len, we think it's t he row index for calculation.
        # 3rd dim is the num_dim, we think it's the column for calculation.
        self.P = torch.empty((1, max_len, num_dim))
        tmp = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow(
            1000, torch.arange(0, num_dim, 2, dtype=torch.float32) / num_dim)
        self.P[:, :, 0::2] = torch.sin(tmp)
        self.P[:, :, 1::2] = torch.cos(tmp)
        self.drop = nn.Dropout(dropout)

    def forward(self, X):
        tmp = X + self.P[:, :X.shape[1],:].to(X.device)
        return self.drop(tmp)

In [17]:
pe = PositionEncoding(20, 0, 10)
pe_test = torch.randn((3, 8, 20))
pe_res = pe(pe_test)
pe_res.shape

torch.Size([3, 8, 20])

In [18]:
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, num_hidden, ffn_num_hidden, num_head, num_blks, dropout, use_bias=False):
        super().__init__()
        self.num_hidden = num_hidden
        self.embedding = Embedding(vocab_size, num_hidden)
        self.pos_encoding = PositionEncoding(num_hidden, dropout)
        self.blks = nn.Sequential()
        for i in range(num_blks):
            self.blks.add_module("block"+str(i), TransformerEncoderBlock(num_hidden, ffn_num_hidden, num_head, dropout, use_bias))

    # shape of X: (batch, seq_len)
    def forward(self, X, valid_lens=None):
        #shape of embedding: (batch, seq_len, num_hidden)
        embedding = self.embedding(X)
        #shape of pos_encoding: (batch, seq_len, num_hidden)
        pos_encoding = self.pos_encoding(embedding * math.sqrt(self.num_hidden))
        output = pos_encoding
        for i, blk in enumerate(self.blks):
            output = blk(output, valid_lens)
        return output

In [19]:
encoder = TransformerEncoder(10, 20, 30, 4, 2, 0)
# shape of encoder_test: (1, 7), batch is 1, seq_len is 7
encoder_test = torch.tensor([1, 2, 3, 4, 5, 6, 7]).reshape(1, -1)
encoder_res = encoder(encoder_test)
encoder_res, encoder_res.shape

(tensor([[[-0.4015,  0.6324, -0.3742, -0.5606,  0.1977,  1.1944, -0.6989,
           -0.8887,  0.7093, -0.7248, -0.0934,  1.7166, -1.6133,  0.4738,
            0.0736,  1.7724, -0.4961,  1.2448, -2.1177, -0.0457],
          [ 0.3790,  0.9859,  0.8214, -0.3690, -0.6692,  1.0659, -1.7662,
           -0.6512,  0.6832, -0.6441, -0.7891,  0.2895, -1.4407,  0.9283,
           -1.3427,  1.2187, -0.2690,  0.4907, -0.9162,  1.9949],
          [-0.6043, -1.5576,  0.7326,  0.0359, -0.1923,  1.0640,  0.6331,
           -0.7142, -1.3970, -0.4047, -1.0121, -0.2890,  0.2993,  0.3074,
            0.0421,  0.0322,  0.3062,  2.8072, -1.3098,  1.2208],
          [ 0.6357, -0.9416,  1.4164, -1.3482,  0.5388, -0.8428, -0.3452,
            0.7658, -0.1868,  0.5980, -0.7559,  0.1237,  0.4097,  1.1404,
            0.0214,  1.7599, -1.5986,  0.3772, -2.1826,  0.4148],
          [-1.3200, -0.8008,  0.3148, -1.9074, -0.3466,  1.5802,  0.7163,
           -1.1239,  0.4652, -0.7073,  1.2105, -0.3560, -1.5596,  0.20

In [31]:
class TransformerDecoderBlock(nn.Module):
    def __init__(self, num_hiddens, ffn_num_hiddens, num_head, dropout, i, use_bias=False):
        super().__init__()
        # i is the idx for this block
        self.i = i
        self.self_attention = MultiheadAttention(num_hiddens, num_head, dropout, use_bias)
        self.addnorm1 = AddNorm(num_hiddens, dropout)
        self.cross_attention = MultiheadAttention(num_hiddens, num_head, dropout, use_bias)
        self.addnorm2 = AddNorm(num_hiddens, dropout)
        self.ffn = FFN(ffn_num_hiddens, num_hiddens)
        self.addnorm3 = AddNorm(num_hiddens, dropout)

    def forward(self, X, state):
        # shape of X: (batch_size, seq_len, num_hiddens)
        # X came from decoder self.
        # For example, at first step for decoder, it's bos (begin of sentence)
        # state came from encoder, it's a list:
        # [output from encoder, valid_len, [key and value for i-th decoder block]]
        enc_outputs, enc_valid_lens = state[0], state[1]
        
        self_key_value = X

        # Given the argument of masked_softmax is attention score, its shape is
        # (batch, query_size, kv_size), so the meanning of dec_len means how many
        # kv the current query will pay attention to.
        # Hence for the 1st query the dec_len should be 1, the 2nd query should be
        # 2... For the last query, the dec_len should be qurry_total_len
        batch_size, query_total_len = X.shape[0], X.shape[1]
        dec_valid_len = torch.arange(1, query_total_len + 1, device=X.device).repeat(batch_size, 1)

        self_att = self.self_attention(X, self_key_value, self_key_value, dec_valid_len)
        addnor1 = self.addnorm1(X, self_att)
        cross_att = self.cross_attention(addnor1, enc_outputs, enc_outputs, enc_valid_lens)
        addnor2 = self.addnorm1(addnor1, cross_att)
        ffn = self.ffn(addnor2)
        addnorm3 = self.addnorm3(addnor2, ffn)

        return addnorm3, state

In [32]:
dec_blk = TransformerDecoderBlock(10, 20, 2, 0, 0)
test_dec_blk_x = torch.ones((2, 5, 10))
test_dec_blk_state = [torch.randn((2, 7, 10)), torch.tensor([4, 6])]
test_res = dec_blk(test_dec_blk_x, test_dec_blk_state)
test_res[0].shape

torch.Size([2, 5, 10])

In [33]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_head, num_blks, dropout, use_bias=False):
        super().__init__()
        self.num_hiddens = num_hiddens
        self.num_blks = num_blks

        self.embedding = Embedding(vocab_size, num_hiddens)
        self.pos_encoding = PositionEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(self.num_blks):
            self.blks.add_module("block"+str(i), TransformerDecoderBlock(num_hiddens, ffn_num_hiddens, num_head, dropout, i))
        self.linear = nn.LazyLinear(vocab_size)

    def forward(self, X, state):
        # X shape is (batch, seq_len)
        emb = self.embedding(X)
        pos_encoding = self.pos_encoding(emb * math.sqrt(self.num_hiddens))
        blk_input = pos_encoding
        for i, blk in enumerate(self.blks):
            blk_input, state = blk(blk_input, state)
        return self.linear(blk_input), state
        

In [37]:
td = TransformerDecoder(20, 10, 20, 2, 2, 0)
test_td = torch.tensor([[1, 2]])
test_td_state = [torch.randn(2, 5, 10), torch.tensor([2, 3])]
test_res = td(test_td, test_td_state)
test_res[0].shape

RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [2, 5] but got: [4, 5].

In [22]:
batch_size = 2
num_steps = 5
dec_valid_lens = torch.arange(
                1, num_steps + 1).repeat(batch_size, 1)
dec_valid_lens

tensor([[1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5]])

In [23]:
data = d2l.MTFraEng(batch_size=128)
print(data)

<d2l.torch.MTFraEng object at 0x716eed3fd0c0>


In [24]:
for a in data.train_dataloader():
    print(a)
    for tt in a:
        print(tt.shape)
    break

[tensor([[158,  91,   2,  ...,   4,   4,   4],
        [ 28, 122,   2,  ...,   4,   4,   4],
        [183,  98,   2,  ...,   4,   4,   4],
        ...,
        [ 11, 163,   2,  ...,   4,   4,   4],
        [ 39, 122,   2,  ...,   4,   4,   4],
        [159,  91,   2,  ...,   4,   4,   4]]), tensor([[  3,   6,   0,  ...,   5,   5,   5],
        [  3,  15,   0,  ...,   5,   5,   5],
        [  3, 135,   6,  ...,   5,   5,   5],
        ...,
        [  3,   6,   2,  ...,   5,   5,   5],
        [  3,   6,   0,  ...,   5,   5,   5],
        [  3,   6,   2,  ...,   5,   5,   5]]), tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 4, 5, 4, 3, 4, 4, 4, 4,
        3, 4, 4, 4, 4, 3, 4, 4, 4, 4, 4, 5, 4, 4, 4, 4, 4, 5, 4, 4, 4, 4, 4, 4,
        4, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 4, 4, 4,
        3, 4, 4, 4, 3, 4, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 4, 5, 4, 4, 5, 4, 4, 5,
        4, 4, 4, 