In [174]:
import torch
import math
import torch.nn as nn
from torch import cuda, device, load, save, Tensor, tensor
from torch.utils.data import DataLoader
from datasets import load_dataset
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Literal, NamedTuple, Callable, Tuple
from torch.backends import mps
from torch.optim import Adam, Optimizer
from tqdm import tqdm
from transformers import AutoTokenizer
import copy

# Utils 

In [175]:
TORCH_MODEL_STORAGE_PATH = Path(".models")

class ModelCheckpoint(NamedTuple):
    "Model checkpoint data."
    epoch: int
    train_loss: float
    val_loss: float
    state_dict: Dict[str, Any]

def count_params(model: nn.Module) -> int:
    """Count the number of model parameters."""
    return sum(len(p) for p in model.parameters())

def get_best_device(
        cuda_priority: Literal[1, 2, 3] = 1,
        mps_priority: Literal[1, 2, 3] = 2,
        cpu_priority: Literal[1, 2, 3] = 3,
    ) -> device:
    """Return the best device available on the machine."""
    device_priorities = sorted(
        (("cuda", cuda_priority), ("mps", mps_priority), ("cpu", cpu_priority)),
        key=lambda e: e[1]
    )
    for device_type, _ in device_priorities:
        if device_type == "cuda" and cuda.is_available():
            return device("cuda")
        elif device_type == "mps" and mps.is_available():
            return device("mps")
        elif device_type == "cpu":
            return device("cpu")

def save_model(model: nn.Module, name: str, loss: float) -> None:
    """Save models to disk."""
    if not TORCH_MODEL_STORAGE_PATH.exists():
        TORCH_MODEL_STORAGE_PATH.mkdir()
    model_dir = TORCH_MODEL_STORAGE_PATH / name
    if not model_dir.exists():
        model_dir.mkdir()
    timestamp = datetime.now().isoformat(timespec="seconds")
    loss_str = f"{loss:.4f}".replace(".", "_") if loss else ""
    filename = f"trained@{timestamp};loss={loss_str}.pt"
    model.to(device("cpu"))
    save(model, model_dir / filename)

def load_model(name: str, latest: bool = False) -> nn.Module:
    """Load model with best loss."""
    if not TORCH_MODEL_STORAGE_PATH.exists():
        TORCH_MODEL_STORAGE_PATH.mkdir()
    model_dir = TORCH_MODEL_STORAGE_PATH / name

    if not latest:
        stored_models = [
            (file_path, str(file_path).split("loss=")[1])
            for file_path in model_dir.glob("*.pt")
        ]
        model = sorted(stored_models, key=lambda e: e[1])[0][0]
    else:
        stored_models = [
            (file_path, str(file_path).split("trained@")[1][:19])
            for file_path in model_dir.glob("*.pt")
        ]
        model = sorted(stored_models, key=lambda e: datetime.fromisoformat(e[1]))[-1][0]

    print(f"loading {model}")
    model = load(model)
    return model

def _early_stop(train_loss: Dict[int, float], epoch_window: int = 3) -> bool:
    """Flag when training no longer improves loss."""
    if len(train_loss) < epoch_window + 1:
        return False
    else:
        losses = list(train_loss.values())
        current_loss = losses[-1]
        avg_window_loss = sum(losses[-(epoch_window + 1) : -1]) / epoch_window
        if current_loss >= avg_window_loss:
            return True
        else:
            return False

# RNN Language Generation

In [176]:
class NextWordPredictionRNN(nn.Module):
    """LSTM for predicting the next token in a sequence."""
    
    def __init__(self, size_vocab: int, size_embed: int, size_hidden: int):
        super().__init__()
        self._size_hidden = size_hidden
        self._embedding = nn.Embedding(size_vocab, size_embed)
        self._lstm = nn.LSTM(size_embed, size_hidden, batch_first=True)
        self._linear = nn.Linear(size_hidden, size_vocab)
    
    def forward(self, x: Tensor, hidden: Tensor, cell: Tensor) -> Tensor:
        # forward step to generate words
        out = self._embedding(x).unsqueeze(1)
        out, (hidden, cell) = self._lstm(out, (hidden, cell))
        out = self._linear(out).reshape(out.shape[0], -1)
        return out, hidden, cell

    def initialize(self, batch_size: int, device_: device) -> Tuple[Tensor, Tensor]:
        hidden = torch.zeros(1, batch_size, self._size_hidden, device=device_)
        cell = torch.zeros(1, batch_size, self._size_hidden, device=device_)
        return hidden, cell

