This notebook implements a LSTM-based sequence-to-sequence with attention as a bidirectional Early Modern English.

We begin by importing various libraries and modules, as well as changing the device for speed.

In [None]:
import os, re, gc, time
from pathlib import Path
from collections import Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device: " +  device)

We then import our parallel data called `data.tsv` where one column is the Shakespearean translation and the other column is the paired Modern English.

In [None]:
from google.colab import files
uploaded = files.upload()

# file called data.tsv

Here, we get rid of the line numbers from our data.

In [None]:
raw_lines = Path("./data.tsv").read_text().splitlines()

cleaned = []
for line in raw_lines:
    line = re.sub(r"\d+", "", line).strip()
    cleaned.append(line)


We then split our data into Modern and Early Modern components.

In [None]:
shakespeare_lines = []
modern_lines = []

for line in cleaned:
    parts = line.split("\t")
    if parts[0] == '':
      continue
    else:
      shakespeare_lines.append(parts[0].strip())
      modern_lines.append(parts[1].strip())

Here, we train a SentencePeice model on the combined corpus using `model_type` of `bpe`, then load the trained model and wrap it in a tokenizer function. This `tokenize` function takes a string, then returns a list of subword units that SentencePiece outputs.

In [None]:
import sentencepiece as spm
from pathlib import Path

combined_text = "\n".join(shakespeare_lines + modern_lines)   # one sentence per line
Path("combined.txt").write_text(combined_text, encoding="utf8")

spm.SentencePieceTrainer.train(
    input='combined.txt',
    model_prefix='shakes_mod',
    vocab_size=2000,
    character_coverage=1.0,
    model_type='bpe',
    bos_id=1, eos_id=2, pad_id=0, unk_id=3,
    user_defined_symbols=''
)

sp = spm.SentencePieceProcessor(model_file='shakes_mod.model')

def tokenize(text):
    """
    Returns a list of sub‑word tokens produced by SentencePiece.
    The model already lower‑cases (if you set `--lower_case=true`
    during training) – we keep the original case so that proper nouns
    stay distinguishable. Feel free to call `text.lower()` before passing
    it if you prefer a case‑insensitive model.
    """
    return sp.encode_as_pieces(text)

This `build_joint_vocab` function makes all of the sentences (Modern and Shakespearean) into one vocabulary.

In [None]:
def build_joint_vocab(texts, min_freq=1):
    counter = Counter()
    for txt in texts:
        counter.update(tokenize(txt))

    vocab = {'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, '<UNK>': 3}
    for token, freq in counter.items():
        if freq >= min_freq:
            vocab[token] = len(vocab)
    return vocab

joint_vocab = build_joint_vocab(shakespeare_lines + modern_lines, min_freq=1)
vocab_size = len(joint_vocab)

This `to_ids` function converts sentences into a list of ids from the joint vocabulary, then adds `<SOS>` and `<EOS>` token ids as needed around the sentence. We then build lists of source ids and target ids, where they both contain the Modern and Shakespeare data, but in a parallel and opposite order.

In [None]:
def to_ids(text):
    tokens = tokenize(text)
    ids = [joint_vocab.get(tok, joint_vocab['<UNK>']) for tok in tokens]
    return [joint_vocab['<SOS>']] + ids + [joint_vocab['<EOS>']]

src_ids = [to_ids(s) for s in shakespeare_lines + modern_lines]
tgt_ids = [to_ids(m) for m in modern_lines + shakespeare_lines]

Here, we generate our training (90%) and validation (10%) split.

In [None]:
total = len(src_ids)
train_sz = int(0.9 * total)

train_src = src_ids[:train_sz]
train_tgt = tgt_ids[:train_sz]

val_src = src_ids[train_sz:]
val_tgt = tgt_ids[train_sz:]

The `TranslationDataset` class below wraps the data.

In [None]:
class TranslationDataset(Dataset):
    def __init__(self, src_seqs, tgt_seqs):
        self.src = src_seqs
        self.tgt = tgt_seqs
    def __len__(self): return len(self.src)
    def __getitem__(self, idx):
        return (torch.tensor(self.src[idx], dtype=torch.long),
                torch.tensor(self.tgt[idx], dtype=torch.long))

def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)
    src_padded = pad_sequence(src_batch, batch_first=True,
                              padding_value=joint_vocab['<PAD>'])
    tgt_padded = pad_sequence(tgt_batch, batch_first=True,
                              padding_value=joint_vocab['<PAD>'])
    return src_padded, tgt_padded

