In [37]:
import math
import pandas as pd
import torch
from torch import nn

In [38]:
class FFN(nn.Module):
    def __init__(self, hidden_d, output_d):
        super().__init__()
        self.dense1 = nn.LazyLinear(hidden_d)
        self.act = nn.ReLU()
        self.dense2 = nn.Linear(hidden_d, output_d)

    def forward(self, X):
        return self.dense2(self.act(self.dense1(X)))

In [39]:
ffn = FFN(100, 8)
test_tensor = torch.ones((5, 100))
ffn(test_tensor)

tensor([[-0.0923, -0.1244, -0.0296,  0.2299, -0.0257, -0.3215,  0.0891, -0.0597],
        [-0.0923, -0.1244, -0.0296,  0.2299, -0.0257, -0.3215,  0.0891, -0.0597],
        [-0.0923, -0.1244, -0.0296,  0.2299, -0.0257, -0.3215,  0.0891, -0.0597],
        [-0.0923, -0.1244, -0.0296,  0.2299, -0.0257, -0.3215,  0.0891, -0.0597],
        [-0.0923, -0.1244, -0.0296,  0.2299, -0.0257, -0.3215,  0.0891, -0.0597]],
       grad_fn=<AddmmBackward0>)

In [40]:
ln = nn.LayerNorm(3)
bn = nn.LazyBatchNorm1d()
test_tensor = torch.tensor([[0, 2, 3], [2, 3, 4], [2, 3, 4], [5, 6, 7]], dtype=torch.float32)

In [41]:
ln(test_tensor)

tensor([[-1.3363,  0.2673,  1.0690],
        [-1.2247,  0.0000,  1.2247],
        [-1.2247,  0.0000,  1.2247],
        [-1.2247,  0.0000,  1.2247]], grad_fn=<NativeLayerNormBackward0>)

In [42]:
bn(test_tensor)

tensor([[-1.2603, -1.0000, -1.0000],
        [-0.1400, -0.3333, -0.3333],
        [-0.1400, -0.3333, -0.3333],
        [ 1.5403,  1.6667,  1.6667]], grad_fn=<NativeBatchNormBackward0>)

In [43]:
class AddNorm(nn.Module):
    def __init__(self, norm_shape, dropout):
        super().__init__()
        self.norm = nn.LayerNorm(norm_shape)
        self.drop = nn.Dropout(dropout)

    def forward(self, X, Y):
        return self.norm(self.drop(Y) + X)

In [44]:
an = AddNorm(3, 0.1)
test_x = torch.ones((2, 3, 3))
test_y = torch.ones((2, 3, 3))
an(test_x, test_y)

tensor([[[ 0.7071, -1.4142,  0.7071],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000]],

        [[-1.4142,  0.7071,  0.7071],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.7071, -1.4142,  0.7071]]], grad_fn=<NativeLayerNormBackward0>)