In [177]:
def _train_step(
    x_batch: Tensor,
    y_batch: Tensor,
    model: nn.Module,
    loss_fn: Callable[[Tensor, Tensor], Tensor],
    optimizer: Optimizer,
    device: device,
) -> Tensor:
    """One iteration of the training loop (for one batch)."""
    model.train()
    batch_size, sequence_length = x_batch.shape 
    
    loss_batch = tensor(0.0, device=device)
    optimizer.zero_grad(set_to_none=True)
    
    hidden, cell = model.initialize(batch_size, device)
    for n in range(sequence_length):
        y_pred, hidden, cell = model(x_batch[:, n], hidden, cell)
        loss_batch += loss_fn(y_pred, y_batch[:, n])
    loss_batch.backward()
    optimizer.step() 

    return loss_batch / sequence_length

@torch.no_grad()
def _val_step(
    x_batch: Tensor,
    y_batch: Tensor,
    model: nn.Module,
    loss_fn: Callable[[Tensor, Tensor], Tensor],
    device: device,
) -> Tensor:
    """One iteration of the validation loop (for one batch)."""
    model.eval()
    batch_size, sequence_length = x_batch.shape 
    
    loss_batch = tensor(0.0, device=device)

    hidden, cell = model.initialize(batch_size, device)
    for n in range(sequence_length):
        y_pred, hidden, cell = model(x_batch[:, n], hidden, cell)
        loss_batch += loss_fn(y_pred, y_batch[:, n])
    
    return loss_batch / sequence_length

In [178]:
def train(
    model: nn.Module,
    train_data: DataLoader,
    val_data: DataLoader,
    n_epochs: int,
    learning_rate: float = 0.001,
    random_seed: int =42,
    device: device = get_best_device(),
) -> Tuple[Dict[int, float], Dict[int, float], ModelCheckpoint]:
    """Training loop for LSTM flavoured RNNs on sequence data."""
    torch.manual_seed(random_seed)
    model.to(device)

    optimizer = Adam(model.parameters(), lr=learning_rate)
    loss_fn = nn.CrossEntropyLoss()
    
    train_losses: Dict[int, float] = {}
    val_losses: Dict[int, float] = {}
    
    for epoch in range(1, n_epochs + 1):
        loss_train = tensor(0.0).to(device)
        for i, (x_batch, y_batch) in enumerate((pbar := tqdm(train_data)), start=1):
            x = x_batch.to(device, non_blocking=True)
            y = y_batch.to(device, non_blocking=True)
            loss_train += _train_step(x, y, model, loss_fn, optimizer, device)
            pbar.set_description(f"epoch {epoch} training loss = {loss_train/i:.4f}")
        
        loss_val = tensor(0.0).to(device)
        for x_batch, y_batch in val_data:
            x = x_batch.to(device, non_blocking=True)
            y = y_batch.to(device, non_blocking=True)
            loss_val += _val_step(x, y, model, loss_fn, device)
        
        epoch_train_loss = loss_train.item() / len(train_data)
        epoch_val_loss = loss_val.item() / len(val_data)
        
        if epoch == 1 or epoch_val_loss < min(val_losses.values()):
            best_checkpoint = ModelCheckpoint(
                epoch, epoch_train_loss, epoch_val_loss, model.state_dict().copy()
            )
        train_losses[epoch] = epoch_train_loss 
        val_losses[epoch] = epoch_val_loss 

        if _early_stop(val_losses):
            break 
    
    print("\nbest model:")
    print(f"|-- epoch: {best_checkpoint.epoch}")
    print(f"|-- validation loss: {best_checkpoint.val_loss:.4f}")

    model.load_state_dict(best_checkpoint.state_dict)
    return train_losses, val_losses, best_checkpoint

