In [0]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

In [0]:
# multi-head scaled dot-product attention
class MultiHeadedScaledDotProductAttention(nn.Module):
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()
        self.n = n_head
        self.d_k = d_k
        self.d_v = d_v
        self.temperature = np.sqrt(d_k)
        # self.temperature = torch.tensor(np.sqrt(d_k), dtype=torch.float, requires_grad=True)
        self.wq = nn.Linear(d_model, d_k * n_head)
        self.wk = nn.Linear(d_model, d_k * n_head)
        self.wv = nn.Linear(d_model, d_v * n_head)
        self.output = nn.Linear(d_v * n_head, d_model)
        self.dropout = nn.Dropout(p=dropout)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, q, k, v, residual=None, mask=None):
        if residual is None:
            residual = q
        # (batch_size, n, seq_length, d)
        q = self.wq(q).view(q.size(0), q.size(1), self.n, self.d_k).transpose(1, 2)
        k = self.wk(k).view(k.size(0), k.size(1), self.n, self.d_k).transpose(1, 2)
        v = self.wv(v).view(v.size(0), v.size(1), self.n, self.d_v).transpose(1, 2)
        attn = torch.matmul(q, k.transpose(2, 3))
        if mask is not None:
            attn = attn.masked_fill(mask.unsqueeze(1)==0, -np.inf)
        attn = F.softmax(attn / self.temperature, dim=-1)
        attn_dot_v = torch.matmul(attn, v).transpose(1, 2).reshape(q.size(0), -1, self.d_v * self.n)
        output = self.output(attn_dot_v)
        output = self.dropout(output)
        output = self.norm(output + residual)
        return output, attn

# positionwise feed-forward
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_hidden, dropout=0.1):
        super().__init__()
        self.L1 = nn.Linear(d_model, d_hidden)
        self.L2 = nn.Linear(d_hidden, d_model)
        self.dropout = nn.Dropout(p=dropout)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        residual = x
        output = self.L2(F.relu(self.L1(x)))
        output = self.dropout(output)
        output = self.norm(output + residual)
        return output

In [0]:
# encoder layer
class EncoderLayer(nn.Module):
    def __init__(self, d_model, d_hidden, n_head, d_k, d_v, dropout=0.1):
        super().__init__()
        self.slf_attn = MultiHeadedScaledDotProductAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffnn = PositionwiseFeedForward(d_model, d_hidden, dropout=dropout)

    def forward(self, enc_input, slf_attn_mask=None):
        slf_attn_output, slf_attn = self.slf_attn(enc_input, enc_input, enc_input, mask=slf_attn_mask)
        enc_output = self.pos_ffnn(slf_attn_output)
        return enc_output, slf_attn

# decoder layer
class DecoderLayer(nn.Module):
    def __init__(self, d_model, d_hidden, n_head, d_k, d_v, dropout=0.1):
        super().__init__()
        self.slf_attn = MultiHeadedScaledDotProductAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.enc_dec_attn = MultiHeadedScaledDotProductAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffnn = PositionwiseFeedForward(d_model, d_hidden, dropout=dropout)

    def forward(self, dec_input, enc_output, slf_attn_mask=None, enc_dec_attn_mask=None):
        slf_attn_output, slf_attn = self.slf_attn(dec_input, dec_input, dec_input, mask=slf_attn_mask)
        enc_dec_attn_output, enc_dec_attn = self.enc_dec_attn(slf_attn_output, enc_output, enc_output, mask=enc_dec_attn_mask)
        dec_output = self.pos_ffnn(enc_dec_attn_output)
        return dec_output, slf_attn, enc_dec_attn

# positional encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, n_position=100):
        super().__init__()
        self.pos_table = self._sinusoid_encoding_table(n_position, d_model)

    def _sinusoid_encoding_table(self, n_position, d_model):
        sinusoid_table = np.array([pos / np.power(10000, np.arange(d_model) // 2 * 2 / d_model)
            for pos in range(n_position)], dtype=np.float)
        sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) 
        sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])
        # (batch_size, n_position, d_model)
        return torch.tensor(sinusoid_table, dtype=torch.float).unsqueeze(0)

    def forward(self, x):
        return x + self.pos_table[:, :x.size(1)]

In [0]:
# encoder
class Encoder(nn.Module):
    def __init__(self, n_src_vocab, d_model, d_hidden, n_head, d_k, d_v, n_position=100, n_layer=2, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(n_src_vocab, d_model, padding_idx=0)
        self.pos_enc = PositionalEncoding(d_model, n_position=n_position)
        self.enc_layer_stack = nn.ModuleList([EncoderLayer(d_model, d_hidden, n_head, d_k, d_v, dropout=dropout)
            for _ in range(n_layer)])

    def forward(self, src_seq, src_mask=None, return_attns=False):
        enc_slf_attn_list = []
        
        enc_output = self.pos_enc(self.embedding(src_seq))
        for enc_layer in self.enc_layer_stack:
            enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=src_mask)
            enc_slf_attn_list += [enc_slf_attn] if return_attns else []
        
        return enc_output, enc_slf_attn_list

