In [24]:
import torch
from torch import cuda, device, load, save, Tensor, tensor
from torch.utils.data import DataLoader
from datasets import load_dataset
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Literal, NamedTuple, Callable, Tuple
from torch.backends import mps
from torch.nn import LSTM, Embedding, Linear, Module, CrossEntropyLoss, Module
from torch.optim import Adam, Optimizer
from tqdm import tqdm
from transformers import AutoTokenizer

# Data Setup

In [25]:
TORCH_MODEL_STORAGE_PATH = Path(".models")

class ModelCheckpoint(NamedTuple):
    "Model checkpoint data."
    epoch: int
    train_loss: float
    val_loss: float
    state_dict: Dict[str, Any]

def count_params(model: Module) -> int:
    """Count the number of model parameters."""
    return sum(len(p) for p in model.parameters())

def get_best_device(
        cuda_priority: Literal[1, 2, 3] = 1,
        mps_priority: Literal[1, 2, 3] = 2,
        cpu_priority: Literal[1, 2, 3] = 3,
    ) -> device:
    """Return the best device available on the machine."""
    device_priorities = sorted(
        (("cuda", cuda_priority), ("mps", mps_priority), ("cpu", cpu_priority)),
        key=lambda e: e[1]
    )
    for device_type, _ in device_priorities:
        if device_type == "cuda" and cuda.is_available():
            return device("cuda")
        elif device_type == "mps" and mps.is_available():
            return device("mps")
        elif device_type == "cpu":
            return device("cpu")

def save_model(model: Module, name: str, loss: float) -> None:
    """Save models to disk."""
    if not TORCH_MODEL_STORAGE_PATH.exists():
        TORCH_MODEL_STORAGE_PATH.mkdir()
    model_dir = TORCH_MODEL_STORAGE_PATH / name
    if not model_dir.exists():
        model_dir.mkdir()
    timestamp = datetime.now().isoformat(timespec="seconds")
    loss_str = f"{loss:.4f}".replace(".", "_") if loss else ""
    filename = f"trained@{timestamp};loss={loss_str}.pt"
    model.to(device("cpu"))
    save(model, model_dir / filename)

def load_model(name: str, latest: bool = False) -> Module:
    """Load model with best loss."""
    if not TORCH_MODEL_STORAGE_PATH.exists():
        TORCH_MODEL_STORAGE_PATH.mkdir()
    model_dir = TORCH_MODEL_STORAGE_PATH / name

    if not latest:
        stored_models = [
            (file_path, str(file_path).split("loss=")[1])
            for file_path in model_dir.glob("*.pt")
        ]
        model = sorted(stored_models, key=lambda e: e[1])[0][0]
    else:
        stored_models = [
            (file_path, str(file_path).split("trained@")[1][:19])
            for file_path in model_dir.glob("*.pt")
        ]
        model = sorted(stored_models, key=lambda e: datetime.fromisoformat(e[1]))[-1][0]

    print(f"loading {model}")
    model = load(model)
    return model

def _early_stop(train_loss: Dict[int, float], epoch_window: int = 3) -> bool:
    """Flag when training no longer improves loss."""
    if len(train_loss) < epoch_window + 1:
        return False
    else:
        losses = list(train_loss.values())
        current_loss = losses[-1]
        avg_window_loss = sum(losses[-(epoch_window + 1) : -1]) / epoch_window
        if current_loss >= avg_window_loss:
            return True
        else:
            return False

# RNN Language Generation

In [27]:
# Load your dataset
ds = load_dataset("sedthh/gutenberg_english", split="train").select(range(10))

# Load your tokenizer (still needed for accurate boundary detection)
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
block_size = 256

def tokenize_function(examples):
    return tokenizer(examples["TEXT"], return_special_tokens_mask=False)

tokenized = ds.map(tokenize_function, batched=True, remove_columns=["TEXT"])


Map:   0%|          | 0/10 [00:00<?, ? examples/s]Token indices sequence length is longer than the specified maximum sequence length for this model (615 > 512). Running this sequence through the model will result in indexing errors
Map: 100%|██████████| 10/10 [00:01<00:00,  5.35 examples/s]


