In [1]:
import math
import copy

import torch
from torch import nn
from torch import optim
from torch.utils import data
from torch.nn import functional as F

In [2]:
def get_device():
    # if torch.cuda.is_available():
    #     return torch.device("cuda")
    # elif torch.backends.mps.is_available():
    #     return torch.device("mps")
    return torch.device("cpu")

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        input_dim,
        num_heads,
    ):
        super(MultiHeadAttention, self).__init__()
        # asserting if input dimension is divisible by number of heads
        assert (
            input_dim % num_heads == 0
        ), "input dimensions must be divisible by number of heads"

        self.input_dim = input_dim
        self.num_heads = num_heads
        self.dim_key_query_value = input_dim // num_heads

        # linear layers for transforming inputs
        self.W_q = nn.Linear(self.input_dim, self.input_dim)  # query transformation
        self.W_k = nn.Linear(self.input_dim, self.input_dim)  # key transformation
        self.W_v = nn.Linear(self.input_dim, self.input_dim)  # value transformation
        self.W_o = nn.Linear(self.input_dim, self.input_dim)  # output transformation

    def split_heads(
        self,
        x,
    ):
        batch_size, sequence_len, input_dim = x.shape
        return x.view(
            batch_size, sequence_len, self.num_heads, self.dim_key_query_value
        ).transpose(1, 2)

    def scaled_dot_product_attention(
        self,
        query,
        key,
        value,
        mask=None,
    ):
        # this calculation shows scores for how a word is relevant to others in the sentence
        attention_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(
            self.dim_key_query_value
        )
        if mask is not None:
            # this basically masks the words that represents future words
            attention_scores = attention_scores.masked_fill(mask == 0, float("-inf"))
        # this calculation shows probabilities for how a word is relevant to others in the sentence
        attention_probabilities = torch.softmax(attention_scores, dim=-1)
        output = torch.matmul(attention_probabilities, value)
        return output

    def combine_heads(
        self,
        x,
    ):
        batch_size, num_heads, sequence_len, dim_key_query_value = x.shape
        return (
            x.transpose(1, 2).contiguous().view(batch_size, sequence_len, self.input_dim)
        )

    def forward(
        self,
        query,
        key,
        value,
        mask=None,
    ):
        # apply linear transformation and split heads
        query = self.split_heads(self.W_q(query))
        key = self.split_heads(self.W_k(key))
        value = self.split_heads(self.W_v(value))

        # scaled dot-product attention
        attention_output = self.scaled_dot_product_attention(query, key, value, mask)
        output = self.W_o(self.combine_heads(attention_output))
        return output

In [4]:
class PositionWiseFeedForward(nn.Module):
    def __init__(
        self,
        input_dim,
        hidden_size,
    ):
        super(PositionWiseFeedForward, self).__init__()
        self.ff_layer = nn.Sequential(
            nn.Linear(input_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, input_dim),
        )

    def forward(
        self,
        x,
    ):
        return self.ff_layer(x)

In [5]:
class PositionalEncoding(nn.Module):
    def __init__(
        self,
        input_dim,
        max_sequence_len,
    ):
        super(PositionalEncoding, self).__init__()
        position_encodings = torch.zeros(max_sequence_len, input_dim)
        positions = torch.arange(0, max_sequence_len, dtype=torch.float32).unsqueeze(
            dim=1
        )
        div_term = torch.exp(
            torch.arange(0, input_dim, 2).float() * -(math.log(10000.0) / input_dim)
        )
        position_encodings[:, 0::2] = torch.sin(positions * div_term)
        position_encodings[:, 1::2] = torch.cos(positions * div_term)
        self.register_buffer("position_encodings", position_encodings.squeeze(dim=0))

    def forward(
        self,
        x,
    ):
        # get the positional encodings as per the shape of the input, it should not exceed the input_dim
        return x + self.position_encodings[: x.size(1)]

