In [1]:
import os
import torch
import tempfile
from typing import Dict, Any

In [2]:
def save_checkpoint(
    checkpoint_path: str,
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler: Any = None,
    epoch: int = 0,
    best_metric: float = None,
    extra: Dict[str, Any] = None,
    *,
    use_atomic: bool = True
) -> None :
    """
    Save a checkpoint containing model state_dict and optimizer state.
    Writes atomically by default to avoid partial files.
    """
    if extra is None:
        extra = {}

    # If model is wrapped (DataParallel / DDP), get underlying module
    model_state = model.module.state_dict() if hasattr(model, "module") else model.state_dict()
    ckpt = {
        "epoch": epoch,
        "model_state": model_state,
        "optim_state": optimizer.state_dict(),
        "best_metric": best_metric,
        "rng_state": torch.get_rng_state(),
        "cuda_rng_state": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
        "extra": extra,
    }
    if scheduler is not None:
        ckpt["scheduler_state"] = scheduler.state_dict()

    if use_atomic:
        dirpath = os.path.dirname(checkpoint_path) or "."
        with tempfile.NamedTemporaryFile(dir=dirpath, delete=False) as tmp:
            tmp_path = tmp.name
        try:
            torch.save(ckpt, tmp_path)
            os.replace(tmp_path, checkpoint_path)  # atomic on POSIX
        finally:
            if os.path.exists(tmp_path):
                os.remove(tmp_path)
    else:
        torch.save(ckpt, checkpoint_path)     

In [3]:
def load_check_point(
    checkpoint_path: str,
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer = None,
    scheduler: Any = None,
    device: torch.device = None,
    strict: bool = True,
    load_optimizer: bool = True,
    load_scheduler: bool = True,
) -> Dict[str,Any]:
    """
    Load checkpoint and restore model and optimizer/scheduler states.
    Returns the checkpoint dict for further inspection.
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ckpt = torch.load(checkpoint_path, map_location=device)

    # If model is wrapped, load into underlying module
    target_model = model.module if hasattr(model, "module") else model
    target_model.load_state_dict(ckpt["model_state"], strict=strict)

    if optimizer is not None and load_optimizer and "optim_state" in ckpt:
        try:
            optimizer.load_state_dict(ckpt["optim_state"])
        except Exception as e:
            # Optimizer state may be incompatible across versions; warn and continue
            print(f"Warning: failed to load optimizer state: {e}")

    if scheduler is not None and load_scheduler and "scheduler_state" in ckpt:
        try:
            scheduler.load_state_dict(ckpt["scheduler_state"])
        except Exception as e:
            print(f"Warning: failed to load scheduler state: {e}")

    # Restore RNG states if present
    if "rng_state" in ckpt and ckpt["rng_state"] is not None:
        torch.set_rng_state(ckpt["rng_state"])
    if torch.cuda.is_available() and "cuda_rng_state" in ckpt and ckpt["cuda_rng_state"] is not None:
        torch.cuda.set_rng_state_all(ckpt["cuda_rng_state"])

    return ckpt
    

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

In [5]:
# Dummy model and data for illustration
class SimpleNet(nn.Module):
    def __init__(self, in_dim=10, out_dim=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 64),
            nn.ReLU(),
            nn.Linear(64, out_dim)
        )

    def forward(self, x):
        return self.net(x)

In [6]:
def train_example(checkpoint_path="checkpoints/ckpt.pt", resume=False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SimpleNet().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

    start_epoch = 0
    best_val = float("inf")

    if resume and os.path.exists(checkpoint_path):
        ckpt = load_checkpoint(checkpoint_path, model, optimizer, scheduler, device=device)
        start_epoch = ckpt.get("epoch", 0) + 1
        best_val = ckpt.get("best_metric", best_val)
        print(f"Resumed from epoch {start_epoch}, best_val={best_val}")

    # Dummy dataset
    x = torch.randn(1000, 10)
    y = torch.randint(0, 2, (1000,))
    ds = TensorDataset(x, y)
    loader = DataLoader(ds, batch_size=32, shuffle=True)
    print(loader)

    for epoch in range(start_epoch, 20):
        model.train()
        total_loss = 0.0
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            loss = nn.CrossEntropyLoss()(logits, yb)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()

        val_metric = total_loss / len(loader)  # placeholder for real val metric
        print(f"Epoch {epoch} loss {val_metric:.4f}")

        # Save checkpoint every epoch or when improved
        is_best = val_metric < best_val
        if is_best:
            best_val = val_metric

        os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
        save_checkpoint(
            checkpoint_path,
            model,
            optimizer,
            scheduler=scheduler,
            epoch=epoch,
            best_metric=best_val,
            extra={"notes": "example run"}
        )

In [7]:
train_example(checkpoint_path="checkpoints/ckpt.pt", resume=False)

<torch.utils.data.dataloader.DataLoader object at 0x110ba42f0>
Epoch 0 loss 0.7088
Epoch 1 loss 0.6899
Epoch 2 loss 0.6870
Epoch 3 loss 0.6824
Epoch 4 loss 0.6807
Epoch 5 loss 0.6784
Epoch 6 loss 0.6789
Epoch 7 loss 0.6771
Epoch 8 loss 0.6760
Epoch 9 loss 0.6750
Epoch 10 loss 0.6755
Epoch 11 loss 0.6729
Epoch 12 loss 0.6728
Epoch 13 loss 0.6723
Epoch 14 loss 0.6726
Epoch 15 loss 0.6707
Epoch 16 loss 0.6719
Epoch 17 loss 0.6709
Epoch 18 loss 0.6725
Epoch 19 loss 0.6714
