In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

import math

In [2]:
class Attention(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, query, key, value, mask=None):
        """
        N: batch size
        seq_len: sequience lenght (e.g. sentence length)
        h: number of heads
        D_k: dim of key/query
        D_v: dim of value
        
        This is the case of single head
            query: N x seq_len x D_k
            key: N x seq_len x D_k
            value: N x seq_len x D_v
        
        In the case of multi head, the input would be
            query: N x h x seq_len x D_k
            key: N x h x seq_len x D_k
            value: N x h x seq_len x D_v
        
        So to make it easily support the above case, we just need to make sure
        that we we are working on the final 2 dim for matrix multi when we do
        the key x query
            
        """
        d_k = query.shape[-1]
        
        dot_product = torch.matmul(query, torch.transpose(key, -2, -1)) / math.sqrt(d_k)
        
        if mask is not None:
            dot_product = dot_product.masked_fill(mask == 0, -1e9)
        
        return torch.matmul(F.softmax(dot_product, dim=-1), value)

In [19]:
class MultiHead(nn.Module):
    def __init__(self, d_model, d_k, d_v, h=2):
        super().__init__()
        
        self.h = h
        self.d_k = d_k
        self.d_v = d_v
        
        # some invariant to check, in the future, d_k, d_v does not need to be provided
        assert d_v == d_k
        assert d_model // h == d_k
        
        self.WQ = nn.Linear(d_model, d_model)
        self.WK = nn.Linear(d_model, d_model)
        self.WV = nn.Linear(d_model, d_model)
        
        self.attn = Attention()
        
        self.fc = nn.Linear(d_model, d_model)
    
    def forward(self, query, key, value, mask=None):
        """
            query: N x seq_len x D_model
            key: N x seq_len x D_model
            value: N x seq_len x D_model
        """
        
        N, seq = query.shape[0], query.shape[1]
        
        q = self.WQ(query)  # N x seq_len x d_model
        k = self.WK(key)  # N x seq_len x d_model
        v = self.WV(value)  # N x seq_len x d_model
        
        q = q.view(N, -1, self.h, self.d_k).transpose(1, 2)  # N x h x seq x d_k
        k = k.view(N, -1, self.h, self.d_k).transpose(1, 2)  # N x h x seq x d_k
        v = v.view(N, -1, self.h, self.d_v).transpose(1, 2)  # N x h x seq x d_v
        
        
        
        out = self.attn(q, k, v)  # N x h x seq x d_v
        out = out.transpose(1, 2)  # N x seq x h x d_v
        
        return self.fc(out.reshape(N, seq, -1))  # N x seq x D_model
        

In [20]:
N = 10
d_model = 10
seq_len = 10

d_k, d_v = 5, 5

query = torch.rand(N, seq_len, d_model)
key = torch.rand(N, seq_len, d_model)
value = torch.rand(N, seq_len, d_model)

# WQ = nn.Linear(3, d_model, d_k)
# WQ(query)

mh = MultiHead(d_model, d_k, d_v, h=2)
mh(query, key, value)

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

In [None]:
WQ = nn.Linear(3, )

In [130]:
class MultiHead(nn.Module):
    def __init__(self, d_model, d_k, d_v, num_heads=2):
        super().__init__()
        
        self.h = num_heads
        
        self.attns = nn.ModuleList()
        self.WQ = nn.ModuleList()
        self.WK = nn.ModuleList()
        self.WV = nn.ModuleList()
        
        for _ in range(self.h):
            self.attns.append(Attention())
            self.WQ.append(nn.Linear(d_model, d_k))
            self.WK.append(nn.Linear(d_model, d_k))
            self.WV.append(nn.Linear(d_model, d_v))
        
        self.fc = nn.Linear(num_heads * d_v, d_model)
    
    def forward(self, query, key, value, mask=None):
        
        heads_out = []
        for i in range(self.h):
            heads_out.append(self.attns[i](self.WQ[i](query), self.WK[i](query), self.WV[i](value), mask=mask))
        
        return self.fc(
            torch.concat(tuple(heads_out), dim=-1)
        )

In [131]:
class AddAndNorm(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.layer_norm = nn.LayerNorm(d_model)
        
    def forward(self, out1, out2):
        return self.layer_norm(out1 + out2)

In [132]:
class FFN(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.fc1 = nn.Linear(d_model, 4 * d_model)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(4 * d_model, d_model)
    
    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

In [133]:
def subsequent_mask(size):
    return torch.from_numpy(
        np.triu(np.ones((size, size)), k=1).astype('uint8')
    ) == 0

In [134]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, d_k, d_v, num_heads=2):
        super().__init__()
        
        self.multi_head = MultiHead(d_model, d_k, d_v, num_heads=num_heads)
        self.add_norm1 = AddAndNorm(d_model)
        self.ffn = FFN(d_model)
        self.add_norm2 = AddAndNorm(d_model)
        
    def forward(self, x, mask):
        """
            x: N x seq_len x d_model
            mask: seq_len x seq_len, in encoder, we need this to prevent padding
        """
        out1 = self.add_norm1(x, self.multi_head(x, x, x, mask=mask))
        out2 = self.add_norm2(out1, self.ffn(out1))
        return out2

In [135]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, d_k, d_v, num_heads=2):
        super().__init__()
        self.masked_multi_head = MultiHead(d_model, d_k, d_v, num_heads=num_heads)
        self.add_norm1 = AddAndNorm(d_model)
        self.multi_head = MultiHead(d_model, d_k, d_v, num_heads=num_heads)
        self.add_norm2 = AddAndNorm(d_model)
        self.ffn = FFN(d_model)
        self.add_norm3 = AddAndNorm(d_model)
    
    def forward(self, x, encoder_out, mask=None):
        """
            x: N x seq_len x d_model
            encoder_out: N x seq_len x d_model
            mask: seq_len x seq_len
        """
        out1 = self.add_norm1(
            x, self.masked_multi_head(x, x, x, mask=mask)
        )
        out2 = self.add_norm2(
            out1, self.multi_head(out1, encoder_out, encoder_out, mask=mask)
        )
        out3 = self.add_norm3(
            out2, self.ffn(out2)
        )
        return out3

In [144]:
class Transformer(nn.Module):
    def __init__(self, d_model, d_k, d_v, out_dim, num_heads=2):
        super().__init__()
        self.encoder = EncoderLayer(d_model, d_k, d_v, num_heads=num_heads)
        self.decoder = DecoderLayer(d_model, d_k, d_v, num_heads=num_heads)
        
        self.fc = nn.Linear(d_model, out_dim)
    
    def forward(self, src, tgt, src_mask, tgt_mask):
        encoder_out = self.encoder(src, src_mask)
        decoder_out = self.decoder(tgt, encoder_out, tgt_mask)
        return F.softmax(self.fc(decoder_out), dim=-1)

In [137]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len):
        super().__init__()
        
        pe = torch.ones((max_len, d_model)) 
        
        pos = torch.arange(0, max_len).reshape((max_len, 1))
        rng = torch.arange(0, d_model/2)
        rng = 2*rng / d_model
        div = torch.pow(10000, rng)
        
        pe[:, 0::2] = torch.sin(pos / div)
        pe[:, 1::2] = torch.cos(pos / div)
        
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        """
            x: N x seq_len x d_model
        """
        return x + self.pe[:x.shape[1], :].unsqueeze(0)

In [146]:
class TokenEmbedding(nn.Module):
    def __init__(self, voc_size, emb_size):
        super().__init__()
        self.emb = nn.Embedding(voc_size, emb_size)
        self.emb_size = emb_size
        
    def forward(self, x):
        return self.emb(x) / math.sqrt(self.emb_size)

Below is the function related to data processing

Below is the sections of playground for testing of the module above

In [48]:
d_model = 10
seq_len = 10

d_k, d_v = 5, 5

In [126]:
x = torch.rand(1, 5, d_model)
label = torch.tensor([0.0, 1.0, 1.0, 0.0, 0.0]).unsqueeze(dim=1)

In [124]:
pe = PositionalEncoding(d_model, 1000)
x = torch.rand(1, 5, d_model)
pe(x)

tensor([[[ 0.7941,  1.8423,  0.2766,  1.8735,  0.6056,  1.4991,  0.1369,
           1.0558,  0.1053,  1.1077],
         [ 1.8124,  0.6735,  0.2615,  1.1507,  0.8948,  1.0303,  0.3383,
           1.5239,  0.5827,  1.2680],
         [ 1.4656,  0.3187,  1.2982,  1.1240,  0.2975,  1.9091,  0.9909,
           1.1799,  0.5763,  1.1359],
         [ 0.1657, -0.8468,  1.0620,  1.0904,  0.8925,  1.0388,  0.6975,
           1.8768,  0.1383,  1.7975],
         [-0.0167, -0.1900,  0.6836,  1.7922,  0.2277,  1.7875,  0.3755,
           1.7339,  0.3565,  1.6054]]])

In [66]:
criterion = nn.MSELoss()
optimizer = optim.SGD(encoder.parameters(), lr=0.01, momentum=0.9)

optimizer.zero_grad()
pred = encoder(x)
loss = criterion(pred, label)
loss.backward()
optimizer.step()

  return F.mse_loss(input, target, reduction=self.reduction)


In [145]:
trans = Transformer(d_model, d_k, d_v, out_dim=20)
src = torch.rand(1, 5, d_model)
tgt = torch.rand(1, 5, d_model)
trans(src, tgt, src_mask=None, tgt_mask=None)

tensor([[[0.0470, 0.1035, 0.0756, 0.0151, 0.1349, 0.0103, 0.0362, 0.0578,
          0.0216, 0.1197, 0.0391, 0.0269, 0.0297, 0.0168, 0.0492, 0.0813,
          0.0374, 0.0171, 0.0220, 0.0589],
         [0.0228, 0.1159, 0.1013, 0.0301, 0.1040, 0.0168, 0.0687, 0.0217,
          0.0092, 0.0952, 0.0347, 0.0192, 0.0108, 0.0518, 0.1153, 0.0695,
          0.0241, 0.0187, 0.0231, 0.0470],
         [0.0498, 0.0562, 0.0134, 0.0172, 0.1349, 0.0250, 0.0134, 0.0504,
          0.0358, 0.0600, 0.0479, 0.0420, 0.0205, 0.1018, 0.0686, 0.0483,
          0.0659, 0.0143, 0.0210, 0.1136],
         [0.0408, 0.0766, 0.0582, 0.0123, 0.0890, 0.0110, 0.0442, 0.1523,
          0.0359, 0.1230, 0.0312, 0.0355, 0.0707, 0.0200, 0.0221, 0.0517,
          0.0258, 0.0195, 0.0220, 0.0583],
         [0.0191, 0.0756, 0.0329, 0.0282, 0.0898, 0.0279, 0.0453, 0.0530,
          0.0208, 0.0604, 0.0308, 0.0365, 0.0275, 0.1241, 0.1149, 0.0511,
          0.0391, 0.0152, 0.0376, 0.0702]]], grad_fn=<SoftmaxBackward0>)