In [1]:
import math
from timeit import default_timer as timer

from torch import Tensor
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import Transformer
from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer
from torch.nn import TransformerEncoder, TransformerEncoderLayer, LayerNorm
from torch.nn import TransformerDecoder, TransformerDecoderLayer
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k, Multi30k
from torch.nn.utils.rnn import pad_sequence
from typing import Iterable, List

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


## Dataset and Data Processing

In [2]:
def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
    """
    Generator function to yield a list of tokens for each sample in the dataset.
    """
    # Mapping of languages to their respective indices in the dataset
    language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}

    # Iterate over each data sample in the dataset
    for data_sample in data_iter:
        # Tokenize the sample using the specified language's token transformation function
        # and yield the list of tokens
        yield token_transform[language](data_sample[language_index[language]])


In [3]:
# We need to modify the URLs for the dataset since the links to the original dataset are broken
# Refer to https://github.com/pytorch/text/issues/1756#issuecomment-1163664163 for more info
multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

SRC_LANGUAGE = 'en'
TGT_LANGUAGE = 'de'

# Place-holders
token_transform = {}
vocab_transform = {}

In [4]:
token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')
token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')

# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    # Training data Iterator
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    # Create torchtext's Vocab object
    vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter, ln),
                                                    min_freq=1,
                                                    specials=special_symbols,
                                                    special_first=True)

# Set ``UNK_IDX`` as the default index. This index is returned when the token is not found.
# If not set, it throws ``RuntimeError`` when the queried token is not found in the Vocabulary.
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
  vocab_transform[ln].set_default_index(UNK_IDX)

In [5]:
def sequential_transforms(*transforms):
    """
    Compose several transforms sequentially.

    This function is a utility to apply a list of transformations sequentially to text data.
    It is used for data preprocessing steps.
    """
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

def tensor_transform(token_ids: List[int]):
    """
    Add BOS (beginning of sequence) and EOS (end of sequence) tokens and convert to tensor.

    This function is used to preprocess the token sequence for the transformer model by adding
    special tokens and converting the list of token IDs into a tensor.
    """
    return torch.cat((torch.tensor([BOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))


# ``src`` and ``tgt`` language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    text_transform[ln] = sequential_transforms(token_transform[ln], #Tokenization
                                               vocab_transform[ln], #Numericalization
                                               tensor_transform) # Add BOS/EOS and create tensor

In [6]:
def collate_fn(batch):
    """
    Function to collate data samples into batch tensors.

    This function is used during data loading to collate multiple data samples (source and target)
    into batch tensors. It also applies padding to ensure consistent tensor sizes.

    Applied within DataLoaders to ensure all samples are valid for the training process.
    """
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        # Apply text transformations and add to batches
        src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))
        tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))

    # Pad sequences to create uniformly sized batches
    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)

    return src_batch, tgt_batch

In [7]:
def generate_square_subsequent_mask(sz):
    """
    Generate a square mask for the sequence. The mask shows which entries should not be used.
    
    This mask is used in the decoder part of the transformer model to prevent the model from
    peeking at the subsequent positions in the sequence during training.
    """
    # Create an upper triangular matrix of ones
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    
    # Replace 0s with '-inf' and 1s with 0.0. '-inf' values will be masked.
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
    """
    Create masks for the source and target sequences.

    This function generates masks for the source and target to be used in the transformer model.
    These masks include padding masks for both source and target, and a target mask for the subsequent positions.
    """
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    # Create a subsequent mask for target sequences to prevent future peeks
    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    # Source mask is all zeros since full sequence can be attended to
    src_mask = torch.zeros((src_seq_len, src_seq_len), device=device).type(torch.bool)

    # Padding masks for source and target sequences
    # This marks positions with PAD_IDX as True so that these are not used in attention computations
    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)

    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

## Model
- We use a 6L6L Transformer as the base model.
- Embedding dimension is 1024. Each multi-head attention layer has 16 heads, with a dimension of 1024 for 
QKV if combining all the heads.
- The hidden projection dimension in FFNs is 4096.
- Dropout layers has a dropout rate of 0.1.
- We use a batch size of 1024.