In [179]:
@torch.no_grad()
def generate(
    model: NextWordPredictionRNN,
    prompt: str,
    tokenizer: AutoTokenizer,
    strategy: Literal["greedy", "sample", "topk"] = "greedy",
    output_length: int = 60,
    temperature: float = 1.0,
    random_seed: int = 42,
    device: device = get_best_device(),
    *,
    k: int = 2,
) -> str:
    """Generate new text conditional on a text prompt."""
    torch.manual_seed(random_seed)

    model.to(device)
    model.eval()
    
    encoded_prompt = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
    prompt_tokens = encoded_prompt["input_ids"].to(device).squeeze(0).tolist()

    hidden, cell = model.initialize(batch_size=1, device_=device)

    # Prime model
    for token_id in prompt_tokens[:-1]:
        x = torch.tensor([token_id], device=device)
        _, hidden, cell = model(x, hidden, cell)

    token_sequence = prompt_tokens.copy()

    for _ in range(output_length):
        x = torch.tensor([token_sequence[-1]], device=device)
        logits, hidden, cell = model(x, hidden, cell)
        token_logits = logits[0] / max(temperature, 1e-6)

        if strategy == "greedy":
            token_pred = torch.argmax(token_logits)
        elif strategy in {"sample", "topk"}:
            if strategy == "topk":
                v, i = torch.topk(token_logits, k)
                filtered_logits = token_logits.clone()
                filtered_logits[filtered_logits < v[-1]] = -float("Inf")
            else:
                filtered_logits = token_logits
            probs = torch.softmax(filtered_logits, dim=-1)
            token_pred = torch.multinomial(probs, num_samples=1).item()
        else:
            raise ValueError(f"Unknown decoding strategy: {strategy}")

        token_sequence.append(token_pred if isinstance(token_pred, int) else token_pred.item())

    return tokenizer.decode(token_sequence, skip_special_tokens=True)

In [182]:
# Load your dataset
ds = load_dataset("sedthh/gutenberg_english", split="train").select(range(20))

# Load your tokenizer (still needed for accurate boundary detection)
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
block_size = 256

def tokenize_function(examples):
    return tokenizer(examples["TEXT"], return_special_tokens_mask=False)

tokenized = ds.map(tokenize_function, batched=True, remove_columns=["TEXT"])

def group_texts(examples):
    # Concatenate all texts
    concatenated = sum(examples["input_ids"], [])
    total_length = len(concatenated)
    total_length = (total_length // block_size) * block_size  # drop remainder
    # Split by chunks of block_size
    result = {
        "input_ids": [concatenated[i : i + block_size] for i in range(0, total_length, block_size)]
    }
    return result

lm_dataset = tokenized.map(group_texts, batched=True, batch_size=1, remove_columns=tokenized.column_names)

split_dataset = lm_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = split_dataset["train"]
val_dataset = split_dataset["test"]

def collate_fn(batch):
    # Each batch element is {"input_ids": [int, int, ...]}
    input_ids = torch.tensor([example["input_ids"] for example in batch], dtype=torch.long)
    x_batch = input_ids[:, :-1]
    y_batch = input_ids[:, 1:]
    return x_batch, y_batch

train_dataloader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    collate_fn=collate_fn,
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False,
    collate_fn=collate_fn,
)

In [183]:
SIZE_EMBED = 256
SIZE_HIDDEN = 512

MAX_EPOCHS = 5
BATCH_SIZE = 256
MAX_SEQ_LEN = 100
MIN_SEQ_LEN = 10
MIN_WORD_FREQ = 2
LEARNING_RATE = 0.005

model = NextWordPredictionRNN(
    tokenizer.vocab_size,
    SIZE_EMBED,
    SIZE_HIDDEN
)

In [33]:
train_losses, val_losses, best_checkpoint = train(
    model=model,
    train_data=train_dataloader,
    val_data=val_dataloader,
    n_epochs=MAX_EPOCHS,
    learning_rate=LEARNING_RATE
)

  0%|          | 0/650 [00:00<?, ?it/s]

epoch 1 training loss = 4.7395: 100%|██████████| 650/650 [04:20<00:00,  2.50it/s]
epoch 2 training loss = 3.8975: 100%|██████████| 650/650 [04:17<00:00,  2.52it/s]
epoch 3 training loss = 3.5853: 100%|██████████| 650/650 [04:18<00:00,  2.52it/s]
epoch 4 training loss = 3.3733: 100%|██████████| 650/650 [04:18<00:00,  2.52it/s]
epoch 5 training loss = 3.2074: 100%|██████████| 650/650 [04:18<00:00,  2.52it/s]