In [6]:
class EncoderLayer(nn.Module):
    def __init__(
        self,
        input_dim,
        num_heads,
        ff_hidden_dim,
        dropout=0.2,
    ):
        super(EncoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(input_dim, num_heads)
        self.feed_forward = PositionWiseFeedForward(input_dim, ff_hidden_dim)
        self.norm1 = nn.LayerNorm(input_dim)
        self.norm2 = nn.LayerNorm(input_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x,
        mask,
    ):
        attention_output = self.self_attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attention_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [7]:
class DecoderLayer(nn.Module):
    def __init__(
        self,
        input_dim,
        num_heads,
        ff_hidden_dim,
        dropout=0.2,
    ):
        super(DecoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(input_dim, num_heads)
        self.cross_attention = MultiHeadAttention(input_dim, num_heads)
        self.feed_forward = PositionWiseFeedForward(input_dim, ff_hidden_dim)
        self.norm1 = nn.LayerNorm(input_dim)
        self.norm2 = nn.LayerNorm(input_dim)
        self.norm3 = nn.LayerNorm(input_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x,
        encoded_output,
        source_mask,
        target_mask,
    ):
        attention_output = self.self_attention(x, x, x, target_mask)
        x = self.norm1(x + self.dropout(attention_output))
        attention_output = self.cross_attention(
            x, encoded_output, encoded_output, source_mask
        )
        x = self.norm2(x + self.dropout(attention_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

In [8]:
class Transformer(nn.Module):
    def __init__(
        self,
        source_vocab_size,
        target_vocab_size,
        input_dim,
        num_heads,
        num_layers,
        ff_hidden_dim,
        max_seq_length,
        dropout,
        device,
    ):
        super(Transformer, self).__init__()
        self.device = device
        self.encoder_embeddings = nn.Embedding(source_vocab_size, input_dim)
        self.decoder_embeddings = nn.Embedding(target_vocab_size, input_dim)
        self.positional_encoding = PositionalEncoding(input_dim, max_seq_length)
        self.encoder_layers = nn.ModuleList(
            [
                EncoderLayer(input_dim, num_heads, ff_hidden_dim, dropout)
                for _ in range(num_layers)
            ]
        )
        self.decoder_layers = nn.ModuleList(
            [
                DecoderLayer(input_dim, num_heads, ff_hidden_dim, dropout)
                for _ in range(num_layers)
            ]
        )
        self.fc_layer = nn.Linear(input_dim, target_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(
        self,
        src,
        tgt,
    ):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = torch.tril(
            torch.ones(1, seq_length, seq_length), diagonal=1
        ).bool().to(self.device)
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def forward(
        self,
        source,
        target,
    ):
        source_mask, target_mask = self.generate_mask(source, target)
        source_embedded = self.dropout(
            self.positional_encoding(self.encoder_embeddings(source))
        )
        target_embedded = self.dropout(
            self.positional_encoding(self.decoder_embeddings(target))
        )
        encoded_output = source_embedded
        for encode_layer in self.encoder_layers:
            encoded_output = encode_layer(
                encoded_output,
                source_mask,
            )

        decoded_output = target_embedded
        for decode_layer in self.decoder_layers:
            decoded_output = decode_layer(
                decoded_output,
                encoded_output,
                source_mask,
                target_mask,
            )
        output = self.fc_layer(decoded_output)
        return output

In [9]:
src_vocab_size = 5000
tgt_vocab_size = 5000
input_dim = 512
num_heads = 8
num_layers = 6
ff_dim = 2048
max_seq_length = 100
dropout = 0.1
batch_size = 64
learning_rate = 1e-3
device = get_device()

In [10]:
transformer = Transformer(
    src_vocab_size,
    tgt_vocab_size,
    input_dim,
    num_heads,
    num_layers,
    ff_dim,
    max_seq_length,
    dropout,
    device,
).to(device)

In [11]:
# sample dummy data
src_data = torch.randint(1, src_vocab_size, (64, max_seq_length)).to(device)
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length)).to(device)

In [12]:
loss_function = nn.CrossEntropyLoss(ignore_index=0)
optimiser = optim.AdamW(transformer.parameters(), lr=learning_rate)

In [13]:
transformer.train()
for epoch in range(100):
    optimiser.zero_grad()
    output = transformer(src_data, tgt_data[:, :-1])
    loss = loss_function(output.reshape(-1, output.shape[-1]), tgt_data[:, 1:].reshape(-1))
    loss.backward()
    optimiser.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

Epoch: 1, Loss: 8.688248634338379
Epoch: 2, Loss: 8.505934715270996
Epoch: 3, Loss: 8.272440910339355
Epoch: 4, Loss: 8.581503868103027
Epoch: 5, Loss: 8.18865966796875
Epoch: 6, Loss: 8.151589393615723
Epoch: 7, Loss: 8.118162155151367
Epoch: 8, Loss: 8.092368125915527
Epoch: 9, Loss: 8.048149108886719
Epoch: 10, Loss: 7.892206192016602
Epoch: 11, Loss: 7.7260212898254395
Epoch: 12, Loss: 8.1525239944458
Epoch: 13, Loss: 8.063343048095703
Epoch: 14, Loss: 7.964765548706055
Epoch: 15, Loss: 7.694516658782959
Epoch: 16, Loss: 7.56680154800415
Epoch: 17, Loss: 7.284770965576172
Epoch: 18, Loss: 7.541484355926514
Epoch: 19, Loss: 7.086330890655518
Epoch: 20, Loss: 6.932593822479248
Epoch: 21, Loss: 6.794715404510498
Epoch: 22, Loss: 6.4900054931640625
Epoch: 23, Loss: 6.193963050842285
Epoch: 24, Loss: 5.933253288269043
Epoch: 25, Loss: 5.922804832458496
Epoch: 26, Loss: 6.284256935119629
Epoch: 27, Loss: 6.0267252922058105
Epoch: 28, Loss: 5.376728534698486
Epoch: 29, Loss: 5.00747632980

In [14]:
transformer.eval()
val_src_data = torch.randint(1, src_vocab_size, (64, max_seq_length)).to(device)
val_tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length)).to(device)

with torch.no_grad():
    val_output = transformer(val_src_data, val_tgt_data[:, :-1])
    val_loss = loss_function(val_output.contiguous().view(-1, tgt_vocab_size), val_tgt_data[:, 1:].contiguous().view(-1))
    print(f"Validation Loss: {val_loss.item()}")

Validation Loss: 11.252229690551758