In [45]:
class MultiheadAttention(nn.Module):
    def __init__(self, num_hidden, num_head, dropout, use_bias=False):
        super().__init__()
        self.num_head = num_head
        self.dropout = nn.Dropout(dropout)
        self.W_q = nn.LazyLinear(num_hidden, bias=use_bias)
        self.W_k = nn.LazyLinear(num_hidden, bias=use_bias)
        self.W_v = nn.LazyLinear(num_hidden, bias=use_bias)
        self.W_o = nn.LazyLinear(num_hidden, bias=use_bias)

    def transpose_qkv(self, x):
        x = x.reshape(x.shape[0], x.shape[1], self.num_head, -1)
        x = x.permute(0, 2, 1, 3)
        x = x.reshape(-1, x.shape[2], x.shape[3])
        return x

    def transpose_o(self, x):
        # x shape: (batch * num_head, seq_len, traits_dim / num_head)
        x = x.reshape(-1, self.num_head, x.shape[1], x.shape[2])
        x = x.permute(0, 2, 1, 3)
        x = x.reshape(x.shape[0], x.shape[1], -1)
        return x

    def dot_product_attetion_score(self, q, k, v):
        d = q.shape[-1]

    def sequence_mask(self, valid_len, x, value=0):
        #shape of valid_len: (batch)
        #shape of x: (batch, dim)
        max_len = x.shape[-1]
        mask = torch.repeat_interleave((torch.arange(max_len, dtype=torch.float32)).reshape(1, -1), repeats=valid_len.shape[0], dim=0)
        valid_len = torch.repeat_interleave(valid_len, repeats=max_len).reshape(-1, max_len)
        mask = mask < valid_len
        x[~mask] = value
        return x

    def masked_softmax(self, valid_len, x):
        # shape of x: (batch, seq_len, seq_len)
        # it means, in each batch, there are seq_len query and
        # each of them map to seq_len key-value pair.
        # shape of valid_len: (batch, seq_len)
        # it means for each query, we should just take care how
        # many key-value pair.
        if valid_len is None:
            return nn.functional.softmax(x, dim=-1)
        else:
            shape = x.shape
            if valid_len.dim() == 1:
                valid_len = torch.repeat_interleave(valid_len, shape[1])
            else:
                valid_len = valid_len.reshape(-1)

            x = x.reshape((-1, shape[-1]))
            x = sequence_mask(valid_len, x, value=-1e6)
            return nn.functional.softmax(x.reshape(shape), dim=-1)

    def forward(self, Q, K, V, valid_len=None):
        # shape of Q, K and V: (Batch, seq_len, traits_dimension)
        q = self.transpose_qkv(self.W_q(Q))
        k = self.transpose_qkv(self.W_k(K))
        v = self.transpose_qkv(self.W_v(V))
        # shape of q, k and v: (Batch * num_head, seq_len, traits_dimension / num_head)
        attetion_score = torch.bmm(q, k.transpose(1, 2)) / math.sqrt(q.shape[-1])
        if valid_len is not None:
            valid_len = torch.repeat_interleave(valid_len, repeats=self.num_hidden, dim=0)
        self.attention_weights = self.masked_softmax(valid_len, attetion_score)
        # res shape: (batch * num_head, seq_len, traits_dim / num_head)
        res = torch.bmm(self.dropout(self.attention_weights), v)
        # o shape: (batch, seq_len, traits_dim)
        o = self.transpose_o(res)
        output = self.W_o(o)
        return output

In [46]:
att = MultiheadAttention(8, 2, 0)
test_att = torch.ones((3, 5, 8))
res = att(test_att, test_att, test_att)
res.shape

torch.Size([3, 5, 8])

In [47]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, hidden_d, ffn_hidden_d, num_head, dropout, use_bias=False):
        super().__init__()
        self.attention = MultiheadAttention(hidden_d, num_head, dropout, use_bias)
        self.addnorm1 = AddNorm(hidden_d, dropout)
        self.ffn = FFN(ffn_hidden_d, hidden_d)
        self.addnorm2 = AddNorm(hidden_d, dropout)

    def forward(self, X, valid_lens=None):
        # X shape: (batch, seq_len, traits_dim)
        attention_output = self.attention(X, X, X, valid_lens)
        addnorm1 = self.addnorm1(X, attention_output)
        ffn = self.ffn(addnorm1)
        addnorm2 = self.addnorm2(addnorm1, ffn)
        return addnorm2

In [48]:
# the 1st dim for teb should equal to last dim for test_teb,
# because we should it's additive for add norm layer 1.
teb = TransformerEncoderBlock(8, 16, 2, 0)
test_teb = torch.ones((3, 5, 8))
res = teb(test_teb)
res.shape

torch.Size([3, 5, 8])

In [49]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, num_hidden):
        super().__init__()
        self.vocab_size = vocab_size
        self.num_hidden = num_hidden
        self.linear = nn.Linear(vocab_size, num_hidden, bias=False)

    def forward(self, X):
        # shape of X: (batch, seq_len)
        one_hot_x = torch.nn.functional.one_hot(X, num_classes=self.vocab_size).float()
        embedding_res = self.linear(one_hot_x)
        return embedding_res

In [50]:
emd = Embedding(10, 20)
test_emd = torch.tensor([1, 5, 6])
emd_res = emd(test_emd)
emd_res, emd_res.shape

(tensor([[ 0.0356, -0.0287, -0.2012,  0.2273, -0.2930,  0.2950, -0.0294, -0.1253,
           0.0435, -0.2875,  0.0802,  0.0469,  0.0069, -0.0489, -0.0305,  0.0602,
           0.1467, -0.2299,  0.0860, -0.1267],
         [-0.1032, -0.1029, -0.0809, -0.3148, -0.3031,  0.1387,  0.1939, -0.0663,
          -0.2699, -0.0253, -0.1051, -0.0866, -0.1352,  0.2682,  0.0427, -0.1792,
           0.2152, -0.1289, -0.1675, -0.1323],
         [ 0.0229,  0.1520,  0.1598,  0.0492, -0.2183,  0.0232, -0.2028, -0.2805,
          -0.2325,  0.1253,  0.2482,  0.0285, -0.0703,  0.2733, -0.1834, -0.1215,
          -0.0719, -0.2733, -0.1391, -0.0980]], grad_fn=<MmBackward0>),
 torch.Size([3, 20]))

