In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np

In [2]:
class ScaledDotProdAttention(nn.Module):
    def __init__(self, dropout = 0.1):
        super(ScaledDotProdAttention, self).__init__()

        self.dropout = dropout

    def forward(self, query, key, value, mask = None):

        attention_scores = torch.matmul(query, key.transpose(-2, -1)) / np.sqrt(query.size(-1))

        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask == 0, -1e10)

        attention = F.softmax(attention_scores, dim = -1)
        attention = self.dropout(attention)

        return torch.matmul(attention, value), attention

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, nhead, dropout = 0.1):
        super(MultiHeadAttention, self).__init__()

        self.d_model = d_model
        self.nhead = nhead
        self.d_k = d_model // nhead
        self.d_v = d_model // nhead

        self.linear_q = nn.Linear(d_model, d_model)
        self.linear_k = nn.Linear(d_model, d_model)
        self.linear_v = nn.Linear(d_model, d_model)

        self.scaled_dot_prod_attention = ScaledDotProdAttention(dropout)

        self.linear_layer = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)


    def forward(self, query, key, value, mask = None, key_padding = None):

        batch_size =  query.size(0)

        query = self.linear_q(query).view(batch_size, -1, self.nhead, self.d_k).transpose(1,2)
        key = self.linear_k(key).view(batch_size, -1, self.nhead, self.d_k).transpose(1,2)
        value = self.linear_v(value).view(batch_size, -1, self.nhead, self.d_v).transpose(1,2)

        output, attention_scores = self.scaled_dot_prod_attention(query, key, value)

        output_concat = output.transpose(1,2).contiguous().view(batch_size, -1, self.d_model)

        output_concat = self.linear_layer(output_concat)

        return self.dropout(output_concat)


In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout = 0.1, max_len = 100):
        super(PositionalEncoding, self).__init__()

        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype = torch.float ).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0,1)
        self.register_buffer('pe', pe) # to not change these values while training

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)
        

In [5]:
class PoswiseFeedForward(nn.Module):
    def __init__(self, d_model, d_mlp = 1024, dropout = 0.1):
        super(PoswiseFeedForward, self).__init__()

        self.linear_1 = nn.Linear(d_model, d_mlp)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_mlp, d_model)

    def forward(self, x):

        x = self.linear_1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.linear_2(x)

        return x
    

In [6]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps = 1e-5):
        super(LayerNorm, self).__init__()

        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):

        mean = x.mean(dim = -1, keepdim = True)
        std = x.std(dim = -1, keepdim = True)

        x = (x - mean) / (std + self.eps)
        x = self.gamma * x + self.beta

        return x
                                                                                        

In [7]:
class EncoderBlock(nn.Module):
    def __init__(self, d_model, nhead, d_mlp, dropout = 0.1):
        super(EncoderBlock, self).__init__()

        self.multi_head_attention = MultiHeadAttention(d_model, nhead, dropout)

        self.feed_forward = PoswiseFeedForward(d_model, d_mlp, dropout)

        self.layer_norm1 = LayerNorm(d_model)
        self.layer_norm2 = LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, src_mask = None, src_key_padding_mask = None, is_causal = False):
         # Multi-head attention
        x2 = self.multi_head_attention(x, x, x, mask = src_mask, key_padding_mask = src_key_padding_mask)[0]
        x2 = self.layer_norm1(x2)
        x = x + self.dropout1(x2)
        # Feed-forward network
        x2 = self.feed_forward(x)
        x2 = self.layer_norm2(x2)
        x = x + self.dropout2(x2)

        return x
        

In [8]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model, nhead, d_mlp, dropout = 0.1):
        super(DecoderBlock, self).__init__()
        
        self.masked_multi_head_attention = MultiHeadAttention(d_model, nhead, dropout)
        self.multi_head_attention = MultiHeadAttention(d_model, nhead, dropout)

        self.feed_forward = PoswiseFeedForward(d_model, d_mlp, dropout)

        self.layer_norm1 = LayerNorm(d_model)
        self.layer_norm2 = LayerNorm(d_model)
        self.layer_norm3 = LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, tgt, memory, tgt_mask = None, memory_mask = None, tgt_key_padding_mask = None, memory_key_padding_mask = None):
        # Masked multi-head attention
        tgt2 = self.masked_multi_head_attention(tgt, tgt, tgt, mask = tgt_mask, key_padding_mask = tgt_key_padding_mask)[0]
        tgt2 = self.layer_norm1(tgt2)
        tgt = tgt + self.dropout1(tgt2)
        # Multi-head attention with encoder output
        tgt2 = self.multi_head_attention(tgt, memory, memory, mask = memory_mask, key_padding_mask = memory_key_padding_mask)[0]
        tgt2 = self.layer_norm2(tgt2)
        tgt = tgt + self.dropout2(tgt2)
        # Feed-forward network
        tgt2 = self.feed_forward(tgt)
        tgt2 = self.layer_norm3(tgt2)
        tgt = tgt + self.dropout3(tgt2)

        return tgt

In [9]:
class TransformerModel(nn.Module):
    def __init__(self, d_model, nhead, n_encoder, n_decoder, d_mlp, max_len, vocab_size, pad_idx, dropout = 0.1):
        super(TransformerModel, self).__init__()

        self.d_model = d_model

        # Encoder
        encoder_layer = EncoderBlock(d_model, nhead, d_mlp, dropout)
        encoder_norm = LayerNorm(d_model)
        self.encoder = nn.TransformerEncoder(encoder_layer, n_encoder, encoder_norm)

        # Decoder
        decoder_layer = DecoderBlock(d_model, nhead, d_mlp, dropout)
        decoder_norm = LayerNorm(d_model)
        self.decoder = nn.TransformerDecoder(decoder_layer, n_decoder, decoder_norm)

        # Positional Encoding
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_len)

        # Embedding layers for input and output
        self.input_embedding = nn.Embedding(vocab_size, d_model, padding_idx = pad_idx)
        self.output_embedding = nn.Embedding(vocab_size, d_model, padding_idx = pad_idx)

        # Final Linear Layer
        self.linear = nn.Linear(d_model, vocab_size)


    def forward(self, src, output, src_mask = None, tgt_mask = None, src_key_padding_mask = None, tgt_key_padding_mask = None, memory_key_padding_mask = None, is_causal = False):

        src = self.input_embedding(src) * np.sqrt(self.d_model)
        src = self.pos_encoder(src)

        encoder_outputs = self.encoder(src, mask = src_mask, src_key_padding_mask = src_key_padding_mask)

        output = self.output_embedding(output) * np.sqrt(self.d_model)
        output = self.pos_encoder(output)
        
        decoder_outputs = self.decoder(output, encoder_outputs, tgt_mask = tgt_mask, memory_mask = src_mask, tgt_key_padding_mask = tgt_key_padding_mask, memory_key_padding_mask = memory_key_padding_mask)

        
        outputs = self.linear(decoder_outputs)

        return outputs
        

In [10]:
d_model = 512
nhead = 1
n_encoder_layers = 1
n_decoder_layers = 1
d_mlp = 1024
max_len = 6
vocab_size = len(list("sdhfbdsksdkfsdfhsdfgiyasegfdhgdfgerwer"))
pad_idx = 0
dropout = 0.1
                 

In [11]:
model = TransformerModel(d_model, nhead, n_encoder_layers, n_decoder_layers, d_mlp, max_len, vocab_size, pad_idx, dropout)

