<a href="https://colab.research.google.com/github/rajaswa/feedback-and-memory-in-transformers/blob/main/Feedback_and_Memory_in_Transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook explores the role of Feedback and Memory in Transformer models. The notebook is based on this paper:

*Fan, A., Lavril, T., Grave, E., Joulin, A., & Sukhbaatar, S. (2020). [Addressing Some Limitations of Transformers with Feedback Memory](https://arxiv.org/abs/2002.09402). arXiv preprint arXiv:2002.09402.*

# Environment Setup

In [1]:
# Check GPU Allotment
!nvidia-smi

Thu Apr 29 09:29:30 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P8    28W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
# Install Dependencies
!pip -q install pytorch-lightning
!pip -q install torchmetrics

In [3]:
"""
IMPORTS
"""

import random
import string
import pandas as pd
import math
import copy
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
import torchmetrics

# SEED
random.seed(42)
torch.manual_seed(42)
torch.backends.cudnn.benchmark = True

# Understanding the Importance of Feedback and Memory in Transformer

### Top-down vs Bottom-up Processing

Most of the Deep Neural Network models follow a **Bottom-up processing** approach, where given a stimulus, high-level latent abstract representations are obtained. These representations are then eventually used for some downstream task.

On the other hand, **Top-down processing** relies on world-knowledge and previosuly known facts & beliefs in the memory. Feedback and Memory play a very important role in Top-down processing. Given below is a short video by **Khan Academy** explaining the difference between **Bottom-up & Top-down processing**, and their individual importance.

In [4]:
# Understanding the Role of Feedback and Top-down processing in Cognition & Perception
from IPython.display import HTML

HTML(
    '<iframe width="560" height="315" src="https://www.youtube.com/embed/aJy5_p_LAhQ" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>'
)

### Limitations of Transformers in Sequential Tasks

[Transformer models](http://jalammar.github.io/illustrated-transformer/) are quite scalable due to their parallel processing capabilities. They handle sequential processing by employing a self-attention mechanism over the entire input sequence. While this helps in capturing a (bi-directional) sequential context while processing the input at a given time-step, it blocks the access to high-level abstract representations from the past time-steps. While this still allows the normal Bottom-up processing, there's no scope for Top-down processing. This results in two major limitations of transformers in sequential tasks:


1.   **Limited Access to Higher Level Representations:** This allows the transformer model to perform only a limited number of state-updates to the input states.
2.   **Maintaining a Belief State:** This doesn't allow the transformer to work with longer sequences, which requires a good memory component.


We'll discuss both these limitations in detail by conducting experiments with a representative task for each of them. We'll also probe how introducding a Top-down feedback with memory helps tackle these limitations.

# Feedback Transformer

![Feedback Transformer](https://raw.githubusercontent.com/rajaswa/feedback-and-memory-in-transformers/main/figures/feedback_transformer.png)

**Feedback Transformer** addresses both the above mentioned limitations of **Vanilla Transformer** by performing **Sequential Processing**, instead of the usual Parallel Processing of input sequence. This is done by simply changing the **attention-mechanism** in the architecture:

1.   ***Self-Attention over Outputs from Previous Layers is Discarded***
2.   ***Attention is instead employed over the Memory states from past input-steps***
3.   ***Where, Memory states are obtained by a learnable weighted-summation of all the hidden representations for the particular time-step.***

This, enables the access to high-level representations across all the layers. This also allows feedback, where a sub-layer can feed itself via memory. While this is essentially a sequential model, it has **key differences with Recurrent Architectures like multi-layered RNNs and LSTMs**. Each layer of these models has recurrent connections to the same layer, but not to higher layers. Morever,  their internal state has a limited capacity determined by the number of layers and their hidden dimension. In contrast, the internal state of a Feedback Transformer is its whole memory, which can grow with the input length.

In [5]:
"""
FEEDBACK TRANSFORMER UTILITIES
"""


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=400):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[: x.size(0), :]
        return self.dropout(x)


class LinearWeightedAvg(nn.Module):
    def __init__(self, n_inputs):
        super(LinearWeightedAvg, self).__init__()
        self.weights = nn.Parameter(torch.randn(n_inputs))
        self.softmax = nn.Softmax(dim=0)

    def forward(self, input):
        res = 0
        weights = self.softmax(self.weights)
        for emb_idx, emb in enumerate(input):
            res += emb * weights[emb_idx]
        return res


class FeedforwardBlock(nn.Module):
    def __init__(self, d_model, dim_feedforward=2048, dropout=0.1, activation="gelu"):
        super(FeedforwardBlock, self).__init__()

        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout_projection = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.dropout_residual = nn.Dropout(dropout)
        self.norm_ff = nn.LayerNorm(d_model)
        self.dropout_ff = nn.Dropout(dropout)
        self.activation = self._get_activation_fn(activation)

    def forward(self, x):
        hidden_state = self.dropout_projection(
            self.activation(self.linear1(x))
        )  # projection
        ff_output = self.dropout_ff(self.linear2(hidden_state))  # feed-forward
        output = x + self.dropout_residual(ff_output)  # residual-connection
        output = self.norm_ff(output)
        return output

    def _get_activation_fn(self, activation):
        if activation == "relu":
            return F.relu
        elif activation == "gelu":
            return F.gelu
        raise RuntimeError("activation should be relu/gelu, not {}".format(activation))

In [6]:
"""
FEEDBACK TRANSFORMER ENCODER
"""


class FeedbackTransformerPointwiseEncoder(nn.Module):
    def __init__(
        self,
        d_model=256,
        nhead=8,
        num_layers=2,
        dim_feedforward=2048,
        dropout=0.1,
        activation="gelu",
    ):
        super(FeedbackTransformerPointwiseEncoder, self).__init__()

        # memory-attention
        self.memory_layer_wise_weighting = LinearWeightedAvg(n_inputs=num_layers)
        self.mem_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.mem_attn_dropout = nn.Dropout(dropout)
        self.mem_attn_norm = nn.LayerNorm(d_model)

        # feedforward
        self.feedforward_layers = self._get_clones(
            FeedforwardBlock(d_model, dim_feedforward, dropout, activation), num_layers
        )

    def forward(self, src_embed, memory_states, memory_key_padding_mask):
        # layer-wise memory-attention and feedforward
        layer_wise_outputs = []
        memory_states = torch.stack(memory_states)
        memory_key_padding_mask = torch.stack(memory_key_padding_mask, dim=0)
        output_embed = src_embed

        for feedforward in self.feedforward_layers:

            # memory-attention
            mem_attn_out, _ = self.mem_attn(
                query=output_embed, key=memory_states, value=memory_states, key_padding_mask=memory_key_padding_mask
            )
            output_embed = output_embed + self.mem_attn_dropout(mem_attn_out)
            output_embed = self.mem_attn_norm(output_embed)

            # feedforward
            output_embed = feedforward(output_embed)
            layer_wise_outputs.append(output_embed)

        # output memory-state for current time-step
        output_memory_state = self.memory_layer_wise_weighting(layer_wise_outputs)
        return output_embed, output_memory_state

    def _get_clones(self, module, N):
        return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


class FeedbackTransformerEncoder(nn.Module):
    def __init__(
        self,
        memory_context=16,
        d_model=256,
        nhead=8,
        num_layers=2,
        dim_feedforward=2048,
        dropout=0.1,
        activation="gelu",
    ):
        super(FeedbackTransformerEncoder, self).__init__()

        self.memory_context = memory_context
        self.d_model = d_model

        self.pointwise_encoder = FeedbackTransformerPointwiseEncoder(
            d_model=d_model,
            nhead=nhead,
            num_layers=num_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation=activation,
        )

        self.norm = nn.LayerNorm(d_model)

    def forward(self, src, mask, src_key_padding_mask):
        # memory context
        bs = src.shape[1]
        memory_states = [
            torch.zeros(bs, self.d_model).to(src.device)
            for i in range(self.memory_context)
        ]
        memory_key_padding_mask = [True for i in range(self.memory_context)]

        # iterate over entire sequence-length
        pred_seq_logits = []
        for i in range(src.shape[0]):
            output_embed, output_memory_state = self.pointwise_encoder(
                torch.unsqueeze(src[i], dim=0), memory_states, memory_key_padding_mask*bs
            )
            pred_seq_logits.append(torch.squeeze(output_embed))
            memory_states = [torch.squeeze(output_memory_state)] + memory_states[:-1]
            memory_key_padding_mask = [False] + memory_key_padding_mask[:-1]

        pred_seq_logits = self.norm(torch.stack(pred_seq_logits))
        return pred_seq_logits

In [7]:
"""
FEEDBACK TRANSFORMER DECODER
"""


class FeedbackTransformerPointwiseDecoder(nn.Module):
    def __init__(
        self,
        d_model=256,
        nhead=8,
        num_layers=2,
        dim_feedforward=2048,
        dropout=0.1,
        activation="gelu",
    ):
        super(FeedbackTransformerPointwiseDecoder, self).__init__()

        # cross-attention encoder-decoder
        self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.cross_dropout = nn.Dropout(dropout)
        self.cross_norm = nn.LayerNorm(d_model)

        # memory-attention
        self.memory_layer_wise_weighting = LinearWeightedAvg(n_inputs=num_layers)
        self.mem_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.mem_attn_dropout = nn.Dropout(dropout)
        self.mem_attn_norm = nn.LayerNorm(d_model)

        # feedforward
        self.feedforward_layers = self._get_clones(
            FeedforwardBlock(d_model, dim_feedforward, dropout, activation), num_layers
        )

    def forward(self, tgt_embed, memory_states, encoder_outputs, memory_key_padding_mask):
        # layer-wise memory-attention and feedforward
        layer_wise_outputs = []
        memory_states = torch.stack(memory_states)
        memory_key_padding_mask = torch.stack(memory_key_padding_mask, dim=0)
        output_embed = tgt_embed

        for feedforward in self.feedforward_layers:

            # memory-attention
            mem_attn_out, _ = self.mem_attn(
                query=output_embed, key=memory_states, value=memory_states, key_padding_mask=memory_key_padding_mask
            )
            output_embed = output_embed + self.mem_attn_dropout(mem_attn_out)
            output_embed = self.mem_attn_norm(output_embed)

            # cross-attention to encoder outputs
            output_embed2, _ = self.cross_attn(
                output_embed, encoder_outputs, encoder_outputs
            )
            output_embed = output_embed + self.cross_dropout(output_embed2)
            output_embed = self.cross_norm(output_embed)

            # feedforward
            output_embed = feedforward(output_embed)
            layer_wise_outputs.append(output_embed)

        # output memory-state for current time-step
        output_memory_state = self.memory_layer_wise_weighting(layer_wise_outputs)
        return output_embed, output_memory_state

    def _get_clones(self, module, N):
        return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


class FeedbackTransformerDecoder(nn.Module):
    def __init__(
        self,
        memory_context=16,
        d_model=256,
        nhead=8,
        num_layers=2,
        dim_feedforward=2048,
        dropout=0.1,
        activation="gelu",
    ):
        super(FeedbackTransformerDecoder, self).__init__()

        self.memory_context = memory_context
        self.d_model = d_model

        self.pointwise_decoder = FeedbackTransformerPointwiseDecoder(
            d_model=d_model,
            nhead=nhead,
            num_layers=num_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation=activation,
        )

        self.norm = nn.LayerNorm(d_model)

    def forward(
        self,
        tgt,
        encoder_outputs,
        tgt_mask,
        memory_mask,
        tgt_key_padding_mask,
        memory_key_padding_mask,
    ):
        # memory context
        bs = tgt.shape[1]
        memory_states = [
            torch.zeros(bs, self.d_model).to(tgt.device)
            for i in range(self.memory_context)
        ]
        memory_key_padding_mask = [True for i in range(self.memory_context)]

        # iterate over entire sequence-length
        pred_seq_logits = []
        for i in range(tgt.shape[0]):
            output_embed, output_memory_state = self.pointwise_decoder(
                torch.unsqueeze(tgt[i], dim=0), memory_states, encoder_outputs, memory_key_padding_mask*bs
            )
            pred_seq_logits.append(torch.squeeze(output_embed))
            memory_states = [torch.squeeze(output_memory_state)] + memory_states[:-1]
            memory_key_padding_mask = [False] + memory_key_padding_mask[:-1]

        pred_seq_logits = self.norm(torch.stack(pred_seq_logits))
        return pred_seq_logits

In [8]:
"""
END TO END FEEDBACK TRANSFORMER MODEL
"""


class FeedbackTransformerModel(nn.Module):
    def __init__(
        self,
        encoder_feedback=False,
        decoder_feedback=True,
        memory_context=16,
        input_vocab_size=11,
        output_vocab_size=11,
        d_model=256,
        nhead=8,
        num_layers=4,
        dim_feedforward=2048,
        dropout=0.1,
        max_seq_length=400,
        PAD_IDX=10,
        activation="gelu",
    ):
        super(FeedbackTransformerModel, self).__init__()

        self.d_model = d_model
        self.input_vocab_size = input_vocab_size
        self.output_vocab_size = output_vocab_size
        self.PAD_IDX = PAD_IDX

        # Embeddings
        self.pos_encoder = PositionalEncoding(
            d_model, dropout=dropout, max_len=max_seq_length
        )
        self.src_embedding = nn.Embedding(input_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(output_vocab_size, d_model)

        # Feedback Transformer
        if encoder_feedback:
            feedback_encoder = FeedbackTransformerEncoder(
                memory_context=memory_context,
                d_model=d_model,
                nhead=nhead,
                num_layers=num_layers,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
                activation=activation,
            )
        else:
            feedback_encoder = None

        if decoder_feedback:
            feedback_decoder = FeedbackTransformerDecoder(
                memory_context=memory_context,
                d_model=d_model,
                nhead=nhead,
                num_layers=num_layers,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
                activation=activation,
            )
        else:
            feedback_decoder = None

        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            custom_encoder=feedback_encoder,
            custom_decoder=feedback_decoder,
            dropout=dropout,
            activation=activation,
        )

        self.lm_layer = nn.Linear(d_model, output_vocab_size)

    def forward(self, input_seq, output_seq, flatten_lm_output=False):
        # Input Sequence (N,S) -> Permuted Input Sequence (S,N)
        input_seq = input_seq.permute(1, 0)
        output_seq = output_seq.permute(1, 0)

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = self.create_mask(
            input_seq, output_seq, self.PAD_IDX
        )

        src_embeddings = self.pos_encoder(
            self.src_embedding(input_seq) * math.sqrt(self.d_model)
        )
        tgt_embeddings = self.pos_encoder(
            self.tgt_embedding(output_seq) * math.sqrt(self.d_model)
        )

        transformer_outputs = self.transformer(
            src=src_embeddings,
            tgt=tgt_embeddings,
            src_mask=src_mask.to(src_embeddings.device),
            tgt_mask=tgt_mask.to(tgt_embeddings.device),
            src_key_padding_mask=src_padding_mask.to(src_embeddings.device),
            tgt_key_padding_mask=tgt_padding_mask.to(tgt_embeddings.device),
        )

        pred_seq_logits = self.lm_layer(transformer_outputs).permute(1, 0, 2)
        if flatten_lm_output:
            pred_seq_logits = pred_seq_logits.reshape(-1, self.output_vocab_size)
        return pred_seq_logits

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones((sz, sz))) == 1).transpose(0, 1)
        mask = (
            mask.float()
            .masked_fill(mask == 0, float("-inf"))
            .masked_fill(mask == 1, float(0.0))
        )
        return mask

    def create_mask(self, src, tgt, PAD_IDX):
        src_seq_len = src.shape[0]
        tgt_seq_len = tgt.shape[0]

        tgt_mask = self.generate_square_subsequent_mask(tgt_seq_len)
        src_mask = torch.zeros((src_seq_len, src_seq_len)).type(torch.bool)

        src_padding_mask = (src == PAD_IDX).transpose(0, 1)
        tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)

        return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [9]:
