In [1]:
import os
import math
import datetime
import time

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.utils import data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torchvision import transforms as T
from torchvision.datasets import ImageFolder
from torch.autograd import Variable
from torchvision.utils import save_image
from torch.backends import cudnn


## Multihead Attention

In [252]:
class ScaledDPAttention(nn.Module):
    def __init__(self, dim_key, dim_val, masked=False):
        super().__init__()
        self.dim_key = dim_key
        self.dim_val = dim_val
        self.masked = masked
        self.scale = math.sqrt(self.dim_val)
        self.sm = nn.Softmax(dim=2)
    
    def forward(self, q, k, v):
        B, L, D = q.size()
        output = torch.matmul(q, torch.transpose(k, 1, 2))
        output = torch.div(output, self.scale)
        if self.masked:
            mask = (-1*torch.ones(L,L)*float('inf')).triu(1)
            output += mask
#         output = output + q.size()[1]
        output = self.sm(output)
        output = torch.matmul(output, v)
        return output

class MultiHeadAttention(nn.Module):
    def __init__(self, dim_model, dim_key, dim_val, h, masked=False): # dim_key, dim_val = dim_model/h (?)
        super().__init__()
        self.h = h
        self.dim_model = dim_model
        self.dim_key = dim_key
        self.dim_val = dim_val
        self.dim_head = dim_model // h
        
        self.v_layers = []
        self.k_layers = []
        self.q_layers = []
        self.attention_layers = []
        for i in range(self.h):
            self.q_layers.append(nn.Linear(in_features=dim_model, out_features=self.dim_head, bias=False))
            self.k_layers.append(nn.Linear(in_features=dim_model, out_features=self.dim_head, bias=False))
            self.v_layers.append(nn.Linear(in_features=dim_model, out_features=self.dim_head, bias=False))
            self.attention_layers.append(ScaledDPAttention(self.dim_head, self.dim_head, masked))
        self.linear = nn.Linear(in_features = h*self.dim_head, out_features=dim_model, bias=False)
    
    def forward(self, Q, K, V):
        outs = []
        for i in range(self.h):
            q = self.q_layers[i](Q)
            k = self.k_layers[i](K)
            v = self.v_layers[i](V)
            o = self.attention_layers[i](q, k, v)
            outs.append(o)
        output = torch.cat(outs, dim=2) # check dimenson 
        return self.linear(output)

### Debug 


In [237]:
dim_model = 64
dim_key = 8
dim_val = 8
K = torch.ones(10,16,dim_model)
V = torch.ones(10,16,dim_model)
Q = torch.ones(10,16,dim_model)
sdp_attention = ScaledDPAttention(dim_key, dim_val, masked=False)
output = sdp_attention(K, V, Q)
output.shape

torch.Size([10, 16, 64])

In [253]:
masked_sdp_attention = ScaledDPAttention(dim_key, dim_val, masked=True)
output = masked_sdp_attention(K, V, Q)
output.shape
# output

torch.Size([10, 16, 64])

In [215]:
dim_model = 64
h = 8
mh_attentaion = MultiHeadAttention(dim_model, dim_key, dim_val, h)
output = mh_attentaion(Q, K, V)
output.shape

torch.Size([10, 16, 64])

## Positional Encoding

In [54]:
def positional_encoding(batch_size, n_dim, embedding_length):
    get_pos = lambda pos : [pos/(10000**(2*(i//2)/n_dim)) for i in range(n_dim)]
    code = np.array([get_pos(i) for i in range(embedding_length)])
    encoding = np.zeros((embedding_length, n_dim))
    encoding[:, 0::2] = np.sin(code[:, 0::2])
    encoding[:, 1::2] = np.cos(code[:, 1::2])
    return np.tile(encoding, (batch_size, 1,1))

### Debug

In [213]:
n_dim = 125
batch_size = 8
embedding_length = 125
pe = positional_encoding(batch_size, n_dim, embedding_length)
pe.shape

(125, 125)


(8, 125, 125)

## Encoder/Decoder

In [117]:
class EncoderLayer(nn.Module):
    def __init__(self, dim_model, dim_key, dim_val, dim_hidden, h=8):
        super().__init__()        
        self.attention = MultiHeadAttention(dim_model, dim_key, dim_val, h)
        self.FFN = nn.Sequential(
            nn.Linear(in_features=dim_model, out_features=dim_hidden),
            nn.ReLU(),
            nn.Linear(in_features=dim_hidden, out_features=dim_model)
        )
        self.norm = nn.LayerNorm(dim_model)
    
    def forward(self, X):
        A = self.attention(Q = X, K = X, V = X)
        A = self.norm(A + X)
        F = self.FFN(A)
        return self.norm(F + A)
    


In [259]:
class DecoderLayer(nn.Module):
    def __init__(self, dim_model, dim_key, dim_val, dim_hidden, h=8):
        super().__init__()
        self.masked_attention = MultiHeadAttention(dim_model, dim_key, dim_val, h, masked=True)
        self.attention = MultiHeadAttention(dim_model, dim_key, dim_val, h, masked=False)
        self.FFN = nn.Sequential(
            nn.Linear(in_features=dim_model, out_features=dim_hidden),
            nn.ReLU(),
            nn.Linear(in_features=dim_hidden, out_features=dim_model)
        )
        self.norm = nn.LayerNorm(dim_model)
    
    def forward(self, inputs): # (X, encoder_feature)
        X, features = inputs
        masked_A = self.masked_attention(Q = X, K = X, V = X)
        masked_A = self.norm(masked_A + X)
        A = self.attention(Q = masked_A, K = features, V = features)
        A = self.norm(A + masked_A)
        F = self.FFN(A)
        F = self.norm(F + A)
        return (F, features)

### Debug

In [256]:
dim_hidden=64
dim_key = n_dim
dim_val = n_dim
dim_model = n_dim
X = torch.zeros(batch_size, embedding_length, n_dim)
encoder_layer = EncoderLayer(dim_model, dim_key, dim_val, dim_hidden, h)
encoder_output = encoder_layer(X)
encoder_output.shape

torch.Size([8, 125, 125])

In [220]:
decoder_layer = DecoderLayer(dim_model, dim_key, dim_val, dim_hidden, h)
decoder_input = (X, encoder_output)
output = decoder_layer(decoder_input)
display(output[0].shape)
display(output[1].shape)

torch.Size([8, 125, 125])

torch.Size([8, 125, 125])

## Transformer Model

In [255]:
class Transformer(nn.Module):
    def __init__(self, dim_model, dim_key, dim_val, dim_hidden, output_dim, N=6, h=8):
        super().__init__()
        encoder_layers = []
        decoder_layers = []

        for i in range(N):
            encoder_layers.append(EncoderLayer(dim_model, dim_key, dim_val, dim_hidden, h))
            decoder_layers.append(DecoderLayer(dim_model, dim_key, dim_val, dim_hidden, h)) 

        self.Encoder = nn.Sequential(*encoder_layers)
        self.Decoder = nn.Sequential(*decoder_layers)        
        self.FC = nn.Linear(in_features=dim_model, out_features=output_dim) # weights? bias?
        self.sm = nn.Softmax(dim=2)
    
    def forward(self, inputs, outputs):
        features = self.Encoder(inputs)
        d, _ = self.Decoder((outputs, features))
        return self.sm(self.FC(d))

### Debug

In [260]:
output_dim = 1000 # dictionary size
model = Transformer(dim_model, dim_key, dim_val, dim_hidden, output_dim)
model_out = model(X, X)
model_out.shape

torch.Size([8, 125, 1000])

## Beam Search

TODO