In [8]:
class PositionalEncoding(nn.Module):
    """
    PositionalEncoding module injects some information about the relative or absolute position 
    of the tokens in the sequence. The positional encodings have the same dimension as 
    the embeddings so that the two can be summed. Here, we use sine and cosine functions 
    of different frequencies.
    """

    def __init__(self, emb_size: int, dropout: float, maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()

        # Create constant 'denominator' part of the positional encoding formula
        den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))

        # Apply sine to even indices in the array; 2i
        pos_embedding[:, 0::2] = torch.sin(pos * den)

        # Apply cosine to odd indices in the array; 2i+1
        pos_embedding[:, 1::2] = torch.cos(pos * den)

        # Reshape for adding to token embeddings
        pos_embedding = pos_embedding.unsqueeze(-2)

        # Dropout layer
        self.dropout = nn.Dropout(dropout)

        # Register pos_embedding as a buffer
        # A buffer is a persistent state for the module (not a parameter, so it's not updated during backprop)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        # Add positional encoding to token embedding, and apply dropout
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])


In [9]:
class TokenEmbedding(nn.Module):
    """
    TokenEmbedding module converts input indices (token IDs) into corresponding embeddings.
    It's a wrapper around the PyTorch Embedding layer, scaling the embeddings by the square root 
    of the embedding size to normalize their variance.
    """

    def __init__(self, vocab_size: int, emb_size: int):
        super(TokenEmbedding, self).__init__()

        # Create an embedding layer with given vocabulary size and embedding size
        self.embedding = nn.Embedding(vocab_size, emb_size)

        # Store the embedding size, used for scaling the embeddings
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        # Convert token indices to embeddings
        # Scale embeddings by the square root of the embedding size
        # This scaling helps maintain a balance in the magnitude of the embeddings, 
        # especially useful in models that involve a lot of matrix multiplications (like Transformers)
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)


In [10]:
class SixLTransformer(nn.Module):
    """
    SixLTransformer is a custom implementation of the Transformer model.
    It consists of a Transformer, embedding layers for source and target, 
    a positional encoding layer, and a final linear layer (generator) for output.
    """

    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 1024,
                 dropout: float = 0.1):
        super(SixLTransformer, self).__init__()

        # Transformer model
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)

        # Final linear layer to generate output
        self.generator = nn.Linear(emb_size, tgt_vocab_size)

        # Embedding layers for source and target sequences
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)

        # Positional Encoding layer
        self.positional_encoding = PositionalEncoding(emb_size, dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):

        # Apply embeddings and positional encoding to the source and target input
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))

        # Pass through the transformer model
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)

        # Generate final output
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        """
        Encode source sequence into context vectors.
        """
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        """
        Decode target sequence from the encoded context.
        """
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)


In [11]:
torch.manual_seed(0)

SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
EMB_SIZE = 1024
NHEAD = 16
FFN_HID_DIM = 4096
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 6
NUM_DECODER_LAYERS = 6

model = SixLTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

model = model.to(device)

## Optimizer
- Adam optimizer (Kingma & Ba, 2014) is used with β1 = 0.9 and β2 = 0.98.
- No weight decay is applied.
- Base learning rate is 0.001. 

In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

## Scheduler
- We adopt a three-stage training scheme, where the learning rate (LR) of each stage decreases from base to zero following a cosine decay.
- The first LR cycle has 50000 steps, others have 88339 steps.
- A quantization event starts at the beginning of each stage.
- **We first train the model in float.**
- In the second stage, all weights will be binarized.
- In the last stage, both weights and activations will be binarized.

In [13]:
# Scheduler First Cycle
scheduler = CosineAnnealingLR(optimizer, T_max=50000)  # Adjust T_max based on training stages

# Scheduler Second and Third Cycle
# scheduler = CosineAnnealingLR(optimizer, T_max=88339)

## Cross Entropy Loss

