# training-utils.ipynb

> Utilities related to model training

In [None]:
#| default_exp training_utils

In [None]:
#| export
from pathlib import Path
import tempfile
from typing import Dict, Generic, List, Protocol, Tuple, TypeVar

In [None]:
#| export
import torch
from tqdm.auto import tqdm

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| hide
from fastcore.test import *

In [None]:
#| export
class CheckPointer:
    def __init__(self, output_dir: Path, filename_stem: str, start_num: int = 0):
        self.output_dir = output_dir
        self.filename_stem = filename_stem
        self.num = start_num

    def filename(self):
        return self.output_dir / f'{self.filename_stem}_{self.num:06d}.pt'

    def save_checkpoint(
        self,
        iters: int,
        model: torch.nn.Module,
        train_loss: float,
        val_loss: float
    ) -> Path:
        filename = self.filename()
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "iters": iters,
                "train_loss": train_loss,
                "val_loss": val_loss,
            },
            filename,
        )
        self.num += 1
        return filename

In [None]:
## Tests for CheckPointer
class TestModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.some_param = torch.nn.Linear(10, 1)

    def forward(self, x):
        return self.some_param(x)

with tempfile.TemporaryDirectory() as tmpdirname:
    output_dir = Path(tmpdirname)
    checkpointer = CheckPointer(output_dir, 'test')

    test_model = TestModule()
    test_train_loss = 0.123
    test_val_loss = 0.125

    filename = checkpointer.save_checkpoint(10, test_model, test_train_loss, test_val_loss)
    test_eq(filename, output_dir / 'test_000000.pt')
    test_eq(filename.exists(), True)

    checkpoint = torch.load(str(output_dir / 'test_000000.pt'))
    test_eq(checkpoint['iters'], 10)
    test_eq(checkpoint['train_loss'], test_train_loss)
    test_eq(checkpoint['val_loss'], test_val_loss)
    test_eq('model_state_dict' in checkpoint, True)

    filename = checkpointer.save_checkpoint(11, test_model, test_train_loss, test_val_loss)
    test_eq(filename, output_dir / 'test_000001.pt')
    test_eq(filename.exists(), True)

    # Test start_num
    checkpointer = CheckPointer(output_dir, 'test', start_num=14)
    filename = checkpointer.save_checkpoint(11, test_model, test_train_loss, test_val_loss)
    test_eq(filename, output_dir / 'test_000014.pt')
    test_eq(filename.exists(), True)

In [None]:
# | export
TModel = TypeVar("TModel", bound=torch.nn.Module, contravariant=True)


class EstimateLossFunction(Protocol[TModel]):
    def __call__(self, model: TModel) -> Dict[str, float]:
        ...


class GetBatchFunction(Protocol):
    def __call__(self, split: str) -> Tuple[torch.Tensor, torch.Tensor]:
        ...


class OnBatchTrainedHandler(Protocol):
    def __call__(self, iters_trained: int, batch: torch.Tensor) -> None:
        ...


class OnCheckpointSavedHandler(Protocol):
    def __call__(self, iters_trained: int, checkpoint_file: Path) -> None:
        ...


class Trainer(Generic[TModel]):
    def __init__(
        self,
        model: TModel,
        checkpointer: CheckPointer,
        get_batch_func: GetBatchFunction,
        estimate_loss_func: EstimateLossFunction[TModel],
        iters_trained: int = 0,
    ):
        self.model = model
        self.checkpointer = checkpointer
        self.get_batch_func = get_batch_func
        self.estimate_loss_func = estimate_loss_func
        self.iters_trained = iters_trained
        self.on_batch_trained_handlers: List[OnBatchTrainedHandler] = []
        self.on_checkpoint_saved_handlers: List[OnCheckpointSavedHandler] = []

    def add_on_batch_trained_handler(self, handler: OnBatchTrainedHandler):
        self.on_batch_trained_handlers.append(handler)

    def add_on_checkpoint_saved_handler(self, handler: OnCheckpointSavedHandler):
        self.on_checkpoint_saved_handlers.append(handler)

    def fire_on_batch_trained(self, batch: torch.Tensor):
        for handler in self.on_batch_trained_handlers:
            handler(self.iters_trained, batch)

    def fire_on_checkpoint_saved(self, checkpoint_file: Path):
        for handler in self.on_checkpoint_saved_handlers:
            handler(self.iters_trained, checkpoint_file)

    def train(
        self,
        n_iters: int,
        optimizer: torch.optim.Optimizer,
        eval_interval: int = 500,
        disable_progress_bar: bool = False,
        disable_output: bool = False,
    ):
        self.model.train()
        for steps in tqdm(range(n_iters), disable=disable_progress_bar):
            xb, yb = self.get_batch_func(split="train")

            _, loss = self.model(xb, yb)
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
            self.iters_trained += 1
            self.fire_on_batch_trained(xb)

            if self.iters_trained % eval_interval == 0:
                losses = self.estimate_loss_func(self.model)
                if not disable_output:
                    print(
                        f"step {steps}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
                    )
                checkpoint_filename = self.checkpointer.save_checkpoint(
                    self.iters_trained,
                    self.model,
                    losses["train"],
                    losses["val"],
                )
                self.fire_on_checkpoint_saved(checkpoint_filename)

In [None]:
# Tests for Trainer
class TestModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.some_param = torch.nn.Linear(10, 1)

    def forward(self, x, y):
        return self.some_param(x), torch.randn(1, 1, requires_grad=True)


with tempfile.TemporaryDirectory() as tmpdirname:
    output_dir = Path(tmpdirname)
    checkpointer = CheckPointer(output_dir, "test")

    test_model = TestModule()

    xbs = []

    def get_batch_func(split: str):
        xb = torch.randn(1, 10)
        xbs.append(xb)
        yb = torch.randn(1, 1)

        return xb, yb

    def estimate_loss_func(model: TestModule):
        return {"train": 0.123, "val": 0.125}

    iters_trained_start = 10
    trainer = Trainer(
        test_model,
        checkpointer,
        get_batch_func,
        estimate_loss_func,
        iters_trained=iters_trained_start,
    )

    on_batch_trained_data = []
    trainer.add_on_batch_trained_handler(
        lambda iters_trained, batch: on_batch_trained_data.append(
            (iters_trained, batch)
        )
    )

    on_checkpoint_saved_data = []
    trainer.add_on_checkpoint_saved_handler(
        lambda iters_trained, checkpoint_file: on_checkpoint_saved_data.append(
            (iters_trained, checkpoint_file)
        )
    )

    trainer.train(
        10,
        torch.optim.Adam(test_model.parameters()),
        eval_interval=5,
        disable_progress_bar=True,
        disable_output=True,
    )
    test_eq(len(on_batch_trained_data), 10)
    test_eq(len(on_checkpoint_saved_data), 2)

    for i, (iters_trained, batch) in enumerate(on_batch_trained_data):
        test_eq(iters_trained, iters_trained_start + i + 1)
        test_eq(batch.shape, (1, 10))
        test_eq(batch, xbs[i])

    for i, (iters_trained, checkpoint_file) in enumerate(on_checkpoint_saved_data):
        test_eq(iters_trained, iters_trained_start + (i + 1) * 5)
        test_eq(checkpoint_file, output_dir / f"test_{i:06d}.pt")

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()