BATCH_SIZE = 32
train_loader = DataLoader(TranslationDataset(train_src, train_tgt),
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          collate_fn=collate_fn)
val_loader   = DataLoader(TranslationDataset(val_src, val_tgt),
                          batch_size=BATCH_SIZE,
                          shuffle=False,
                          collate_fn=collate_fn)

The following cell builds our model.

* The `Encoder` class defines an embedding layer with dimension 256, a 2-layer bidirectional LSTM with hidden size 128, and forward logic that returns the entire sequence of encoder states with final hidden and cell states.
* The `Attention` class computes Bahdanau/additive attention using learned linear layers that transform the concatenation of encoder and decoder states, with projections to scalar alignment scores and a softmax normalization.
* The `Decoder` class embeds the previous token, computes an attention context vector over encoder outputs, concatenates the embedding with the context vector, and passes the end result through a 2-layer unidirectional LSTM. A linear layer then maps the concatenation of the LSTM output, context vector, and embedding to vocabulary logits.
* The `Seq2Seq` class serves as a wrapper that coordinates the encoder and decoder, putting together the bidirectional encoder states before initializing the decoder. During training, it uses teacher forcing and calculates loss. The training loop incorporates gradient clipping and early stopping based on validation loss.

In [None]:
class Encoder(nn.Module):
    def __init__(self, vocab_sz, embed_dim, hidden_dim,
                 num_layers, dropout=0.3):
        super().__init__()
        self.embedding = nn.Embedding(vocab_sz, embed_dim,
                                      padding_idx=joint_vocab['<PAD>'])
        self.lstm = nn.LSTM(embed_dim, hidden_dim,
                            num_layers,
                            batch_first=True,
                            dropout=dropout,
                            bidirectional=True)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        # src: [batch, src_len]
        embedded = self.dropout(self.embedding(src))
        outputs, (hidden, cell) = self.lstm(embedded)
        # outputs: [batch, src_len, hidden*2] (bidirectional)
        return outputs, hidden, cell


class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.attn = nn.Linear(hidden_dim * 3, hidden_dim)
        self.v    = nn.Linear(hidden_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs):
        # hidden: [batch, hidden] (decoder's current hidden state)
        # encoder_outputs: [batch, src_len, hidden*2]
        src_len = encoder_outputs.size(1)

        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)   # [batch, src_len, hidden]
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        attention = self.v(energy).squeeze(2)               # [batch, src_len]
        return F.softmax(attention, dim=1)                  # normalized


class Decoder(nn.Module):
    def __init__(self, vocab_sz, embed_dim, enc_hidden_dim,
                 num_layers, attention, dropout=0.3):
        super().__init__()
        self.attention = attention
        self.embedding = nn.Embedding(vocab_sz, embed_dim,
                                      padding_idx=joint_vocab['<PAD>'])
        # LSTM input = embed + context (enc_hidden_dim*2 because encoder is bidir)
        self.lstm = nn.LSTM(embed_dim + enc_hidden_dim * 2,
                            enc_hidden_dim,
                            num_layers,
                            batch_first=True,
                            dropout=dropout)
        self.fc_out = nn.Linear(enc_hidden_dim + enc_hidden_dim * 2 + embed_dim,
                                vocab_sz)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_tok, hidden, cell, encoder_outputs):
        # input_tok: [batch]  (token ids)
        input_tok = input_tok.unsqueeze(1)                     # [batch, 1]
        embedded = self.dropout(self.embedding(input_tok))     # [batch, 1, embed]

        # ---- attention -------------------------------------------------
        a = self.attention(hidden[-1], encoder_outputs)       # [batch, src_len]
        a = a.unsqueeze(1)                                    # [batch, 1, src_len]
        weighted = torch.bmm(a, encoder_outputs)              # [batch, 1, hidden*2]

        # ---- LSTM input ------------------------------------------------
        lstm_input = torch.cat((embedded, weighted), dim=2)   # [batch, 1, embed+hidden*2]
        output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))

        # ---- final prediction -----------------------------------------
        prediction = self.fc_out(torch.cat((output.squeeze(1),
                                            weighted.squeeze(1),
                                            embedded.squeeze(1)), dim=1))
        return prediction, hidden, cell