best model:
|-- epoch: 3
|-- validation loss: 3.9789


In [35]:
generate(
    model=model,
    prompt="this is the worst timeline given the circumstances. You must",
    tokenizer=tokenizer,
    strategy="topk",
)

'this is the worst timeline given the circumstances. You must have a little remarkable, and that they are not hers, but they are not really friendly to the other. ” “ I am not able to do, ” he cried. “ I don ’ t know what you mean to do it, ” said Alice, “ if you ’ d have been,'

# GPT

In [80]:
def clones(module: nn.Module, N: int) -> nn.ModuleList:
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [81]:
class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."

    def __init__(self, d_model: int, d_ff: int, dropout: float=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: Tensor) -> Tensor:
        return self.w_2(self.dropout(self.w_1(x).relu()))

In [82]:
class Embeddings(nn.Module):
    def __init__(self, d_model: int, vocab: int):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model
    
    def forward(self, x: Tensor) -> Tensor:
        return self.lut(x) * math.sqrt(self.d_model)

In [83]:
class PositionalEncoding(nn.Module):
    "Implement the PE function."

    def __init__(self, d_model: int, dropout: float, max_len: int = 5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)
    
    def forward(self, x: Tensor) -> Tensor:
        x = x + self.pe[:, :x.size(1)].requires_grad_(False)
        return self.dropout(x)

In [87]:
class SublayerConnection(nn.Module):
    """Layernorm and Residual connection in between layers."""
    
    def __init__(self, d_model: int, dropout: float):
        super(SublayerConnection, self).__init__()
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: Tensor, sublayer: nn.Module):
        return x + self.dropout(sublayer(self.norm(x)))

In [140]:
def attention(q, k, v, mask=None, dropout=None):
    d_k = q.size(-1)
    scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
    if mask is not None:
        scores.masked_fill(mask == 0, 1e-9)
    p_attn = scores.softmax(dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, v)

In [141]:
class MultiHeadAttention(nn.Module):
    """Computes multi head attention."""
    
    def __init__(self, heads: int, d_model: int, dropout: float = 0.1):
        "Take in model size and number of heads"
        super(MultiHeadAttention, self).__init__()
        assert d_model % heads == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // heads
        self.heads = heads 
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.dropout = nn.Dropout(p=dropout)
    
    def forward(self, q, k, v, mask):
        nbatches = q.size(0)
        
        # 1) Do linear projections to Q, K, V
        q, k, v = [
            lin(x).view(nbatches, -1, self.heads, self.d_k).transpose(1, 2) 
            for lin, x in zip(self.linears, (q, k, v))
        ]
        
        # 2) Run attention on Q, K, V
        a = attention(q, k, v, mask=mask, dropout=self.dropout)
        
        # 3) Concatenate heads 
        a = (
            a.transpose(1, 2)
            .contiguous()
            .view(nbatches, -1, self.heads * self.d_k)
        )
        del q
        del k 
        del v 
        return self.linears[-1](a)

In [142]:
class DecoderLayer(nn.Module):
    """Decoder is made of masked self-attention and feed forward"""
    def __init__(self, d_model: int, self_attn: MultiHeadAttention, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.d_model = d_model 
        self.self_attn = self_attn 
        self.feed_forward = feed_forward 
        self.sublayer = clones(SublayerConnection(d_model, dropout), 2)
    
    def forward(self, x: Tensor, mask: Tensor):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

In [143]:
class Decoder(nn.Module):
    "Generic N layer decoder with masking."

    def __init__(self, layer: DecoderLayer, N: int):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = nn.LayerNorm(layer.d_model)
    
    def forward(
        self,
        x: Tensor, 
        mask: Tensor
    ):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

In [213]:
class Generator(nn.Module):
    "Define standard linear + softmax generation step. Used after decoder. So use NLLLoss() instead of CrossEntropyLoss()"

    def __init__(self, d_model: int, vocab_size: int):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab_size)
    
    def forward(
        self,
        x: Tensor
    ):
        return self.proj(x)

