In [3]:
#pip3 install torch torchvision torchaudio

In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

In [20]:
class MultiHeadAttention(nn.Module):
  def __init__(self,model_dimension,n_heads):
    super(MultiHeadAttention,self).__init__()
    assert model_dimension % n_heads == 0

    #initialize the dimension
    self.model_dimension = model_dimension
    self.n_heads = n_heads
    #now, we determine the dimension of each head
    self.dimension = model_dimension // n_heads

    #we use linear layers for transforming inputs into Q,K,V and output
    self.q = nn.Linear(model_dimension,model_dimension)
    self.k = nn.Linear(model_dimension,model_dimension)
    self.v = nn.Linear(model_dimension,model_dimension)
    self.o = nn.Linear(model_dimension,model_dimension)

  def scaled_dot_product_attention(self,Q,K,V,mask=None):
    attn_scores = torch.matmul(Q,K.transpose(-2,-1))/math.sqrt(self.dimension)
    if mask is not None:
      attn_scores.masked_fill(mask==0,-1e9)
    attn_prob = torch.softmax(attn_scores,dim=-1)
    output = torch.matmul(attn_prob,V)
    return output

  def split_heads(self,x):
    batch_size, seq_length,d_model = x.size()
    return x.view(batch_size,seq_length,self.n_heads,self.dimension).transpose(1,2)

  def combine_heads(self,x):
    batch_size, _, seq_length, d_k = x.size()
    return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

  def forward(self,Q,K,V,mask=None):
    Q = self.split_heads(self.q(Q))
    K = self.split_heads(self.q(K))
    V = self.split_heads(self.q(V))

    attn_output = self.scaled_dot_product_attention(Q,K,V,mask)
    output = self.o(self.combine_heads(attn_output))
    return output

In [21]:
class PositionWiseFeedForward(nn.Module):
  #model_dimension = Dimensionality of the model's input and output.
  #fF_dimension = Dimensionality of the inner layer in the feed-forward network
  def __init__(self,model_dimension,ff_dimension):
    super(PositionWiseFeedForward, self).__init__()
    self.fc1 = nn.Linear(model_dimension, ff_dimension)
    self.fc2 = nn.Linear(ff_dimension, model_dimension)
    self.relu = nn.ReLU()

  def forward(self,x):
    return self.fc2(self.relu(self.fc1(x)))


In [22]:
class PositionalEncoding(nn.Module):
  def __init__(self,model_dimension,max_seq_length):
    super(PositionalEncoding, self).__init__()
    pe = torch.zeros(max_seq_length,model_dimension)
    position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, model_dimension, 2).float() * -(math.log(10000.0) / model_dimension))

#The sine function is applied to the even indices and the cosine function to the odd indices of pe.
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    self.register_buffer('pe',pe.unsqueeze(0))

  def forward(self,x):
    return x + self.pe[:,:x.size(1)]

In [23]:
class EncoderLayer(nn.Module):
  def __init__(self,model_dimension,n_heads,ff_dimension,dropout):
    super(EncoderLayer, self).__init__()
    self.self_attn = MultiHeadAttention(model_dimension,n_heads)
    self.feed_forward = PositionWiseFeedForward(model_dimension,ff_dimension)
    self.norm1 = nn.LayerNorm(model_dimension)
    self.norm2 = nn.LayerNorm(model_dimension)
    self.dropout = nn.Dropout(dropout)

  def forward(self,x,mask):
    attn_output = self.self_attn(x,x,x,mask)
    x = self.norm1(x+self.dropout(attn_output))
    ff_output = self.feed_forward(x)
    x = self.norm2(x+self.dropout(ff_output))
    return x


Now, we will be working on the decoder blocks

In [24]:
class DecoderLayer(nn.Module):
  def __init__(self,model_dimension,n_heads,ff_dimension,dropout):
    super(DecoderLayer,self).__init__()
    self.self_attn = MultiHeadAttention(model_dimension,n_heads)
    self.cross_attn = MultiHeadAttention(model_dimension,n_heads)
    self.feed_forward = PositionWiseFeedForward(model_dimension,ff_dimension)
    self.norm1 = nn.LayerNorm(model_dimension)
    self.norm2 = nn.LayerNorm(model_dimension)
    self.norm3 = nn.LayerNorm(model_dimension)
    self.dropout = nn.Dropout(dropout)

  def forward(self,x,encoder_output,src_mask,target_mask):
    attn_output = self.self_attn(x,x,x,target_mask)
    x = self.norm1(x+self.dropout(attn_output))
    attn_output = self.cross_attn(x,encoder_output,encoder_output,src_mask)
    x = self.norm2(x+self.dropout(attn_output))
    ff_output = self.feed_forward(x)
    x = self.norm3(x+self.dropout(ff_output))
    return x

Combining the Encoder and Decoder layers to create the complete Transformer network

In [25]:
class Transformer(nn.Module):
  def __init__(self,src_vocab_size,tgt_vocab_size,model_dimension,n_heads,num_layers,ff_dimension,max_seq_length,dropout):
    super(Transformer,self).__init__()
    self.encoder_embedding = nn.Embedding(src_vocab_size,model_dimension)
    self.decoder_embedding = nn.Embedding(tgt_vocab_size,model_dimension)
    self.positional_encoding = PositionalEncoding(model_dimension,max_seq_length)

    self.encoder_embedding = nn.ModuleList([EncoderLayer(model_dimension,n_heads,ff_dimension,dropout) for _ in range(num_layers)])
    self.decoder_embedding = nn.ModuleList([DecoderLayer(model_dimension,n_heads,ff_dimension,dropout) for _ in range(num_layers)])

    #self.fc: Final fully connected (linear) layer mapping to target vocabulary size.
    self.fc = nn.Linear(model_dimension,tgt_vocab_size)
    self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def forward(self,src,tgt):
      src_mask, tgt_mask = self.generate_mask(src, tgt)
      src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
      tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))
      enc_output = src_embedded
      for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

      dec_output = tgt_embedded
      for dec_layer in self.decoder_layers:
          dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

      output = self.fc(dec_output)
      return output

Training the Model

In [26]:
src_vocab_size = 5000
tgt_vocab_size = 5000
model_dimension = 512
n_heads = 8
num_layers = 6
ff_dimension = 2048
max_seq_length = 100
dropout = 0.1

transformer = Transformer(src_vocab_size, tgt_vocab_size, model_dimension, n_heads, num_layers, ff_dimension, max_seq_length, dropout)

# Generate random sample data
src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

In [27]:
transformer = Transformer(src_vocab_size, tgt_vocab_size, model_dimension, n_heads, num_layers, ff_dimension, max_seq_length, dropout)

In [29]:
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

for epoch in range(100):
    optimizer.zero_grad()
    output = transformer(src_data, tgt_data[:, :-1])
    loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")