In [14]:
# This is the normal Cross Entropy Loss, used by default when training the Transformer Architecture
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

## Proof of Concept Knowledge Distillation Loss
- The authors apply knowledge distillation (KD) during training, learning from a more advanced model.
- KD can be implemented by replacing the ground truth label in the cross-entropy loss function with the softmaxed logits from the teacher model, so it is optional for users.

Implementing Knowledge Distillation is quite a specific task - it generally requires having access to a good model trained on the task at hand, and which is also efficient enough to run during the training process, whilst providing a valuable loss to the model. The computation requirements can get out of hand given given that 2 models should run at the same time.

In [15]:
def custom_loss(output, target, teacher_model=None, alpha=0.5, temperature=2.0):
    """
    Custom loss function for knowledge distillation.
    
    Parameters:
    output: Logits from the student model.
    target: Ground truth labels or logits from the teacher model.
    teacher_model: Pre-trained teacher model. If None, standard training is performed.
    alpha: Weighting factor for combining the teacher and student loss.
    temperature: Temperature for softening probabilities.
    """
    student_loss = F.cross_entropy(output, target.view(-1))
    
    if teacher_model:
        with torch.no_grad():
            teacher_output = teacher_model(target)
        teacher_output = teacher_output.view_as(output)

        # Soften the outputs of both student and teacher models
        soft_student_probs = F.log_softmax(output / temperature, dim=1)
        soft_teacher_probs = F.softmax(teacher_output / temperature, dim=1)

        # Calculate the distillation loss
        distillation_loss = F.kl_div(soft_student_probs, soft_teacher_probs, reduction='batchmean')

        # Combine the student and distillation loss
        return alpha * distillation_loss + (1 - alpha) * student_loss
    else:
        return student_loss


## Training and Evaluation

In [16]:
def train_epoch(model, optimizer):
    # Set the model to training mode (enables dropout, batch normalization, etc.)
    model.train()

    # Variable to accumulate losses for computing average loss later
    losses = 0

    # Load the training dataset
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    # Loop over each batch in the dataset
    for src, tgt in train_dataloader:
        # Move source and target tensors to the appropriate device (GPU or CPU)
        src = src.to(device)
        tgt = tgt.to(device)

        # Prepare the input and output for the target sequence
        # The target input excludes the last token, and the output excludes the first token
        tgt_input = tgt[:-1, :]
        tgt_out = tgt[1:, :]

        # Create masks for the source and target sequences
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        # Forward pass: compute predictions
        logits = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)

        # Zero the gradients before backpropagation
        optimizer.zero_grad()

        # Compute the loss between the predictions and the true target output
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        
        # Backpropagation
        loss.backward()

        # Update model parameters
        optimizer.step()

        # Accumulate the loss
        losses += loss.item()
    
    # Update the learning rate if using a learning rate scheduler
    scheduler.step()

    # Return the average loss
    return losses / len(list(train_dataloader))

In [17]:
def evaluate(model):
    """
    Evaluate the model.
    """
    # Set the model to evaluation mode (disables dropout, batch normalization, etc.)
    model.eval()

    # Variable to accumulate losses for computing average loss later
    losses = 0

    # Load the validation dataset
    val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    # Loop over each batch in the dataset
    for src, tgt in val_dataloader:
        # Move source and target tensors to the appropriate device
        src = src.to(device)
        tgt = tgt.to(device)

        # Prepare the input and output for the target sequence
        tgt_input = tgt[:-1, :]
        tgt_out = tgt[1:, :]

        # Create masks for the source and target sequences
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        # Forward pass: compute predictions
        logits = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)

        # Compute the loss between the predictions and the true target output
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))

        # Accumulate the loss
        losses += loss.item()

    # Return the average loss
    return losses / len(list(val_dataloader))

In [18]:
NUM_EPOCHS = 20

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(model, optimizer)
    end_time = timer()
    val_loss = evaluate(model)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))

