<a href="https://colab.research.google.com/github/zubejda/attention_is_all_you_need_attempt/blob/main/eng_to_cz_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import torch.optim as optim
import numpy as np
import string
import math

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
eng_lines = []
cz_lines = []
chars_to_remove = ',\":``'
with open('ces.txt', 'r') as f:
    lines = f.readlines()
    for line in lines:
        eng, cz = line.split('\t')[:2]
        cz.rstrip(chars_to_remove)
        eng.rstrip(chars_to_remove)
        cz.lstrip(chars_to_remove)
        eng.lstrip(chars_to_remove)
        eng_lines.append(eng)
        cz_lines.append(cz)

# cz_lines = np.array(cz_lines)
# eng_lines = np.array(eng_lines)

In [None]:
PAD_TOKEN = '<PAD>'
PAD_INDEX = 0
SOS_TOKEN = '<SOS>'
SOS_INDEX = 1
EOS_TOKEN = '<EOS>'
EOS_INDEX = 2

In [None]:
class english_czech_dataset(Dataset):
    def __init__(self, english_lines, czech_lines, max_length=128):
        """
        Args:
            english_lines (list or np.array): List or an array of English lines.
            czech_lines (list or np.array): List or an array of Czech lines. - targets
        """
        self.english_lines = english_lines
        self.czech_lines = czech_lines
        self.max_length = max_length
        self.chars = string.ascii_lowercase + string.digits + string.punctuation + ' '

        self.char_to_idx = {PAD_TOKEN: PAD_INDEX, SOS_TOKEN: SOS_INDEX, EOS_TOKEN: EOS_INDEX}
        self.char_to_idx.update({ch: idx + 1 for idx, ch in enumerate(self.chars)})
        self.idx_to_char = {idx: ch for ch, idx in self.char_to_idx.items()}

    def __len__(self):
        """Return the total number of samples in the dataset"""
        return len(self.czech_lines)

    def vocab_len(self):
        return len(self.char_to_idx)

    def __getitem__(self, idx):
        """
        returns a tuple of english and czech tenosr
        """
        eng_encoded = self.encode_string(self.english_lines[idx])
        cz_encoded = self.encode_string(self.czech_lines[idx])

        eng_padded = self.pad_or_truncate(eng_encoded)
        cz_padded = self.pad_or_truncate(cz_encoded)

        eng_tensor = torch.tensor(eng_padded, dtype=torch.int32)
        cz_tensor = torch.tensor(cz_padded, dtype=torch.int32)

        return eng_tensor, cz_tensor

    def encode_string(self, s):
        return [self.char_to_idx[c] for c in s.lower() if c in self.char_to_idx]

    def decode_string(self, indices):
        return ''.join([self.idx_to_char[i] for i in indices if i in self.idx_to_char])

    def pad_or_truncate(self, encoded_sequence):
        """
        Pad or truncate a sequence to the specified maximum length.

        Args:
            encoded_sequence (list): List of encoded integers.

        Returns:
            List: Padded or truncated sequence.
        """
        if len(encoded_sequence) > self.max_length:
            return encoded_sequence[:self.max_length]

        return encoded_sequence + [PAD_INDEX] * (self.max_length - len(encoded_sequence))

In [None]:
max_seq_length = 128 # also the size for embeddings
dataset = english_czech_dataset(eng_lines, cz_lines, max_length=max_seq_length)
print(dataset[0][0].shape)
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
# print(dataset.decode_string(dataset[0][0].tolist()))
for batch_idx, (eng_batch, cz_batch) in enumerate(dataloader):
    print(eng_batch.shape, cz_batch.shape)
    break

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, embedding_dim, max_length=128):
        """
        Args:
            embedding_dim: Dimensionality of the embeddings (should match the model's embedding_dim).
            max_len: Maximum length of the input sequence (should cover the longest sequence you expect).
        """
        super(PositionalEncoding, self).__init__()
        self.embedding_dim = embedding_dim
        self.max_length = max_length

        pe = torch.zeros(self.max_length, self.embedding_dim)
        position = torch.arange(0, self.max_length, dtype=torch.float).unsqueeze(1)
        denominator = torch.exp(torch.arange(0, self.embedding_dim, 2).float() * (-math.log(10000.0) / self.embedding_dim))
        pe[:, 0::2] = torch.sin(position * denominator)
        pe[:, 1::2] = torch.cos(position * denominator)
        self.pe = pe.unsqueeze(0)  # Shape: [1, max_len, embedding_dim]

    def forward(self, x):
        seq_len = x.size(1)

        return self.pe[:, :seq_len, :].to(x.device)

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads # TODO: implement the variant with multiple heads

        self.q = nn.Linear(self.embedding_dim, self.embedding_dim)
        self.k = nn.Linear(self.embedding_dim, self.embedding_dim)
        self.v = nn.Linear(self.embedding_dim, self.embedding_dim)
        self.linear = nn.Linear(self.embedding_dim, self.embedding_dim)
        self.norm = nn.LayerNorm(self.embedding_dim)

    def forward(self, q_x, k_x, v_x, mask=None):
        q_vals = self.q(q_x)
        k_vals = self.k(k_x)
        v_vals = self.v(v_x)
        softmax_input = torch.matmul(q_vals, k_vals.transpose(1, 2)) / math.sqrt(self.embedding_dim)
        if mask is not None:
            softmax_input += mask

        softmax_output = torch.softmax(softmax_input, dim=-1)
        sdp_attention = torch.matmul(softmax_output, v_vals)
        mha_output = self.linear(sdp_attention)
        attention_output = v_x + mha_output
        attention_output = self.norm(attention_output)

        return attention_output

