In [15]:
import math
from pathlib import Path
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset, DatasetDict
import sentencepiece as spm

pl.seed_everything(42, workers=True)

Seed set to 42


42

In [16]:
ds = load_dataset("opus_books", "de-en")

README.md: 0.00B [00:00, ?B/s]

de-en/train-00000-of-00001.parquet:   0%|          | 0.00/8.80M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/51467 [00:00<?, ? examples/s]

In [17]:
data_dir = Path("../data/seq2seq_nmt")
sp_dir = data_dir / "spm"
sp_dir.mkdir(parents=True, exist_ok=True)

# Dump raw text to files SentencePiece can read
def dump_corpus(split, lang, out_path):
    with out_path.open("w", encoding="utf-8") as f:
        for ex in ds[split]:
            f.write(ex["translation"][lang].strip() + "\n")

dump_corpus("train", "en", sp_dir / "train.en.txt")
dump_corpus("train", "de", sp_dir / "train.de.txt")

# Train two SPM models (unigram or bpe; paper used words, but subwords are practical)
VOCAB_EN = 8000
VOCAB_DE = 8000
SPECIALS = ["<pad>", "<bos>", "<eos>"]  # idx 0..2 (we'll enforce)

def train_spm(input_txt: Path, model_prefix: str, vocab_size: int):
    cmd = (
        f"--input={input_txt} --model_prefix={model_prefix} "
        f"--vocab_size={vocab_size - len(SPECIALS)} --character_coverage=1.0 "
        f"--pad_id=0 --pad_piece=<pad> "
        f"--bos_id=1 --bos_piece=<bos> "
        f"--eos_id=2 --eos_piece=<eos> "
        f"--unk_id=3 --model_type=unigram"
    )
    spm.SentencePieceTrainer.Train(cmd)

if not (sp_dir / "en.model").exists():
    train_spm(sp_dir / "train.en.txt", str(sp_dir / "en"), VOCAB_EN)
if not (sp_dir / "de.model").exists():
    train_spm(sp_dir / "train.de.txt", str(sp_dir / "de"), VOCAB_DE)

sp_en = spm.SentencePieceProcessor(model_file=str(sp_dir / "en.model"))
sp_de = spm.SentencePieceProcessor(model_file=str(sp_dir / "de.model"))

PAD, BOS, EOS, UNK = 0, 1, 2, 3
VOCAB_EN = sp_en.get_piece_size()
VOCAB_DE = sp_de.get_piece_size()