In [155]:
class GPT(nn.Module):
    """GPT for language generation."""
    
    def __init__(
        self, 
        vocab_size: int, 
        N: int = 6, 
        d_model: int = 512,
        d_ff: int = 2048,
        heads: int = 8,
        dropout: float = 0.1
    ):
        super(GPT, self).__init__()
        attn = MultiHeadAttention(heads, d_model)
        ff = PositionwiseFeedForward(d_model, d_ff, dropout)
        position = PositionalEncoding(d_model, dropout)
        self.embed = nn.Sequential(Embeddings(d_model, vocab_size), position)
        self.decoder = Decoder(DecoderLayer(d_model, attn, ff, dropout), N)
        self.generator = Generator(d_model, vocab_size)
    
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def forward(self, x, mask):
        return self.generator(self.decoder(self.embed(x), mask))

In [151]:
def subsequent_mask(size: int):
    "Mask out subsequent positions."
    attn_shape = (1, 1, size, size) # (1, sequence_length, sequence_length)
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
        torch.uint8
    )
    return (subsequent_mask == 0).to(float) # (1, sequence_length, sequence_length)

In [215]:
def inference_test():
    model = GPT(11, 2)
    model.eval() 
    src = torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])

    ys = torch.zeros(1, 1).type_as(src)
    for _ in range(9):
        out = nn.functional.log_softmax(model(ys, subsequent_mask(ys.size(1))), dim=-1)
        prob = torch.argmax(out[:,-1]).item()
        ys = torch.cat(
            [ys, torch.empty(1, 1).type_as(src).fill_(prob)], dim=1
        )
    
    print("Example Untrained Model Prediction:", ys)

def run_tests():
    for _ in range(10):
        inference_test()

run_tests()

Example Untrained Model Prediction: tensor([[ 0, 10,  4,  6, 10,  4,  6, 10,  4,  6]])
Example Untrained Model Prediction: tensor([[ 0,  2, 10, 10, 10, 10, 10, 10, 10, 10]])
Example Untrained Model Prediction: tensor([[0, 7, 7, 8, 8, 9, 4, 9, 4, 2]])
Example Untrained Model Prediction: tensor([[ 0, 10,  0, 10,  1,  8, 10,  1,  8, 10]])
Example Untrained Model Prediction: tensor([[0, 3, 3, 3, 3, 3, 3, 3, 1, 8]])
Example Untrained Model Prediction: tensor([[0, 8, 8, 8, 8, 8, 8, 8, 9, 2]])
Example Untrained Model Prediction: tensor([[0, 2, 3, 2, 1, 8, 0, 2, 1, 8]])
Example Untrained Model Prediction: tensor([[ 0, 10,  1,  1,  1,  1, 10,  1, 10,  1]])
Example Untrained Model Prediction: tensor([[0, 8, 8, 8, 8, 8, 8, 8, 8, 8]])
Example Untrained Model Prediction: tensor([[0, 0, 0, 5, 4, 4, 5, 8, 4, 5]])


In [216]:
def _train_step(
    x_batch: Tensor,
    y_batch: Tensor,
    model: nn.Module,
    loss_fn: Callable[[Tensor, Tensor], Tensor],
    optimizer: Optimizer,
    device: device,
) -> Tensor:
    """One iteration of the training loop (for one batch)."""
    model.train()
    _, sequence_length = x_batch.shape 
    
    loss_batch = tensor(0.0, device=device)
    optimizer.zero_grad(set_to_none=True)
    
    y_pred = model(x_batch, subsequent_mask(sequence_length))
    loss_batch += loss_fn(y_pred.permute(0, 2, 1), y_batch)
    
    loss_batch.backward()
    optimizer.step() 

    return loss_batch / sequence_length

@torch.no_grad()
def _val_step(
    x_batch: Tensor,
    y_batch: Tensor,
    model: nn.Module,
    loss_fn: Callable[[Tensor, Tensor], Tensor],
    device: device,
) -> Tensor:
    """One iteration of the validation loop (for one batch)."""
    model.eval()
    batch_size, sequence_length = x_batch.shape 
    
    loss_batch = tensor(0.0, device=device)

    y_pred = model(x_batch, subsequent_mask(sequence_length))
    loss_batch += loss_fn(y_pred, y_batch)
    
    return loss_batch / sequence_length