In [None]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, max_length=128, num_heads=8, masking=False):

        super(Encoder, self).__init__()
        self.vocab_size = vocab_size # amount of characters in the vocabulary
        self.embedding_size = embedding_size # size of the embeddings
        self.masking = masking # decide whether a padding mask is applied before computing softmax
        self.max_length = max_length

        self.embedding = nn.Embedding(self.vocab_size, self.embedding_size)
        self.pos_encoding = PositionalEncoding(self.embedding_size, max_length=self.max_length)
        self.att_block = MultiHeadAttention(self.embedding_size, 8)
        self.fc1 = nn.Linear(self.embedding_size, self.embedding_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(self.embedding_size, self.embedding_size)
        self.ff_net = nn.Sequential(
            self.fc1,
            self.relu,
            self.fc2
        ) # feedforward network on top of the attention layer
        self.norm = nn.LayerNorm(self.embedding_size)

    def forward(self, x):
        mask = torch.where(x == 0, float('-inf'), torch.tensor(0.0)).unsqueeze(1)
        if len(list(x.shape)) == 1:
            x = x.unsqueeze(0)
            mask = mask.squeeze(0)

        x = self.embedding(x)
        x += self.pos_encoding(x) # add the positional encoding to the embedded sequence
        if self.masking is False:
            mask = None
        x = self.att_block(x, x, x, mask=mask)

        x = self.ff_net(x) + x

        encoder_output = self.norm(x)

        return encoder_output, mask

In [None]:
embed_dim = max_seq_length
encoder = Encoder(dataset.vocab_len(), embed_dim, masking=True).to(device)
print('eng', eng_batch[0, 0:5])

enc_output, enc_mask = encoder(eng_batch.to(device))
print(enc_output.shape)

In [None]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, embedding_layer, max_length=128, masking=False):

        super(Decoder, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.masking = masking
        self.max_length = max_length

        self.embedding = embedding_layer
        self.pos_encoding = PositionalEncoding(self.embedding_size, max_length=self.max_length)
        self.mmha = MultiHeadAttention(embedding_size, 8)
        self.mha = MultiHeadAttention(embedding_size, 8)
        self.fc1 = nn.Linear(self.embedding_size, self.embedding_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(self.embedding_size, self.embedding_size)
        self.ff_net = nn.Sequential(
            self.fc1,
            self.relu,
            self.fc2
        ) # feedforward network on top of the attention layer
        self.norm = nn.LayerNorm(self.embedding_size)

        self.linear = nn.Linear(self.embedding_size, self.vocab_size) # linear layer on top of the decoder
        self.linear.weight = self.embedding.weight # weight sharing for efficiency
        self.softmax = nn.Softmax(dim=2)

    def forward(self, target, encoder_output, mha_mask=None):
        # unsqueeze for using it across a batch, if we had multiple attention heads one more unsqueeze would be necessary
        mmha_mask = torch.triu(torch.full((self.embedding_size, self.embedding_size), float('-inf')), diagonal=1).unsqueeze(0)  # mask for mmha
        # mha_mask = torch.where(target == 0, float('-inf'), torch.tensor(0.0)).unsqueeze(1) # mask for mha

        x = self.embedding(target)
        x += self.pos_encoding(x)
        if self.masking is False:
            mha_mask = None
            mmha_mask = None
        mmha_output = self.mmha(x, x, x, mask=mmha_mask.to(device))

        x = self.mha(mmha_output, encoder_output, encoder_output, mask=mha_mask)

        x = self.ff_net(x) + x

        decoder_output = self.norm(x)

        logits = self.linear(decoder_output)
        output = self.softmax(logits)

        return logits, output

In [None]:
embed_dim = max_seq_length
decoder = Decoder(dataset.vocab_len(), embed_dim, encoder.embedding, masking=True).to(device)
print('eng', eng_batch[0, 0:5])

logits, dec_output = decoder(eng_batch.to(device), enc_output.to(device), enc_mask.to(device))
print(dec_output[0, 0:5, 0:5])
print(dec_output.shape)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total trainable parameters: {count_parameters(encoder)+count_parameters(decoder)}")

In [None]:
training_params = {
    'batch_size': 32,
    'num_epochs': 20
}

dataset = english_czech_dataset(eng_lines, cz_lines, max_length=max_seq_length)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size  # Remainder goes to validation

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size=training_params['batch_size'], shuffle=True, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=training_params['batch_size'], shuffle=False, drop_last=False)

optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_INDEX)

In [None]:
for epoch in range(training_params['num_epochs']):
    for src, tgt in dataloader:

        src = src.to(device)
        tgt = tgt.to(device)

        tgt_input = torch.cat([torch.full((training_params['batch_size'], 1), SOS_INDEX, dtype=torch.long, device=device), tgt[:, :-1]], dim=1)  # prepend <sos>
        tgt_output = torch.cat([tgt[:, 1:], torch.full((batch_size, 1), EOS_INDEX, dtype=torch.long, device=device)], dim=1)  # append <eos>

        optimizer.zero_grad()

        encoder_output, enc_mask = encoder(src)

        logits, predictions = decoder(tgt_input, encoder_output, enc_mask)

        logits = logits.view(-1, dataset.vocab_len())  # Flatten for CrossEntropy
        tgt_output = tgt_output.view(-1)  # Flatten target

        loss = criterion(logits, tgt_output)
        loss.backward()

        optimizer.step()

    print(f"Epoch [{epoch+1}/{training_params['num_epochs']}], Loss: {loss.item():.4f}")