# training-utils.ipynb

> Utilities related to model training

In [None]:
#| default_exp training_utils

In [None]:
#| export
from pathlib import Path
import tempfile
from typing import Dict, Generic, Protocol, Tuple, TypeVar

In [None]:
#| export
import torch
from tqdm.auto import tqdm

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| hide
from fastcore.test import *

In [None]:
#| export
class CheckPointer:
    def __init__(self, output_dir: Path, filename_stem: str, start_num: int = 0):
        self.output_dir = output_dir
        self.filename_stem = filename_stem
        self.num = start_num

    def filename(self):
        return self.output_dir / f'{self.filename_stem}_{self.num:06d}.pt'

    def save_checkpoint(
        self,
        iters: int,
        model: torch.nn.Module,
        train_loss: float,
        val_loss: float
    ):
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "iters": iters,
                "train_loss": train_loss,
                "val_loss": val_loss,
            },
            self.filename(),
        )
        self.num += 1

In [None]:
## Tests for CheckPointer
class TestModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.some_param = torch.nn.Linear(10, 1)

    def forward(self, x):
        return self.some_param(x)

with tempfile.TemporaryDirectory() as tmpdirname:
    output_dir = Path(tmpdirname)
    checkpointer = CheckPointer(output_dir, 'test')

    test_model = TestModule()
    test_train_loss = 0.123
    test_val_loss = 0.125

    checkpointer.save_checkpoint(10, test_model, test_train_loss, test_val_loss)
    test_eq((output_dir / 'test_000000.pt').exists(), True)

    checkpoint = torch.load(str(output_dir / 'test_000000.pt'))
    test_eq(checkpoint['iters'], 10)
    test_eq(checkpoint['train_loss'], test_train_loss)
    test_eq(checkpoint['val_loss'], test_val_loss)
    test_eq('model_state_dict' in checkpoint, True)

    checkpointer.save_checkpoint(11, test_model, test_train_loss, test_val_loss)
    test_eq((output_dir / 'test_000001.pt').exists(), True)

    # Test start_num
    checkpointer = CheckPointer(output_dir, 'test', start_num=14)
    checkpointer.save_checkpoint(11, test_model, test_train_loss, test_val_loss)
    test_eq((output_dir / 'test_000014.pt').exists(), True)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()