Epoch: 1, Train loss: 5.637, Val loss: 5.068, Epoch time = 73.065s
Epoch: 2, Train loss: 4.427, Val loss: 4.279, Epoch time = 72.763s
Epoch: 3, Train loss: 3.847, Val loss: 3.974, Epoch time = 72.677s
Epoch: 4, Train loss: 3.528, Val loss: 3.869, Epoch time = 72.738s
Epoch: 5, Train loss: 3.325, Val loss: 3.724, Epoch time = 72.916s
Epoch: 6, Train loss: 3.146, Val loss: 3.636, Epoch time = 73.246s
Epoch: 7, Train loss: 3.007, Val loss: 3.614, Epoch time = 73.627s
Epoch: 8, Train loss: 2.884, Val loss: 3.531, Epoch time = 73.844s
Epoch: 9, Train loss: 2.753, Val loss: 3.466, Epoch time = 73.328s
Epoch: 10, Train loss: 2.632, Val loss: 3.459, Epoch time = 73.405s
Epoch: 11, Train loss: 2.537, Val loss: 3.468, Epoch time = 73.592s
Epoch: 12, Train loss: 2.419, Val loss: 3.392, Epoch time = 73.231s
Epoch: 13, Train loss: 2.304, Val loss: 3.328, Epoch time = 74.232s
Epoch: 14, Train loss: 2.204, Val loss: 3.312, Epoch time = 73.543s
Epoch: 15, Train loss: 2.105, Val loss: 3.320, Epoch time

```
Epoch: 1, Train loss: 5.637, Val loss: 5.068, Epoch time = 73.065s
Epoch: 2, Train loss: 4.427, Val loss: 4.279, Epoch time = 72.763s
Epoch: 3, Train loss: 3.847, Val loss: 3.974, Epoch time = 72.677s
Epoch: 4, Train loss: 3.528, Val loss: 3.869, Epoch time = 72.738s
Epoch: 5, Train loss: 3.325, Val loss: 3.724, Epoch time = 72.916s
Epoch: 6, Train loss: 3.146, Val loss: 3.636, Epoch time = 73.246s
Epoch: 7, Train loss: 3.007, Val loss: 3.614, Epoch time = 73.627s
Epoch: 8, Train loss: 2.884, Val loss: 3.531, Epoch time = 73.844s
Epoch: 9, Train loss: 2.753, Val loss: 3.466, Epoch time = 73.328s
Epoch: 10, Train loss: 2.632, Val loss: 3.459, Epoch time = 73.405s
Epoch: 11, Train loss: 2.537, Val loss: 3.468, Epoch time = 73.592s
Epoch: 12, Train loss: 2.419, Val loss: 3.392, Epoch time = 73.231s
Epoch: 13, Train loss: 2.304, Val loss: 3.328, Epoch time = 74.232s
Epoch: 14, Train loss: 2.204, Val loss: 3.312, Epoch time = 73.543s
Epoch: 15, Train loss: 2.105, Val loss: 3.320, Epoch time = 73.151s
Epoch: 16, Train loss: 1.991, Val loss: 3.333, Epoch time = 73.269s
Epoch: 17, Train loss: 1.891, Val loss: 3.313, Epoch time = 73.090s
Epoch: 18, Train loss: 1.802, Val loss: 3.273, Epoch time = 73.098s
Epoch: 19, Train loss: 1.699, Val loss: 3.251, Epoch time = 74.058s
Epoch: 20, Train loss: 1.595, Val loss: 3.261, Eoch time = 74.433s

```

## Testing it Works