# decoder
class Decoder(nn.Module):
    def __init__(self, n_trg_vocab, d_model, d_hidden, n_head, d_k, d_v, n_position=100, n_layer=2, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(n_trg_vocab, d_model, padding_idx=0)
        self.pos_enc = PositionalEncoding(d_model, n_position=n_position)
        self.dec_layer_stack = nn.ModuleList([DecoderLayer(d_model, d_hidden, n_head, d_k, d_v, dropout=dropout)
            for _ in range(n_layer)])
        
    def forward(self, trg_seq, trg_mask, enc_output, src_mask, return_attns=False):
        dec_slf_attn_list, enc_dec_attn_list = [], []

        dec_output = self.pos_enc(self.embedding(trg_seq))
        for dec_layer in self.dec_layer_stack:
            dec_output, dec_slf_attn, enc_dec_attn = dec_layer(dec_output, enc_output, slf_attn_mask=trg_mask, enc_dec_attn_mask=src_mask)
            dec_slf_attn_list += [dec_slf_attn] if return_attns else []
            enc_dec_attn_list += [enc_dec_attn] if return_attns else []
        
        return dec_output, dec_slf_attn_list, enc_dec_attn_list

In [0]:
# transformer
class Transformer(nn.Module):
    def __init__(self, n_src_vocab, n_trg_vocab, d_model=256, d_hidden=1024, 
                 n_head=8, d_k=64, d_v=64, n_position=100, n_layer=2, dropout=0.1, 
                 emb_src_trg_weight_sharing=True, trg_emb_prj_weight_sharing=True):
        super().__init__()
        self.encoder = Encoder(n_src_vocab, d_model, d_hidden, n_head, d_k, d_v, 
                               n_position=n_position, n_layer=n_layer, dropout=dropout)
        self.decoder = Decoder(n_trg_vocab, d_model, d_hidden, n_head, d_k, d_v, 
                               n_position=n_position, n_layer=n_layer, dropout=dropout)
        self.trg_prj = nn.Linear(d_model, n_trg_vocab, bias=False)

        if emb_src_trg_weight_sharing:
            self.encoder.embedding.weight = self.decoder.embedding.weight
        if trg_emb_prj_weight_sharing:
            self.trg_prj.weight = self.decoder.embedding.weight
            self.x_logit_scale = d_model ** -0.5

    def forward(self, src_seq, src_mask: torch.Tensor, trg_seq, trg_mask: torch.Tensor):
        enc_output, *_ = self.encoder(src_seq, src_mask)
        dec_output, *_ = self.decoder(trg_seq, trg_mask, enc_output, src_mask)
        transformer_output = F.softmax(self.trg_prj(dec_output) * self.x_logit_scale, dim=-1)

        return transformer_output

In [28]:
model = Transformer(30, 30)
x1 = torch.tensor([[3, 8, 9, 2]])
x2 = torch.tensor([[1, 5, 19]])
model(x1, None, x2, None)

tensor([[[7.6401e-06, 9.9968e-01, 3.2029e-06, 2.4810e-06, 5.6179e-06,
          2.2567e-06, 2.7340e-05, 7.3559e-07, 4.7759e-07, 8.8210e-06,
          1.8307e-05, 1.0132e-05, 1.3958e-06, 2.8679e-06, 6.7776e-06,
          6.4898e-06, 5.4635e-06, 1.5337e-05, 2.5972e-05, 1.0940e-05,
          9.3544e-07, 2.8428e-06, 8.2941e-06, 5.0713e-06, 9.1464e-06,
          9.1178e-05, 1.1957e-05, 5.3190e-06, 1.3683e-05, 6.6922e-06],
         [4.1794e-06, 2.4963e-06, 1.1406e-06, 3.8322e-06, 1.4624e-05,
          9.9983e-01, 2.1844e-05, 1.0107e-05, 1.8736e-06, 9.1035e-07,
          1.1114e-05, 4.1653e-06, 1.0109e-05, 8.3385e-06, 1.6902e-06,
          1.0906e-06, 1.7582e-05, 3.0540e-06, 1.0815e-06, 1.2542e-05,
          5.4959e-06, 9.4522e-07, 1.9852e-06, 1.2144e-06, 9.0445e-07,
          4.8632e-06, 3.4214e-06, 3.8293e-06, 1.5750e-05, 1.2982e-06],
         [5.9970e-06, 6.7244e-06, 1.2330e-05, 5.8685e-06, 1.7898e-06,
          1.0640e-05, 4.1920e-05, 3.3656e-06, 1.0078e-06, 1.8784e-06,
          1.0348e-

In [0]:
from tqdm import tqdm
import torch.optim as optim
import torch.utils.data as data

class MnistDataset(data.Dataset):
    def __init__(self, csv_file):
        self.df = pd.read_csv(csv_file, header=None)
    
    def __len__(self): 
        return len(self.df)

    def __getitem__(self, idx):
        data = self.df.loc[idx].to_numpy()
        label = data[0]
        img = torch.from_numpy((data[1:] / 255).astype(np.float32))
        return img, label

class mnistTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = EncoderLayer(784, 256, 4, 16, 16)
        self.output_layer = nn.Linear(784, 10)

    def forward(self, x):
        enc_output = self.encoder(x)
        return F.softmax(self.output_layer(enc_output), dim=-1)

In [0]:
train_data = MnistDataset('/content/sample_data/mnist_train_small.csv')
test_data = MnistDataset('/content/sample_data/mnist_test.csv')
train_loader = data.DataLoader(train_data, batch_size=16, shuffle=True)
test_loader = data.DataLoader(test_data, batch_size=16)

In [0]:
model = mnistTransformer()
loss_func = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

In [0]:
for epoch in range(2):
    for (x, y_true) in tqdm(train_loader):
        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_func(y_pred, y_true)
        loss.backward()
        optimizer.step()

100%|██████████| 1250/1250 [00:12<00:00, 103.47it/s]
100%|██████████| 1250/1250 [00:12<00:00, 103.05it/s]


In [0]:
model.eval()
correct = 0
for i, (X_test, y_test) in enumerate(test_loader):
    y_pred = model(X_test)
    correct += torch.argmax(y_pred, dim=-1).eq(y_test).sum()
print(correct.item() / len(test_data))

0.9657


In [0]:
a = 1 / torch.pow(10000, torch.arange(10) // 2 * 2 / 2.)