In [217]:
def train(
    model: nn.Module,
    train_data: DataLoader,
    val_data: DataLoader,
    n_epochs: int,
    learning_rate: float = 0.001,
    random_seed: int =42,
    device: device = "cpu",
) -> Tuple[Dict[int, float], Dict[int, float], ModelCheckpoint]:
    """Training loop for LSTM flavoured RNNs on sequence data."""
    torch.manual_seed(random_seed)
    model.to(device)

    optimizer = Adam(model.parameters(), lr=learning_rate)
    loss_fn = nn.CrossEntropyLoss()
    
    train_losses: Dict[int, float] = {}
    val_losses: Dict[int, float] = {}
    
    for epoch in range(1, n_epochs + 1):
        loss_train = tensor(0.0).to(device)
        for i, (x_batch, y_batch) in enumerate((pbar := tqdm(train_data)), start=1):
            x = x_batch.to(device, non_blocking=True)
            y = y_batch.to(device, non_blocking=True)
            loss_train += _train_step(x, y, model, loss_fn, optimizer, device)
            pbar.set_description(f"epoch {epoch} training loss = {loss_train/i:.4f}")
        
        loss_val = tensor(0.0).to(device)
        for x_batch, y_batch in val_data:
            x = x_batch.to(device, non_blocking=True)
            y = y_batch.to(device, non_blocking=True)
            loss_val += _val_step(x, y, model, loss_fn, device)
        
        epoch_train_loss = loss_train.item() / len(train_data)
        epoch_val_loss = loss_val.item() / len(val_data)
        
        if epoch == 1 or epoch_val_loss < min(val_losses.values()):
            best_checkpoint = ModelCheckpoint(
                epoch, epoch_train_loss, epoch_val_loss, model.state_dict().copy()
            )
        train_losses[epoch] = epoch_train_loss 
        val_losses[epoch] = epoch_val_loss 

        if _early_stop(val_losses):
            break 
    
    print("\nbest model:")
    print(f"|-- epoch: {best_checkpoint.epoch}")
    print(f"|-- validation loss: {best_checkpoint.val_loss:.4f}")

    model.load_state_dict(best_checkpoint.state_dict)
    return train_losses, val_losses, best_checkpoint

In [259]:
model = GPT(tokenizer.vocab_size, 4)

In [260]:
train_losses, val_losses, best_checkpoint = train(
    model=model,
    train_data=train_dataloader,
    val_data=val_dataloader,
    n_epochs=MAX_EPOCHS,
    learning_rate=LEARNING_RATE
)

epoch 1 training loss = 0.0273:  15%|█▌        | 98/650 [03:47<21:20,  2.32s/it]


KeyboardInterrupt: 

In [261]:
@torch.no_grad()
def generate(
    model: NextWordPredictionRNN,
    prompt: str,
    tokenizer: AutoTokenizer,
    strategy: Literal["greedy", "sample", "topk"] = "greedy",
    output_length: int = 60,
    temperature: float = 1.0,
    random_seed: int = 42,
    device: device = "cpu",
    *,
    k: int = 2,
) -> str:
    """Generate new text conditional on a text prompt."""
    torch.manual_seed(random_seed)

    model.to(device)
    model.eval()
    
    encoded_prompt = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
    prompt_tokens = encoded_prompt["input_ids"].to(device).squeeze(0).tolist()

    token_sequence = tensor(prompt_tokens.copy()).unsqueeze(0)

    for _ in range(output_length):
        out = nn.functional.log_softmax(model(token_sequence, subsequent_mask(token_sequence.size(1))), dim=-1)
        prob = torch.argmax(out[:,-1]).item()
        token_sequence = torch.cat(
            [token_sequence, torch.empty(1, 1).type_as(token_sequence).fill_(prob)], dim=1
        )

    return tokenizer.decode(token_sequence.flatten(), skip_special_tokens=True)

In [262]:
generate(
    model=model,
    prompt="this is the worst timeline given the circumstances. You must",
    tokenizer=tokenizer,
    strategy="topk",
)

'this is the worst timeline given the circumstances. You must and the o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o o'