<a href="https://colab.research.google.com/github/zubejda/attention_is_all_you_need_attempt/blob/main/eng_to_cz_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import string
import math

In [None]:
eng_lines = []
cz_lines = []
chars_to_remove = ',\":``'
with open('ces.txt', 'r') as f:
    lines = f.readlines()
    for line in lines:
        eng, cz = line.split('\t')[:2]
        cz.rstrip(chars_to_remove)
        eng.rstrip(chars_to_remove)
        cz.lstrip(chars_to_remove)
        eng.lstrip(chars_to_remove)
        eng_lines.append(eng)
        cz_lines.append(cz)

# cz_lines = np.array(cz_lines)
# eng_lines = np.array(eng_lines)

In [None]:
class DataLoader(Dataset):
    def __init__(self, english_lines, czech_lines, max_length=128):
        """
        Args:
            english_lines (list or np.array): List or an array of English lines.
            czech_lines (list or np.array): List or an array of Czech lines. - targets
        """
        self.english_lines = english_lines
        self.czech_lines = czech_lines
        self.max_length = max_length
        self.chars = string.ascii_lowercase + string.digits + string.punctuation + ' '
        self.PAD_TOKEN = '<PAD>'
        self.PAD_INDEX = 0

        self.char_to_idx = {self.PAD_TOKEN: self.PAD_INDEX}
        self.char_to_idx.update({ch: idx + 1 for idx, ch in enumerate(self.chars)})
        self.idx_to_char = {idx: ch for ch, idx in self.char_to_idx.items()}

    def __len__(self):
        """Return the total number of samples in the dataset"""
        return len(self.czech_lines)

    def vocab_len(self):
        return len(self.chars)

    def __getitem__(self, idx):
        """
        returns a tuple of english and czech tenosr
        """
        eng_encoded = self.encode_string(self.english_lines[idx])
        cz_encoded = self.encode_string(self.czech_lines[idx])

        eng_padded = self.pad_or_truncate(eng_encoded)
        cz_padded = self.pad_or_truncate(cz_encoded)

        eng_tensor = torch.tensor(eng_padded, dtype=torch.int32)
        cz_tensor = torch.tensor(cz_padded, dtype=torch.int32)

        return eng_tensor, cz_tensor

    def encode_string(self, s):
        return [self.char_to_idx[c] for c in s.lower() if c in self.char_to_idx]

    def decode_string(self, indices):
        return ''.join([self.idx_to_char[i] for i in indices if i in self.idx_to_char])

    def pad_or_truncate(self, encoded_sequence):
        """
        Pad or truncate a sequence to the specified maximum length.

        Args:
            encoded_sequence (list): List of encoded integers.

        Returns:
            List: Padded or truncated sequence.
        """
        if len(encoded_sequence) > self.max_length:
            return encoded_sequence[:self.max_length]

        return encoded_sequence + [self.PAD_INDEX] * (self.max_length - len(encoded_sequence))

In [None]:
dataset = DataLoader(eng_lines, cz_lines)
print(dataset.vocab_len())
print(dataset[0][0].shape)

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, embedding_dim, max_length=128):
        """
        Args:
            embedding_dim: Dimensionality of the embeddings (should match the model's embedding_dim).
            max_len: Maximum length of the input sequence (should cover the longest sequence you expect).
        """
        super(PositionalEncoding, self).__init__()
        self.embedding_dim = embedding_dim
        self.max_length = max_length

        pe = torch.zeros(self.max_length, self.embedding_dim)
        position = torch.arange(0, self.max_length, dtype=torch.float).unsqueeze(1)
        denominator = torch.exp(torch.arange(0, self.embedding_dim, 2).float() * (-math.log(10000.0) / self.embedding_dim))
        pe[:, 0::2] = torch.sin(position * denominator)
        pe[:, 1::2] = torch.cos(position * denominator)
        self.pe = pe.unsqueeze(0)  # Shape: [1, max_len, embedding_dim]

    def forward(self, x):
        seq_len = x.size(1)

        return x + self.pe[:seq_len, :].to(x.device)

In [None]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, masking=False, pe_length=128):

        super(Encoder, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.pe_length = pe_length # maximu length of a single input sequence
        self.masking = masking # decide whether a padding mask is applied before computing softmax

        self.embedding = nn.Embedding(self.vocab_size, self.embedding_size)
        self.pos_encoding = PositionalEncoding(self.embedding_size, max_length=self.pe_length)
        self.q1 = nn.Linear(self.embedding_size, self.hidden_size)
        self.k1 = nn.Linear(self.embedding_size, self.hidden_size)
        self.v1 = nn.Linear(self.embedding_size, self.hidden_size)

    def forward(self, x):
        mask = x != 0
        x = self.embedding(x)
        x = self.pos_encoding(x)
        q_vals = self.q1(x)
        k_vals = self.k1(x)
        softmax_input = torch.matmul(q_vals, k_vals.transpose(1, 2)) / math.sqrt(self.pe_length)
        if self.masking:
            softmax_input = softmax_input.masked_fill(mask == 0, float('-inf'))

        softmax_output = torch.softmax(softmax_input, dim=-1)
        v_vals = self.v1(x)
        csp_attention = torch.matmul(softmax_output, v_vals)
        x = x + csp_attention

        return x

In [None]:
encoder = Encoder(dataset.vocab_len(), 256, 256, masking=True, pe_length=dataset.max_length)
print(encoder.forward(dataset[0][0]).shape)