In [33]:
def group_texts(examples):
    # Concatenate all texts
    concatenated = sum(examples["input_ids"], [])
    total_length = len(concatenated)
    total_length = (total_length // block_size) * block_size  # drop remainder
    # Split by chunks of block_size
    result = {
        "input_ids": [concatenated[i : i + block_size] for i in range(0, total_length, block_size)]
    }
    return result

lm_dataset = tokenized.map(group_texts, batched=True, batch_size=1, remove_columns=tokenized.column_names)

Map: 100%|██████████| 10/10 [00:00<00:00, 47.15 examples/s]


In [36]:
split_dataset = lm_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = split_dataset["train"]
val_dataset = split_dataset["test"]

In [39]:
def collate_fn(batch):
    # Each batch element is {"input_ids": [int, int, ...]}
    input_ids = torch.tensor([example["input_ids"] for example in batch], dtype=torch.long)
    x_batch = input_ids[:, :-1]
    y_batch = input_ids[:, 1:]
    return x_batch, y_batch

train_dataloader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    collate_fn=collate_fn,
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False,
    collate_fn=collate_fn,
)

In [49]:
class NextWordPredictionRNN(Module):
    """LSTM for predicting the next token in a sequence."""
    
    def __init__(self, size_vocab: int, size_embed: int, size_hidden: int):
        super().__init__()
        self._size_hidden = size_hidden
        self._embedding = Embedding(size_vocab, size_embed)
        self._lstm = LSTM(size_embed, size_hidden, batch_first=True)
        self._linear = Linear(size_hidden, size_vocab)
    
    def forward(self, x: Tensor, hidden: Tensor, cell: Tensor) -> Tensor:
        # forward step to generate words
        out = self._embedding(x).unsqueeze(1)
        out, (hidden, cell) = self._lstm(out, (hidden, cell))
        out = self._linear(out).reshape(out.shape[0], -1)
        return out, hidden, cell

    def initialize(self, batch_size: int, device_: device) -> Tuple[Tensor, Tensor]:
        hidden = torch.zeros(1, batch_size, self._size_hidden, device=device_)
        cell = torch.zeros(1, batch_size, self._size_hidden, device=device_)
        return hidden, cell

In [47]:
def _train_step(
    x_batch: Tensor,
    y_batch: Tensor,
    model: Module,
    loss_fn: Callable[[Tensor, Tensor], Tensor],
    optimizer: Optimizer,
    device: device,
) -> Tensor:
    """One iteration of the training loop (for one batch)."""
    model.train()
    batch_size, sequence_length = x_batch.shape 
    
    loss_batch = tensor(0.0, device=device)
    optimizer.zero_grad(set_to_none=True)
    
    hidden, cell = model.initialize(batch_size, device)
    for n in range(sequence_length):
        y_pred, hidden, cell = model(x_batch[:, n], hidden, cell)
        loss_batch += loss_fn(y_pred, y_batch[:, n])
    loss_batch.backward()
    optimizer.step() 

    return loss_batch / sequence_length

@torch.no_grad()
def _val_step(
    x_batch: Tensor,
    y_batch: Tensor,
    model: Module,
    loss_fn: Callable[[Tensor, Tensor], Tensor],
    device: device,
) -> Tensor:
    """One iteration of the validation loop (for one batch)."""
    model.eval()
    batch_size, sequence_length = x_batch.shape 
    
    loss_batch = tensor(0.0, device=device)

    hidden, cell = model.initialize(batch_size, device)
    for n in range(sequence_length):
        y_pred, hidden, cell = model(x_batch[:, n], hidden, cell)
        loss_batch += loss_fn(y_pred, y_batch[:, n])
    
    return loss_batch / sequence_length