sentencepiece_trainer.cc(178) LOG(INFO) Running command: --input=../data/seq2seq_nmt/spm/train.en.txt --model_prefix=../data/seq2seq_nmt/spm/en --vocab_size=7997 --character_coverage=1.0 --pad_id=0 --pad_piece=<pad> --bos_id=1 --bos_piece=<bos> --eos_id=2 --eos_piece=<eos> --unk_id=3 --model_type=unigram
sentencepiece_trainer.cc(78) LOG(INFO) Starts training with : 
trainer_spec {
  input: ../data/seq2seq_nmt/spm/train.en.txt
  input_format: 
  model_prefix: ../data/seq2seq_nmt/spm/en
  model_type: UNIGRAM
  vocab_size: 7997
  self_test_sample_size: 0
  character_coverage: 1
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  required_char

In [18]:
SRC_LANG = "en"   # source language (encoder input)
TGT_LANG = "de"   # target language (decoder output)

print("Available splits:", list(ds.keys()))

# If "validation" is missing, carve it out of train
if "validation" not in ds:
    split = ds["train"].train_test_split(test_size=0.05, seed=42)
    ds = DatasetDict({"train": split["train"], "validation": split["test"], **({} if "test" not in ds else {"test": ds["test"]})})

print("Now have splits:", list(ds.keys()))


Available splits: ['train']
Now have splits: ['train', 'validation']


In [19]:
@dataclass
class Example:
    src_ids: list
    src_len: int
    tgt_in: list
    tgt_out: list
    tgt_len: int

class MTDataset(Dataset):
    def __init__(self, split: str, reverse_source: bool = True, max_len: int = 100):
        self.data = []
        self.reverse_source = reverse_source
        self.max_len = max_len
        for ex in ds[split]:
            src = ex["translation"]["en"].strip()
            tgt = ex["translation"]["de"].strip()

            src_ids = sp_en.encode(src, out_type=int)
            tgt_ids = sp_de.encode(tgt, out_type=int)

            if len(src_ids) == 0 or len(tgt_ids) == 0:
                continue
            if len(src_ids) > max_len or len(tgt_ids) > max_len:
                continue

            if reverse_source:
                src_ids = list(reversed(src_ids))

            # decoder inputs/outputs
            tgt_in  = [BOS] + tgt_ids
            tgt_out = tgt_ids + [EOS]

            self.data.append(Example(src_ids, len(src_ids), tgt_in, tgt_out, len(tgt_out)))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        e = self.data[i]
        return e

def collate(batch):
    # pad to batch max
    src_max = max(e.src_len for e in batch)
    tgt_max = max(e.tgt_len for e in batch)  # for tgt_in/out they share length

    bs = len(batch)
    src = torch.full((bs, src_max), PAD, dtype=torch.long)
    src_lens = torch.tensor([e.src_len for e in batch], dtype=torch.long)

    tgt_in = torch.full((bs, tgt_max), PAD, dtype=torch.long)
    tgt_out = torch.full((bs, tgt_max), PAD, dtype=torch.long)
    tgt_lens = torch.tensor([e.tgt_len for e in batch], dtype=torch.long)

    for i, e in enumerate(batch):
        src[i, :e.src_len] = torch.tensor(e.src_ids)
        tgt_in[i, :len(e.tgt_in)] = torch.tensor(e.tgt_in)
        tgt_out[i, :len(e.tgt_out)] = torch.tensor(e.tgt_out)

    return {
        "src": src, "src_lens": src_lens,
        "tgt_in": tgt_in, "tgt_out": tgt_out, "tgt_lens": tgt_lens
    }

train_ds = MTDataset("train", reverse_source=True, max_len=80)
valid_ds = MTDataset("validation", reverse_source=True, max_len=80)

len(train_ds), len(valid_ds)

(46214, 2438)

In [20]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, emb_dim=512, hidden=512, num_layers=3, dropout=0.2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, emb_dim, padding_idx=PAD)
        self.lstm = nn.LSTM(emb_dim, hidden, num_layers=num_layers, batch_first=True,
                            dropout=dropout if num_layers > 1 else 0.0)
    def forward(self, src, src_lens):
        emb = self.embed(src)
        packed = nn.utils.rnn.pack_padded_sequence(emb, src_lens.cpu(), batch_first=True, enforce_sorted=False)
        _, (h, c) = self.lstm(packed)
        return h, c  # [L, B, H]

class Decoder(nn.Module):
    def __init__(self, vocab_size, emb_dim=512, hidden=512, num_layers=3, dropout=0.2, tie_weights=False):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, emb_dim, padding_idx=PAD)
        self.lstm = nn.LSTM(emb_dim, hidden, num_layers=num_layers, batch_first=True,
                            dropout=dropout if num_layers > 1 else 0.0)
        self.proj = nn.Linear(hidden, vocab_size, bias=False)
        if tie_weights:
            assert emb_dim == hidden
            self.proj.weight = self.embed.weight

    def forward(self, tgt_in, h0, c0):
        emb = self.embed(tgt_in)         # [B,T,E]
        out, (h, c) = self.lstm(emb, (h0, c0))
        logits = self.proj(out)          # [B,T,V]
        return logits, (h, c)

