In [166]:
import torch
from torch import nn
import math
torch.__version__

'2.0.1+cpu'

In [167]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cpu


In [168]:
#Hyperparameters
num_heads = 8
embed_len = 512
batch_size = 8
stack_len = 6
dropout = 0.1

output_vocab_size = 7000 #just for testing
input_vocab_size = 7000 #just for testing

Bulid the embedding block

In [169]:
class InputEmbedding(nn.Module):
    def __init__(self, input_vocab_size = input_vocab_size, embed_len = embed_len, dropout = dropout, device = device):
        super(InputEmbedding, self).__init__()
        self.input_vocab_size = input_vocab_size
        self.embed_len = embed_len
        self.dropout = dropout
        self.device = device

        self.firstEmbedding = nn.Embedding(self.input_vocab_size, self.embed_len)
        self.secondEmbedding = nn.Embedding(self.input_vocab_size, self.embed_len)
        self.dropoutLayer = nn.Dropout()

    def forward(self, input):
        first_embedding = self.firstEmbedding(input)
        batch_size, seq_len = input.shape

        positions_vector = torch.arange(0, seq_len).expand(batch_size, seq_len).to(device)
        positional_encoding = self.secondEmbedding(positions_vector)
        return self.dropoutLayer(first_embedding + positional_encoding)

In [170]:
input = torch.randint(10, (8,20)).to(device)

embedding = InputEmbedding().to(device)
out = embedding(input).to(device)
out.shape

torch.Size([8, 20, 512])

Building the Scaled Dot Product block

In [171]:
class ScaledDotProduct(nn.Module):
    def __init__(self, embed_len = embed_len, mask = None):
        super(ScaledDotProduct,self).__init__()
        self.embed_len = embed_len
        self.mask = mask
        self.dk = embed_len # dimension of keys and queries

        self.softmax = nn.Softmax(dim= 3)

    def forward(self, queries, keys, values):
        compatibility = torch.matmul(queries, torch.transpose(keys,2,3))
        compatibility = compatibility/math.sqrt(self.dk)

        compatibility = self.softmax(compatibility)

        # apply a mask for the decoder
        if self.mask is not None:
            compatibility = torch.tril(compatibility)


        return torch.matmul(compatibility, torch.transpose(values, 1, 2))


Building MultiHeaded implementation block.

In [173]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads = num_heads, embed_len = embed_len, batch_size= batch_size, mask = None):
        super(MultiHeadAttention, self).__init__()
        self.num_heads =num_heads
        self.embed_len = embed_len
        self.batch_size =  batch_size
        self.mask = mask
        self.head_length = int(self.embed_len/ self.num_heads)

        self.q_in = self.v_in = self.k_in = self.embed_len

        #Linear layers as input to the multihead attention
        self.q_linear = nn.Linear(int(self.q_in), int(self.q_in))
        self.k_linear = nn.Linear(int(self.k_in), int(self.k_in))
        self.v_linear = nn.Linear(int(self.v_in), int(self.v_in))

        #Activate mask for the decoder
        if self.mask is not None:
            self.attention = ScaledDotProduct(mask=True)
        else:
            self.attention =ScaledDotProduct()

        self.output_linear = nn.Linear(self.q_in, self.q_in)

    
    def forward(self, queries, keys, values):

        # queries shape = (8, 20, 512)
        # we need to reshape (batch_size, seq-len, embed-len) to (batch_size, seq_len, num_heads, head_length)
        # output should be reshaped into (8, 20, 8, 64)
        #after transpose the output would become(8,8,20,64)
        queries = self.q_linear(queries).reshape(
            self.batch_size, -1, self.num_heads, self.head_length)
        queries = queries.transpose(1,2)

        #keys shape = (batch_size, num_heads, seq_len, head_length)
        keys =self.k_linear(keys).reshape(
            self.batch_size, -1, self.num_heads, self.head_length)
        keys = keys.transpose(1,2)

        # values shape = (batch_size, num_heads, seq_len, head_lengh)
        values = self.v_linear(values).reshape(
            self.batch_size, -1, self.num_heads, self.head_length)
        # QK result dimension(batch_size, num_heads, seq_len) -> (8,8,20,20)
        # QK matmul V -> (8,8,20,20)  matmul (8,8,20,64) -> (8,8,20,64)
        # final shape should be (batch_size, seq_len, embed_len)
        sdp_output = self.attention.forward(queries, keys, values).transpose(1,2).reshape(
            self.batch_size, -1, self.num_heads*self.head_length
        )

        #output has size(8, 20,512)
        return self.output_linear(sdp_output)