In [42]:
def train(
    model: Module,
    train_data: DataLoader,
    val_data: DataLoader,
    n_epochs: int,
    learning_rate: float = 0.001,
    random_seed: int =42,
    device: device = 'cpu',
) -> Tuple[Dict[int, float], Dict[int, float], ModelCheckpoint]:
    """Training loop for LSTM flavoured RNNs on sequence data."""
    torch.manual_seed(random_seed)
    model.to(device)

    global PAD_TOKEN_IDX
    optimizer = Adam(model.parameters(), lr=learning_rate)
    loss_fn = CrossEntropyLoss(ignore_index=PAD_TOKEN_IDX)
    
    train_losses: Dict[int, float] = {}
    val_losses: Dict[int, float] = {}
    
    for epoch in range(1, n_epochs + 1):
        loss_train = tensor(0.0).to(device)
        for i, (x_batch, y_batch) in enumerate((pbar := tqdm(train_data)), start=1):
            x = x_batch.to(device, non_blocking=True)
            y = y_batch.to(device, non_blocking=True)
            loss_train += _train_step(x, y, model, loss_fn, optimizer, device)
            pbar.set_description(f"epoch {epoch} training loss = {loss_train/i:.4f}")
        
        loss_val = tensor(0.0).to(device)
        for x_batch, y_batch in val_data:
            x = x_batch.to(device, non_blocking=True)
            y = y_batch.to(device, non_blocking=True)
            loss_val += _val_step(x, y, model, loss_fn, device)
        
        epoch_train_loss = loss_train.item() / len(train_data)
        epoch_val_loss = loss_val.item() / len(val_data)
        
        if epoch == 1 or epoch_val_loss < min(val_losses.values()):
            best_checkpoint = ModelCheckpoint(
                epoch, epoch_train_loss, epoch_val_loss, model.state_dict().copy()
            )
        train_losses[epoch] = epoch_train_loss 
        val_losses[epoch] = epoch_val_loss 

        if _early_stop(val_losses):
            break 
    
    print("\nbest model:")
    print(f"|-- epoch: {best_checkpoint.epoch}")
    print(f"|-- validation loss: {best_checkpoint.val_loss:.4f}")

    model.load_state_dict(best_checkpoint.state_dict)
    return train_losses, val_losses, best_checkpoint

In [64]:
@torch.no_grad()
def generate(
    model: NextWordPredictionRNN,
    prompt: str,
    tokenizer: AutoTokenizer,
    strategy: Literal["greedy", "sample", "topk"] = "greedy",
    output_length: int = 60,
    temperature: float = 1.0,
    random_seed: int = 42,
    device: device = "cpu",
    *,
    k: int = 2,
) -> str:
    """Generate new text conditional on a text prompt."""
    torch.manual_seed(random_seed)

    model.to(device)
    model.eval()
    
    encoded_prompt = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
    prompt_tokens = encoded_prompt["input_ids"].to(device).squeeze(0).tolist()

    hidden, cell = model.initialize(batch_size=1, device_=device)

    # Prime model
    for token_id in prompt_tokens[:-1]:
        x = torch.tensor([token_id], device=device)
        _, hidden, cell = model(x, hidden, cell)

    token_sequence = prompt_tokens.copy()

    for _ in range(output_length):
        x = torch.tensor([token_sequence[-1]], device=device)
        logits, hidden, cell = model(x, hidden, cell)
        token_logits = logits[0] / max(temperature, 1e-6)

        if strategy == "greedy":
            token_pred = torch.argmax(token_logits)
        elif strategy in {"sample", "topk"}:
            if strategy == "topk":
                v, i = torch.topk(token_logits, k)
                filtered_logits = token_logits.clone()
                filtered_logits[filtered_logits < v[-1]] = -float("Inf")
            else:
                filtered_logits = token_logits
            probs = torch.softmax(filtered_logits, dim=-1)
            token_pred = torch.multinomial(probs, num_samples=1).item()
        else:
            raise ValueError(f"Unknown decoding strategy: {strategy}")

        token_sequence.append(token_pred if isinstance(token_pred, int) else token_pred.item())

    return tokenizer.decode(token_sequence, skip_special_tokens=True)

In [50]:
SIZE_EMBED = 256
SIZE_HIDDEN = 512

MAX_EPOCHS = 30
BATCH_SIZE = 256
MAX_SEQ_LEN = 100
MIN_SEQ_LEN = 10
MIN_WORD_FREQ = 2
LEARNING_RATE = 0.005

model = NextWordPredictionRNN(
    tokenizer.vocab_size,
    SIZE_EMBED,
    SIZE_HIDDEN
)

In [None]:
train_losses, val_losses, best_checkpoint = train(
    model=model,
    train_data=train_dataloader,
    val_data=val_dataloader,
    n_epochs=MAX_EPOCHS,
    learning_rate=LEARNING_RATE
)

epoch 1 training loss = 3.8046:  29%|██▊       | 75/261 [05:34<15:07,  4.88s/it]

In [67]:
generate(
    model=model,
    prompt="this is the worst timeline given the circumstances. You must",
    tokenizer=tokenizer,
    strategy="topk",
)

'this is the worst timeline given the circumstances. You must she could not, ” Alice said, Thou art a man, she could not, and said, I am the LORD ’ s house, and she could not, and she said, I am the LORD, and said, I am the LORD ’ s name ’'