class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        batch_sz, tgt_len = src.size(0), tgt.size(1)
        vocab_sz = self.decoder.fc_out.out_features

        # tensor to store decoder outputs
        outputs = torch.zeros(batch_sz, tgt_len, vocab_sz).to(self.device)

        # ---------- encoder -------------------------------------------------
        enc_outputs, hidden, cell = self.encoder(src)

        # Collapse bidirectional hidden states:
        # hidden / cell: [num_layers*2, batch, hidden]
        hidden = hidden.view(self.encoder.lstm.num_layers, 2, batch_sz, -1)
        cell   = cell.view(self.encoder.lstm.num_layers, 2, batch_sz, -1)

        # Use the **forward** direction only (or average both – here we average)
        hidden = (hidden[:, 0, :, :] + hidden[:, 1, :, :]) / 2
        cell   = cell[:, 0, :, :].contiguous()   # keep forward direction

        # first input to the decoder is <SOS>
        input_tok = tgt[:, 0]

        for t in range(1, tgt_len):
            output, hidden, cell = self.decoder(input_tok, hidden, cell, enc_outputs)
            outputs[:, t, :] = output

            # teacher forcing?
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input_tok = tgt[:, t] if teacher_force else top1

        return outputs

We then define our hyperparameters. `PATIENCE` is used for early stopping since we don't expect all of the epochs to be effective due to now having very much data.

In [None]:
EMBED_DIM = 256
HIDDEN_DIM = 128
NUM_LAYERS = 2
DROPOUT = 0.3
LR = 5e-4
EPOCHS = 50
MAX_LEN = 128
PATIENCE = 3

As declared before, we now initialize our model, optimizer (using Adam) and loss (cross-entropy).

In [None]:
encoder    = Encoder(vocab_size, EMBED_DIM, HIDDEN_DIM,
                    NUM_LAYERS, DROPOUT)
attention  = Attention(HIDDEN_DIM)
decoder    = Decoder(vocab_size, EMBED_DIM, HIDDEN_DIM,
                    NUM_LAYERS, attention, DROPOUT)

model = Seq2Seq(encoder, decoder, device).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss(ignore_index=joint_vocab['<PAD>'])

The function `maybe_truncate` discards anything after `MAX_LEN` in a batch.

In [None]:
def maybe_truncate(tensor):
    if tensor.size(1) > MAX_LEN:
        return tensor[:, :MAX_LEN]
    return tensor

This function `train_epoch` takes our sequence-to-sequence model, a `loader` that gives batches of pairs, an `optimizer` to update parameters, a loss function (`criterion`), and a limit on the gradient vector (`clip`). It loops over all batches (pairs of data) to train the model on this epoch's loss and update parameters.

In [None]:
def train_epoch(model, loader, optimizer, criterion, clip=1.0):
    model.train()
    epoch_loss = 0

    for src, tgt in loader:
        src = maybe_truncate(src).to(device)
        tgt = maybe_truncate(tgt).to(device)

        optimizer.zero_grad()
        output = model(src, tgt)                 # [batch, tgt_len, vocab]

        # reshape for CrossEntropy (ignore first token <SOS>)
        output = output[:, 1:, :].reshape(-1, vocab_size)
        tgt    = tgt[:, 1:].reshape(-1)

        loss = criterion(output, tgt)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        epoch_loss += loss.item()
    return epoch_loss / len(loader)

This function `evaluate` takes our sequence-to-sequence model, a `loader` that gives batches of pairs, and an `optimizer` to update parameters. It loops over all batches (pairs of data) to find the loss of the epoch on the validation set.

In [None]:
def evaluate(model, loader, criterion):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for src, tgt in loader:
            src = maybe_truncate(src).to(device)
            tgt = maybe_truncate(tgt).to(device)

            output = model(src, tgt, teacher_forcing_ratio=0.0)

            output = output[:, 1:, :].reshape(-1, vocab_size)
            tgt    = tgt[:, 1:].reshape(-1)

            loss = criterion(output, tgt)
            epoch_loss += loss.item()
    return epoch_loss / len(loader)

We then go through the training loop in the cell below to put this all together. We go through the declared number of epochs and update parameters while getting training and validation loss unless there has been no improvement for `PATIENCE` times.

In [None]:
best_val = float('inf')
no_improve = 0

print("Starting training...")
for epoch in range(1, EPOCHS + 1):
    torch.cuda.empty_cache()
    gc.collect()

    start = time.time()
    train_loss = train_epoch(model, train_loader, optimizer, criterion)
    val_loss   = evaluate(model, val_loader, criterion)
    elapsed = time.time() - start

    mins, secs = divmod(int(elapsed), 60)

    if val_loss < best_val:
        best_val = val_loss
        torch.save(model.state_dict(), 'best_joint_model.pt')
        no_improve = 0
        status = "Saved best"
    else:
        no_improve += 1
        status = f"No improvement ({no_improve}/{PATIENCE})"

    print(f"Epoch {epoch:02} | {mins}m {secs}s")
    print(f"   Train loss: {train_loss:.4f} | PPL: {torch.exp(torch.tensor(train_loss)):.2f}")
    print(f"   Val   loss: {val_loss:.4f} | PPL: {torch.exp(torch.tensor(val_loss)):.2f}  {status}")

    if no_improve >= PATIENCE:
        print(f"Early stopping after {epoch} epochs (val loss hasn't improved).")
        break