In [19]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    """
    Generate an output sequence using a greedy decoding algorithm.
    """
    # Move source sequence and mask to the appropriate device
    src = src.to(device)
    src_mask = src_mask.to(device)

    # Encode the source sequence
    memory = model.encode(src, src_mask)

    # Initialize the target sequence with the start symbol
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)

    # Generate the sequence token by token
    for i in range(max_len-1):
        # Move the memory tensor to the appropriate device (GPU or CPU)
        memory = memory.to(device)
    
        # Generate a square subsequent mask for the target sequence so far.
        # This mask ensures that the prediction for position i can depend only on the known outputs at positions less than i.
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(device)
    
        # Decode the output sequence so far to get the next output
        out = model.decode(ys, memory, tgt_mask)
    
        # Transpose the output to make it suitable for generating probabilities from the generator
        out = out.transpose(0, 1)
    
        # Pass the last output token through the generator to get logit probabilities for the next word
        prob = model.generator(out[:, -1])
    
        # Select the token with the highest probability as the next word
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()  # Convert to a Python int
    
        # Append the predicted next word to the sequence (ys)
        # The sequence (ys) holds all the tokens predicted so far and is used in the next iteration to predict the following token
        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
    
        # Break the loop if the end-of-sequence token is predicted
        if next_word == EOS_IDX:
            break
    
    return ys


In [20]:
def translate(model: torch.nn.Module, src_sentence: str):
    """
    Translate an input sentence into the target language using the model.
    """
    # Ensure the model is in evaluation mode
    model.eval()

    # Transform the source sentence into a tensor
    src = text_transform[SRC_LANGUAGE](src_sentence).view(-1, 1)

    # Create a source mask with all zeros
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)

    # Generate the target sequence using the greedy_decode function
    tgt_tokens = greedy_decode(
        model, src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()

    # Convert the generated sequence of indices back to a string
    return " ".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>", "").replace("<eos>", "")


In [21]:
print(translate(model, "It is cold outside .")) 

 Ein Bauarbeiter wartet auf den nächsten Zug . 


Prediction: `Ein Bauarbeiter wartet auf den nächsten Zug .`

Actual translation: `A construction worker waits for the next train`

**Translation is off, but at least words form coherent sentences. We would probably need to train a lot more for actual translations.**

## Binarization of Transformer Model
The function of casting floating-point values into binary Binarized Neural Machine Translation values is summarized as follows:

$$
\text{clip}(x, x_{\min}, x_{\max}) := \min(x_{\max}, \max(x_{\min}, x))
$$

$$
x_b := \left\lfloor \text{clip}\left( x, -1 + \varepsilon, 1 - \varepsilon \right) + 0.5 \right\rfloor \times B
$$

where $x$ is the input tensor, $\varepsilon$ is a small floating-point number that prevents overflow when taking the floor, and $B$ is the binarization bound.

### Note 1
In my implementation, I only binarize the linear layer. In the actual paper, it would seem like the authors made a very custom implementation of each layer, and binarized each layer, and operations. I found it to be quite an involved process, so this serves as a demonstration.

### Note 2
The architecture I implemented below vs. the previous one: I added multiple Layer Norm layers in places where the original Transformer paper authors also placed the Layer Norms.

The binarized linear layer also has a scaling factor involved, as mentioned in the paper. In this case, given that the feed forward layer is 4096, square root of that results to 64, so that's the default scaling factor.

In [22]:
# Function to binarize weights
def binarize_weights(weights, epsilon=0.1, B=2):
    with torch.no_grad():
        clipped = weights.clamp(-1 + epsilon, 1 - epsilon)
        return torch.floor((clipped + 1 - epsilon) + 0.5) * B - 1

In [23]:
class BinarizedLinear(nn.Linear):
    def forward(self, input, scaling_factor=64):
        # Clone the original weights to preserve them for the backward pass
        original_weight = self.weight.data.clone()

        # Binarize the weights for the forward pass
        # This converts the weights to -1 or +1, based on the binarization logic
        self.weight.data = binarize_weights(self.weight.data)

        # Forward pass with binarized weights and binarized input
        # Binarizing the input as well to maintain consistency with weights
        output = super(BinarizedLinear, self).forward(binarize_weights(input))

        # Restore the original weights for the backward pass
        # This ensures that the actual weights are used for gradient calculation
        self.weight.data = original_weight

        # Divide the output by a scaling factor
        # This is typically done to counter the effect of binarization,
        # which can significantly change the magnitude of the output
        return output / scaling_factor

In [39]:
class BinarizedSixLTransformer(nn.Module):
    def __init__(self,
                num_encoder_layers: int,
                num_decoder_layers: int,
                emb_size: int,
                nhead: int,
                src_vocab_size: int,
                tgt_vocab_size: int,
                dim_feedforward: int = 1024,
                dropout: float = 0.1):
        super(BinarizedSixLTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = BinarizedLinear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(emb_size, dropout)

        # Layer Normalization
        self.src_tok_norm = LayerNorm(emb_size)
        self.tgt_tok_norm = LayerNorm(emb_size)
        self.final_norm = LayerNorm(emb_size)
    
    def forward(self, src, trg, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))

        # Apply LayerNorm after positional encoding
        src_emb = self.src_tok_norm(src_emb)
        tgt_emb = self.tgt_tok_norm(tgt_emb)

        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        
        # Apply LayerNorm before generator
        outs = self.final_norm(outs)

        return self.generator(outs)

    def encode(self, src, src_mask):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        src_emb = self.src_tok_norm(src_emb)
        return self.transformer.encoder(src_emb, src_mask)

    def decode(self, tgt, memory, tgt_mask):
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(tgt))
        tgt_emb = self.tgt_tok_norm(tgt_emb)
        return self.transformer.decoder(tgt_emb, memory, tgt_mask)