In [51]:
class PositionEncoding(nn.Module):
    def __init__(self, num_dim, dropout, max_len=1000):
        super().__init__()
        # We use a fixed priori as P for current position encoding now.
        # There are 3 dims:
        # 1st dim for batch, because we could use boardcast mechanism, we set it as 1.
        # 2nd dim is the max len, we think it's t he row index for calculation.
        # 3rd dim is the num_dim, we think it's the column for calculation.
        self.P = torch.empty((1, max_len, num_dim))
        tmp = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow(
            1000, torch.arange(0, num_dim, 2, dtype=torch.float32) / num_dim)
        self.P[:, :, 0::2] = torch.sin(tmp)
        self.P[:, :, 1::2] = torch.cos(tmp)
        self.drop = nn.Dropout(dropout)

    def forward(self, X):
        tmp = X + self.P[:, :X.shape[1],:].to(X.device)
        return self.drop(tmp)

In [52]:
pe = PositionalEncoding(20, 0, 10)
pe_test = torch.randn((3, 8, 20))
pe_res = pe(pe_test)
pe_res.shape

torch.Size([3, 8, 20])

In [53]:
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, num_hidden, ffn_num_hidden, num_head, num_blks, dropout, use_bias=False):
        super().__init__()
        self.num_hidden = num_hidden
        self.embedding = Embedding(vocab_size, num_hidden)
        self.pos_encoding = PositionEncoding(num_hidden, dropout)
        self.blks = nn.Sequential()
        for i in range(num_blks):
            self.blks.add_module("block"+str(i), TransformerEncoderBlock(num_hidden, ffn_num_hidden, num_head, dropout, use_bias))

    # shape of X: (batch, seq_len)
    def forward(self, X, valid_lens=None):
        #shape of embedding: (batch, seq_len, num_hidden)
        embedding = self.embedding(X)
        #shape of pos_encoding: (batch, seq_len, num_hidden)
        pos_encoding = self.pos_encoding(embedding * math.sqrt(self.num_hidden))
        output = pos_encoding
        for i, blk in enumerate(self.blks):
            output = blk(output, valid_lens)
        return output

In [59]:
encoder = TransformerEncoder(10, 20, 30, 4, 2, 0)
# shape of encoder_test: (1, 7), batch is 1, seq_len is 7
encoder_test = torch.tensor([1, 2, 3, 4, 5, 6, 7]).reshape(1, -1)
encoder_res = encoder(encoder_test)
encoder_res, encoder_res.shape

(tensor([[[-4.4308e-01,  1.3189e+00, -8.5987e-01,  2.0115e+00, -1.4273e+00,
            3.4615e-02, -7.2143e-01, -4.8976e-01,  3.3794e-01, -3.7318e-01,
           -1.4779e+00,  5.2810e-01, -7.9528e-01,  1.5515e+00, -5.3594e-01,
           -1.1595e+00,  9.0920e-04,  9.5331e-01,  7.0481e-02,  1.4759e+00],
          [-3.8693e-01,  8.2518e-01, -1.5203e+00,  2.3357e-01,  2.6728e-02,
            9.4459e-01,  4.7949e-02,  1.2375e+00,  1.0800e+00,  1.5880e+00,
           -8.6881e-01,  1.5499e+00, -1.0151e+00,  6.0419e-01, -1.6954e+00,
            2.7596e-01, -7.5043e-01,  1.0690e-01, -1.1851e+00, -1.0983e+00],
          [ 8.2149e-01, -2.4622e+00,  7.3896e-01, -6.6537e-01,  8.6555e-01,
           -5.5550e-01, -4.0124e-01, -6.2197e-01,  1.1920e+00,  1.0700e+00,
            5.6262e-01,  8.3060e-01, -9.1356e-01,  1.8843e-01, -3.8616e-01,
           -1.4513e+00, -6.9863e-01, -5.9390e-01,  1.0762e+00,  1.4040e+00],
          [ 1.0509e-01, -1.4472e+00,  8.6075e-01,  6.3476e-01,  9.6304e-01,
         