class Seq2SeqLM(pl.LightningModule):
    def __init__(self, src_vocab, tgt_vocab, emb_dim=512, hidden=512, num_layers=3, dropout=0.2,
                 lr=3e-4, label_smoothing=0.1, tie_weights=False):
        super().__init__()
        self.save_hyperparameters()
        self.encoder = Encoder(src_vocab, emb_dim, hidden, num_layers, dropout)
        self.decoder = Decoder(tgt_vocab, emb_dim, hidden, num_layers, dropout, tie_weights)
        self.crit = nn.CrossEntropyLoss(ignore_index=PAD, label_smoothing=label_smoothing)

    def forward(self, batch):
        h0, c0 = self.encoder(batch["src"], batch["src_lens"])
        logits, _ = self.decoder(batch["tgt_in"], h0, c0)
        return logits

    def _step(self, batch, stage):
        logits = self(batch)
        B, T, V = logits.shape
        loss = self.crit(logits.view(B*T, V), batch["tgt_out"].view(B*T))
        self.log(f"{stage}/loss", loss, prog_bar=(stage=="train"), on_epoch=True, on_step=(stage=="train"))
        return loss

    def training_step(self, batch, idx):   return self._step(batch, "train")
    def validation_step(self, batch, idx): return self._step(batch, "val")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

    @torch.no_grad()
    def greedy_decode(self, src_ids, src_len, max_len=80):
        self.eval()
        h, c = self.encoder(src_ids.unsqueeze(0), src_len.unsqueeze(0))
        y = torch.tensor([[BOS]], device=self.device)
        outs = []
        for _ in range(max_len):
            logits, (h, c) = self.decoder(y, h, c)
            nxt = logits[:, -1].softmax(-1).argmax(-1)
            if nxt.item() == EOS: break
            outs.append(nxt.item())
            y = torch.cat([y, nxt.unsqueeze(1)], dim=1)
        return outs


In [21]:
BATCH_SIZE = 32

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, collate_fn=collate)
valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, collate_fn=collate)

In [23]:
model = Seq2SeqLM(
    src_vocab=VOCAB_EN,     # English SPM size
    tgt_vocab=VOCAB_DE,     # German  SPM size
    emb_dim=512, hidden=512, num_layers=3, dropout=0.2,
    lr=3e-4, label_smoothing=0.1, tie_weights=False
)

trainer = pl.Trainer(
    max_epochs=10,
    accelerator="auto",
    devices="auto",
    log_every_n_steps=50,
)
trainer.fit(model, train_loader, valid_loader)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params | Mode 
-----------------------------------------------------
0 | encoder | Encoder          | 10.4 M | train
1 | decoder | Decoder          | 14.5 M | train
2 | crit    | CrossEntropyLoss | 0      | train
-----------------------------------------------------
24.9 M    Trainable params
0         Non-trainable params
24.9 M    Total params
99.564    Total estimated model params size (MB)
8         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


In [24]:
def detok_de(ids):
    # strip at EOS if present
    toks = [t for t in ids if t not in (PAD, BOS)]
    if EOS in toks:
        toks = toks[:toks.index(EOS)]
    return sp_de.decode(toks)

@torch.no_grad()
def translate_examples(n=5):
    model.eval()
    for i in range(min(n, len(valid_ds))):
        ex = valid_ds[i]
        src_ids = torch.tensor(ex.src_ids, device=model.device)
        src_len = torch.tensor(ex.src_len, device=model.device)
        out_ids = model.greedy_decode(src_ids, src_len, max_len=80)
        hyp = sp_de.decode(out_ids)

        # gold
        gold = detok_de(ex.tgt_out)
        # remember: we reversed source for training; display the unreversed source
        src_original = sp_en.decode(list(reversed(ex.src_ids)))
        print(f"SRC: {src_original}\nHYP: {hyp}\nREF: {gold}\n---")

In [25]:
translate_examples(5)

SRC: I turned, and Miss Ingram darted forwards from her sofa: the others, too, looked up from their several occupations; for at the same time a crunching of wheels and a splashing tramp of horse-hoofs became audible on the wet gravel.
HYP: Ich erhob sich und blickte in die Höhe, die sich in der Kirche zu sehen schien.
REF: Ich wandte mich um und sah, wie Miß Ingram mit der größten Eilfertigkeit von ihrem Sofa aufsprang. Auch die Übrigen blickten von ihren verschiedenen Beschäftigungen auf, denn im selben Augenblick wurde ein Knirschen von Rädern und platschende Huftritte draußen auf dem durchweichten Kieswege vor dem Hause hörbar.
---
SRC: I found that I could never let anyone else deal with this sort of work unless I wanted to harm both the client and the job I had taken on.
HYP: Ich hatte nicht mehr zu sehen, als ob ich mich in der That zu verrühren.
REF: Ich fand, daß ich diese Arbeit niemandem überlassen dürfe, wenn ich mich nicht an meinen Klienten und an der Aufgabe, die ich über