In [174]:
class EncoderBlock(nn.Module):
    def __init__(self,embed_len= embed_len, dropout =
                 dropout):
        super(EncoderBlock, self).__init__()
        self.embed_len = embed_len
        self.dropout =dropout
        self.multihead =MultiHeadAttention()
        self.firstNorm = nn.LayerNorm(self.embed_len)
        self.secondNorm = nn.LayerNorm(self.embed_len)
        self.dropoutLayer = nn.Dropout(p = self.dropout)

        self.feedForward = nn.Sequential(
            nn.Linear(self.embed_len, self.embed_len*4),
            nn.ReLU(),            
            nn.Linear(self.embed_len*4, self.embed_len*4)
        )


    def forward(self, queries, keys, values):
        attention_output = self.multihead.forward(queries, keys, values)
        attention_output = self.dropoutLayer(attention_output)
        first_sublayer_output = self.firstNorm(attention_output +queries)

        ff_output = self.feedForward(first_sublayer_output)
        ff_output = self.dropoutLayer(ff_output)

        return self.secondNorm(ff_output + first_sublayer_output)

In [179]:
class DecoderBlock(nn.Module):
    def __init__(self,embed_len= embed_len, dropout =
                 dropout):
        super(DecoderBlock, self).__init__()
        self.embed_len = embed_len
        self.dropout =dropout
        self.multihead =MultiHeadAttention( mask= True)
        self.firstNorm = nn.LayerNorm(self.embed_len)
        self.dropoutLayer = nn.Dropout(p = self.dropout)

        self.encoderBlock = EncoderBlock()

    def forward(self, queries, keys, values):
        masked_multihead_output = self.multihead.forward(queries, queries, queries)
        masked_multihead_output = self.dropoutLayer(masked_multihead_output)
        first_sublayer_output = self.firstNorm(masked_multihead_output +queries)

        return self.encoderBlock(first_sublayer_output, keys, values)

implement full transformer

In [180]:
class Transformer(nn.Module):
    def __init__(self, embed_len= embed_len, stack_len = stack_len, device= device, output_vocab_size=output_vocab_size):
        super(Transformer, self).__init__()

        self.embed_len= embed_len
        self.stack_len = stack_len
        self.device= device
        self.output_vocab_size=output_vocab_size

        self.embedding = InputEmbedding().to(self.device)
        self.encstack = nn.ModuleList(EncoderBlock() for i in range(self.stack_len)).to(self.device)
        self.decstack = nn.ModuleList(DecoderBlock() for i in range(self.stack_len)).to(self.device)

        self.finalLinear = nn.Linear(self.embed_len, self.output_vocab_size).to(device)
        self.softmax = nn.Softmax()

    
    def forward(self, test_input, test_output):
        enc_output = self.embedding.forward(test_input)

        for enc_layer in self.encstack:
            enc_output = enc_layer.forward(enc_output,enc_output,enc_output)

        dec_output = self.embedding(test_output)
        for dec_layer in self.decstack:
            dec_output = dec_layer.forward(dec_output,enc_output,enc_output)

        final_output =self.finalLinear(dec_output)

        return self.softmax(final_output)


In [181]:
input_token = torch.randint(10, (batch_size, 30)).to(device)
output_target = torch.randint(10, (batch_size, 20)).to(device)

transformer = Transformer().to(device)
transformer_output = transformer.forward(input_token, output_target)

RuntimeError: The size of tensor a (2048) must match the size of tensor b (512) at non-singleton dimension 2