If we do not stop early, we load the best checkpoint below to prepare for inference.

In [None]:
model.load_state_dict(torch.load('best_joint_model.pt', map_location=device))
model.eval()

We then create a few functions to simplify inference. The `translate` function tokenizes an input sentence, encodes it, and autoregressively generates an output sequence without teacher forcing until either an `<EOS>` token is produced or the maximum length is reached.

In [None]:
def translate(sentence: str,
              src_vocab: dict,
              tgt_vocab: dict,
              max_out_len: int = 50) -> str:
    """
    `src_vocab` and `tgt_vocab` are BOTH the *joint* vocab.
    The function only cares about which language you feed in as `sentence`.
    """
    src_ids = to_ids(sentence)               # uses joint_vocab internally
    src_tensor = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)

    with torch.no_grad():
        enc_outputs, hidden, cell = model.encoder(src_tensor)

        # collapse bidirectional hidden states (same logic as in Seq2Seq.forward)
        hidden = hidden.view(model.encoder.lstm.num_layers, 2,
                             1, -1)
        cell   = cell.view(model.encoder.lstm.num_layers, 2,
                           1, -1)

        hidden = (hidden[:, 0, :, :] + hidden[:, 1, :, :]) / 2
        cell   = cell[:, 0, :, :].contiguous()

        # first decoder input = <SOS>
        input_tok = torch.tensor([tgt_vocab['<SOS>']], device=device)

        generated = []

        for _ in range(max_out_len):
            out, hidden, cell = model.decoder(input_tok, hidden, cell, enc_outputs)
            pred_id = out.argmax(1).item()
            if pred_id == tgt_vocab['<EOS>']:
                break
            generated.append(pred_id)
            input_tok = torch.tensor([pred_id], device=device)

    inv_vocab = {idx: tok for tok, idx in tgt_vocab.items()}
    tokens = [inv_vocab.get(i, '<UNK>') for i in generated]
    return " ".join(tokens)


# Helper wrappers for the two directions (they both reuse the joint vocab)
def shakespeare_to_modern(sentence):
    return translate(sentence, joint_vocab, joint_vocab)

def modern_to_shakespeare(sentence):
    return translate(sentence, joint_vocab, joint_vocab)

# Testing Output Setup
(Code explained in the Marian MT notebook. Only difference is not using the `translate` function since we have `shakespeare_to_modern` and `modern_to_shakespeare` in this context.)

In [None]:
from google.colab import files
uploaded = files.upload()

#import file called test.tsv

In [None]:
from pathlib import Path
import re
from datasets import load_dataset

testset = load_dataset(
    "csv",
    data_files={"full": str(Path("./test.tsv"))},
    delimiter="\t",
    column_names=["shakespeare", "modern"]
)


In [None]:
import re

def remove_numbers(row):
    row["shakespeare"] = re.sub(r"\d+", "", str(row["shakespeare"]))
    row["modern"] = re.sub(r"\d+", "", str(row["modern"]))
    return row

testset["full"] = testset["full"].map(remove_numbers)


In [None]:
def too_long(testset):
    sh_words = len(str(testset["shakespeare"]).split())
    mod_words = len(str(testset["modern"]).split())
    return (sh_words <= 25) and (mod_words <= 25)

testset["full"] = testset["full"].filter(too_long)


In [None]:
import csv

with open("modern_test.tsv", "w", newline="", encoding="utf-8") as f:
    writer = csv.writer(f, delimiter="\t")
    writer.writerow(["original modern", "translated from early modern"])

    for i in range(len(testset["full"])):
        original = str(testset["full"][i]["shakespeare"])

        translated = shakespeare_to_modern(original)
        print(i)
        print(original)
        print(translated)

        writer.writerow([str(testset["full"][i]["modern"]), translated])


In [None]:
with open("shakespeare_test.tsv", "w", newline="", encoding="utf-8") as f:
    writer = csv.writer(f, delimiter="\t")
    writer.writerow(["original early modern", "translated from modern"])

    for i in range(len(testset["full"])):
        original = str(testset["full"][i]["modern"])

        translated = modern_to_shakespeare(original)

        writer.writerow([str(testset["full"][i]["shakespeare"]), translated])