In [40]:
torch.manual_seed(0)

SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
EMB_SIZE = 1024
NHEAD = 16
FFN_HID_DIM = 4096
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 6
NUM_DECODER_LAYERS = 6

# Initialize the model
binarized_model = BinarizedSixLTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE, NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)
binarized_model = binarized_model.to(device)

In [41]:
optimizer = torch.optim.Adam(binarized_model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)

In [42]:
for p in binarized_model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

In [43]:
NUM_EPOCHS = 10

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(binarized_model, optimizer)
    end_time = timer()
    val_loss = evaluate(binarized_model)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))

Epoch: 1, Train loss: 9.861, Val loss: 9.859, Epoch time = 30.313s
Epoch: 2, Train loss: 9.319, Val loss: 7.010, Epoch time = 29.779s
Epoch: 3, Train loss: 6.222, Val loss: 6.772, Epoch time = 29.691s
Epoch: 4, Train loss: 5.861, Val loss: 6.492, Epoch time = 29.832s
Epoch: 5, Train loss: 5.752, Val loss: 6.281, Epoch time = 29.754s
Epoch: 6, Train loss: 5.700, Val loss: 6.202, Epoch time = 29.781s
Epoch: 7, Train loss: 5.667, Val loss: 6.190, Epoch time = 29.793s
Epoch: 8, Train loss: 5.635, Val loss: 6.108, Epoch time = 29.834s
Epoch: 9, Train loss: 5.592, Val loss: 6.053, Epoch time = 29.856s
Epoch: 10, Train loss: 5.558, Val loss: 6.015, Epoch time = 30.001s


We also notice how the epochs take 30s as opposed to an average of 70-80s in the previous training operations. This is no coincidence - PyTorch optimizations seem to be happening in the background as the forward-passes are now much simpler to compute.

In [44]:
print(translate(binarized_model, "It is cold outside .")) 

 Ein Mann in . 


Actual translation: `A man in .`

## Recap

- Went over the generation of text embeddings for a sequence to sequence transformer (on 32 bits), with the purpose of Machine Translation.
- Created a Transformers architecture suitable for machine translation, amd trained it in a similar fashion to what was presented within the paper.
- Generated predictions using the trained model.
- Successfully binarized a linear layer. Made sure that the binarization only takes place during the forward pass.
- Trained a transformer architecture which includes the binarized linear layer and layer norms.
- Outputted predictions using the binarized model.

## Room for improvement
- Binarizing all layers mentioned in the paper - this means custom implementations of all layers, including activations.
- Training on the WMT 2014 data, as opposed to the Multi30k (larger dataset, longer training, better predictions).
- Performing knowledge distillation given a pre-trained model for Machine Translation.