"""
LIGHTNING MODULE FOR SEQ2SEQ TASKS
"""


class Seq2SeqModel(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.loss = nn.CrossEntropyLoss(ignore_index=0)
        self.accuracy = torchmetrics.Accuracy()

    def training_step(self, batch, batch_idx):
        input_seq, output_seq = batch
        pred_seq_logits = self.model(input_seq, output_seq[:, :-1], flatten_lm_output=True)
        loss = self.loss(pred_seq_logits, output_seq[:, 1:].reshape(-1))

        self.log(
            "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        return loss

    def validation_step(self, batch, batch_idx):
        input_seq, output_seq = batch
        pred_seq_logits = self.model(input_seq, output_seq[:, :-1], flatten_lm_output=True)
        loss = self.loss(pred_seq_logits, output_seq[:, 1:].reshape(-1))

        self.log(
            "val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
        )
        return {
            "loss": loss,
            "pred": pred_seq_logits,
            "ground_truth": output_seq[:, 1:].reshape(-1),
        }

    def validation_epoch_end(self, validation_step_outputs):
        preds, ground_truths = [], []
        for out in validation_step_outputs:
            preds += torch.argmax(out["pred"], dim=1).tolist()
            ground_truths += out["ground_truth"].tolist()
        accuracy = self.accuracy(torch.tensor(preds), torch.tensor(ground_truths))

        self.log(
            "epoch_val_accuracy", accuracy, on_epoch=True, prog_bar=True, logger=True
        )

    def test_step(self, batch, batch_idx):
        input_seq, output_seq = batch
        pred_seq_logits = self.model(input_seq, output_seq[:, :-1], flatten_lm_output=True)
        loss = self.loss(pred_seq_logits, output_seq[:, 1:].reshape(-1))

        self.log(
            "test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
        )
        return {
            "loss": loss,
            "pred": pred_seq_logits,
            "ground_truth": output_seq[:, 1:].reshape(-1),
        }

    def test_epoch_end(self, test_step_outputs):
        preds, ground_truths = [], []
        for out in test_step_outputs:
            preds += torch.argmax(out["pred"], dim=1).tolist()
            ground_truths += out["ground_truth"].tolist()
        accuracy = self.accuracy(torch.tensor(preds), torch.tensor(ground_truths))

        self.log(
            "epoch_test_accuracy", accuracy, on_epoch=True, prog_bar=True, logger=True
        )

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=3e-3, weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="min", factor=0.5, patience=1
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
            "monitor": "epoch_val_accuracy",
        }

# COGS Benchmark

In [10]:
"""
RAW DATA
"""

!git clone https://github.com/najoungkim/COGS.git
BASE_DIR = "./COGS/data/"
TRAIN_PATH = str("{}train.tsv".format(BASE_DIR))
TRAIN_100_PATH = str("{}train_100.tsv".format(BASE_DIR))
VALID_PATH = str("{}dev.tsv".format(BASE_DIR))
TEST_PATH = str("{}test.tsv".format(BASE_DIR))
GEN_PATH = str("{}gen.tsv".format(BASE_DIR))

Cloning into 'COGS'...
remote: Enumerating objects: 264, done.[K
remote: Counting objects: 100% (264/264), done.[K
remote: Compressing objects: 100% (222/222), done.[K
remote: Total 264 (delta 38), reused 251 (delta 32), pack-reused 0[K
Receiving objects: 100% (264/264), 2.45 MiB | 13.28 MiB/s, done.
Resolving deltas: 100% (38/38), done.


In [11]:
"""
PYTORCH DATASET CLASS
"""


class COGSDataset(Dataset):
    def __init__(self, PATH):
        super().__init__()

        self.PAD_IDX = 0
        self.BOS_IDX = 1
        self.EOS_IDX = 2

        # load data
        (
            src_lines,
            tgt_lines,
            src_vocab,
            tgt_vocab,
            self.codes,
        ) = self.getCOGSParallelData(PATH)

        # tokenize data
        self.PAD_IDX = 0
        self.src_lines_tokenized, self.src_token2id = self.tokenize(
            src_lines, src_vocab
        )
        print(
            "Tokenized Source Lines. {} Unique Tokens in Source Data".format(
                len(self.src_token2id)
            )
        )
        self.tgt_lines_tokenized, self.tgt_token2id = self.tokenize(
            tgt_lines, tgt_vocab
        )
        print(
            "Tokenized Target Lines. {} Unique Tokens in Target Data".format(
                len(self.tgt_token2id)
            )
        )

    def getCOGSParallelData(self, PATH):
        # read raw file
        with open(PATH) as f:
            data = f.readlines()

        src_vocab, tgt_vocab = set(), set()
        src_lines, tgt_lines, codes = [], [], []

        for line in data:
            source, target, code = line.rstrip("\n").split("\t")
            src_lines.append(source)
            tgt_lines.append(target)
            codes.append(code)
            src_vocab.update(source.split())
            tgt_vocab.update(target.split())

        return src_lines, tgt_lines, src_vocab, tgt_vocab, codes

    def tokenize(self, lines, vocab):
        vocab = list(vocab)
        # create dictionary
        token2id = {"[PAD]": 0, "[BOS]": 1, "[EOS]": 2}
        for i in range(len(vocab)):
            token2id[vocab[i]] = i + 3

        # tokenize lines
        tokenized_lines = []
        for line in lines:
            tokenized_line = [self.BOS_IDX] + [token2id[item] for item in line.split()] + [self.EOS_IDX]
            tokenized_lines.append(tokenized_line)

        return tokenized_lines, token2id

    def __len__(self):
        return len(self.src_lines_tokenized)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        input_seq = torch.tensor(self.src_lines_tokenized[idx])
        output_seq = torch.tensor(self.tgt_lines_tokenized[idx])
        return input_seq, output_seq

    def pad_tensor(self, tensor, max_length):
        tensor = torch.cat(
            [tensor, torch.zeros(max_length - tensor.shape[0], dtype=torch.int32)],
            dim=0,
        )
        return tensor

    def collate_fn(self, batch):
        # find longest sequences
        max_len_input = max([sample[0].shape[0] for sample in batch])
        max_len_output = max([sample[1].shape[0] for sample in batch])

        # pad according to max_length
        input_seq = [self.pad_tensor(sample[0], max_len_input) for sample in batch]
        output_seq = [self.pad_tensor(sample[1], max_len_output) for sample in batch]

        # stack all
        input_seq = torch.stack(input_seq, dim=0)
        output_seq = torch.stack(output_seq, dim=0)
        return input_seq, output_seq

In [12]:
"""
LIGHTNING DATAMODULE FOR COGS BENCHMARK
"""


class COGSDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32, num_workers=2, use_100=True):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.use_100 = use_100

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            if self.use_100:
                self.train_dataset = COGSDataset(PATH=TRAIN_100_PATH)
            else:
                self.train_dataset = COGSDataset(PATH=TRAIN_PATH)
            self.valid_dataset = COGSDataset(PATH=VALID_PATH)
        if stage == "test" or stage is None:
            self.test_dataset = COGSDataset(PATH=TEST_PATH)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            drop_last=True,
            collate_fn=self.train_dataset.collate_fn,
        )

    def val_dataloader(self):
        return DataLoader(
            self.valid_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            collate_fn=self.valid_dataset.collate_fn,
            drop_last=True,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            collate_fn=self.test_dataset.collate_fn,
            drop_last=True,
        )

# INSTANTITATE A DATAMODULE
datamodule = COGSDataModule(batch_size=64, num_workers=2, use_100=True)

In [13]:
"""
TRAINING SETUP FOR COGS BENCHMARK
"""

trainer_flags = {
    "amp_backend": "native",
    "benchmark": False,
    "deterministic": False,
    "callbacks": [
        ModelCheckpoint(monitor="epoch_val_accuracy"),
        EarlyStopping(monitor="epoch_val_accuracy", mode="max", patience=3),
    ],
    "gpus": 1,
    "log_every_n_steps": 10,
    "logger": TensorBoardLogger(save_dir="logs/", name="cogs_benchmark_logs"),
    "max_epochs": 100,
    "progress_bar_refresh_rate": 20,
}

### Vanilla Transformer

In [None]:
"""
TRAIN SEQ2SEQ VANILLA TRANSFORMER FOR COGS BENCHMARK
"""

model = Seq2SeqModel(
    model=FeedbackTransformerModel(
        encoder_feedback=False,
        decoder_feedback=False,
        memory_context=8,
        input_vocab_size=800,
        output_vocab_size=800,
        d_model=32,
        nhead=4,
        num_layers=4,
        dim_feedforward=64,
        max_seq_length=800,
        dropout=0.1,
        PAD_IDX=0,
        activation="gelu",
    )
)

trainer = pl.Trainer(**trainer_flags)
trainer.fit(model=model, datamodule=datamodule)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


Tokenized Source Lines. 747 Unique Tokens in Source Data


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Tokenized Target Lines. 687 Unique Tokens in Target Data
Tokenized Source Lines. 584 Unique Tokens in Source Data
Tokenized Target Lines. 579 Unique Tokens in Target Data



  | Name     | Type                     | Params
------------------------------------------------------
0 | model    | FeedbackTransformerModel | 163 K 
1 | loss     | CrossEntropyLoss         | 0     
2 | accuracy | Accuracy                 | 0     
------------------------------------------------------
163 K     Trainable params
0         Non-trainable params
163 K     Total params
0.653     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

In [None]:
"""
TEST SEQ2SEQ VANILLA TRANSFORMER FOR COGS BENCHMARK
"""

trainer.test()

### Feedback Transformer

In [None]:
"""
TRAIN SEQ2SEQ FEEDBACK TRANSFORMER FOR COGS BENCHMARK
"""

model = Seq2SeqModel(
    model=FeedbackTransformerModel(
        encoder_feedback=False,
        decoder_feedback=True,
        memory_context=8,
        input_vocab_size=800,
        output_vocab_size=800,
        d_model=32,
        nhead=4,
        num_layers=4,
        dim_feedforward=64,
        max_seq_length=800,
        dropout=0.1,
        PAD_IDX=0,
        activation="gelu",
    )
)

trainer = pl.Trainer(**trainer_flags)
trainer.fit(model=model, datamodule=datamodule)

In [None]:
"""
TEST SEQ2SEQ FEEDBACK TRANSFORMER FOR COGS BENCHMARK
"""

trainer.test()

### Logs

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs

# Sequence Copy & Reverse Task

In [None]:
"""
PYTORCH DATASET CLASS
"""

class SequenceCopyDataset(Dataset):
    def __init__(self, num_samples=10000, max_length=40, reverse=True):
        super().__init__()
        self.PAD_IDX = 0
        self.BOS_IDX = 1
        self.EOS_IDX = 2
        self.sequence_pairs = self.generate_samples(num_samples, max_length, reverse)

    def generate_samples(self, num_samples=1000, max_length=40, reverse=False):
        sequence_pairs = []
        for i in range(num_samples):
            input_sequence = [
                random.choice([3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
                for _ in range(max_length)
            ]
            if reverse:
                output_sequence = [self.BOS_IDX] + [
                    input_sequence[-1 * i] for i in range(1, max_length + 1)
                ] + [self.EOS_IDX]
            else:
                output_sequence = [self.BOS_IDX] + input_sequence + [self.EOS_IDX]
            sequence_pairs.append({"input": input_sequence, "output": output_sequence})
        return sequence_pairs

    def __len__(self):
        return len(self.sequence_pairs)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        sample = (
            torch.tensor(self.sequence_pairs[idx]["input"]),
            torch.tensor(self.sequence_pairs[idx]["output"]),
        )
        return sample

In [None]:
"""
LIGHTNING DATAMODULE FOR SEQUENCE COPY & REVERSE TASK
"""

class SequenceCopyDataModule(pl.LightningDataModule):
    def __init__(
        self,
        batch_size=32,
        num_workers=4,
        num_samples_train=10000,
        num_samples_eval=2000,
        max_length_train=40,
        max_length_eval=60,
        reverse=False,
    ):
        super().__init__()
        self.num_samples_train = num_samples_train
        self.num_samples_eval = num_samples_eval
        self.max_length_train = max_length_train
        self.max_length_eval = max_length_eval
        self.reverse = reverse
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.train_dataset = SequenceCopyDataset(
                num_samples=self.num_samples_train,
                max_length=self.max_length_train,
                reverse=self.reverse,
            )
            self.valid_dataset = SequenceCopyDataset(
                num_samples=self.num_samples_eval,
                max_length=self.max_length_eval,
                reverse=self.reverse,
            )
        if stage == "test" or stage is None:
            self.test_dataset = SequenceCopyDataset(
                num_samples=self.num_samples_eval,
                max_length=self.max_length_eval,
                reverse=self.reverse,
            )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            drop_last=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.valid_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            drop_last=True,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            drop_last=True,
        )

# INSTANTITATE A DATAMODULE
datamodule = SequenceCopyDataModule(
    batch_size=64,
    num_workers=2,
    num_samples_train=10000,
    num_samples_eval=1000,
    max_length_train=10,
    max_length_eval=50,
    reverse=True,
)

In [None]:
"""
TRAINING SETUP FOR SEQUENCE COPY & REVERSE TASK
"""

trainer_flags = {
    "amp_backend": "native",
    "benchmark": False,
    "deterministic": False,
    "callbacks": [
        ModelCheckpoint(monitor="epoch_val_accuracy"),
        EarlyStopping(monitor="epoch_val_accuracy", mode="max", patience=3),
    ],
    "gpus": 1,
    "log_every_n_steps": 10,
    "logger": TensorBoardLogger(save_dir="logs/", name="sequence_copy_reverse_logs"),
    "max_epochs": 1,
    "progress_bar_refresh_rate": 20,
}

### Vanilla Transformer

In [None]:
"""
TRAIN SEQ2SEQ VANILLA TRANSFORMER FOR SEQUENCE COPY & REVERSE TASK
"""

model = Seq2SeqModel(
    model=FeedbackTransformerModel(
        encoder_feedback=False,
        decoder_feedback=False,
        memory_context=16,
        input_vocab_size=13,
        output_vocab_size=13,
        d_model=128,
        nhead=8,
        num_layers=4,
        dim_feedforward=256,
        max_seq_length=203,
        dropout=0.1,
        PAD_IDX=0,
        activation="gelu",
    )
)

trainer = pl.Trainer(**trainer_flags)
trainer.fit(model=model, datamodule=datamodule)

In [None]:
"""
TEST SEQ2SEQ VANILLA TRANSFORMER FOR SEQUENCE COPY & REVERSE TASK
"""

trainer.test()

### Feedback Transformer

In [None]:
"""
TRAIN SEQ2SEQ FEEDBACK TRANSFORMER FOR SEQUENCE COPY & REVERSE TASK
"""

model = Seq2SeqModel(
    model=FeedbackTransformerModel(
        encoder_feedback=False,
        decoder_feedback=True,
        memory_context=16,
        input_vocab_size=13,
        output_vocab_size=13,
        d_model=128,
        nhead=8,
        num_layers=4,
        dim_feedforward=256,
        max_seq_length=203,
        dropout=0.1,
        PAD_IDX=0,
        activation="gelu",
    )
)

trainer = pl.Trainer(**trainer_flags)
trainer.fit(model=model, datamodule=datamodule)

In [None]:
"""
TEST SEQ2SEQ FEEDBACK TRANSFORMER FOR SEQUENCE COPY & REVERSE TASK
"""

trainer.test()

### Logs

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs