In [None]:
import torch
from datetime import datetime
import os


class CheckpointHandler:
    def __init__(self, folder_path, filename=None):
        if filename is None:
            dt = datetime.now().strftime("%Y-%m-%d_%H-%M")
        self.checkpoint_file = f"cp_{dt}.pth"
        os.makedirs(folder_path, exist_ok=True)
        self.checkpoint_path = os.path.join(folder_path, self.checkpoint_file)

    def save(self, model, optimizer, epoch, val_loss):
        """Save model, optimizer state, current epoch, and validation loss."""
        checkpoint = {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "val_loss": val_loss,
        }
        torch.save(checkpoint, self.checkpoint_path)
        print(f"Checkpoint saved at epoch {epoch} with val_loss {val_loss}")

    def load(self):
        """Load checkpoint and restore model & optimizer states."""
        checkpoint = torch.load(self.checkpoint_path)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        epoch = checkpoint["epoch"]
        val_loss = checkpoint["val_loss"]
        print(
            f"Checkpoint loaded. Resuming from epoch {epoch} with val_loss {val_loss}"
        )
        return self.model, self.optimizer, epoch, val_loss

In [12]:
model = torch.nn.Linear(3, 3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 

state_dict = model.state_dict()
for name, value in state_dict.items():
    print(name, value)
    
state_dict = optimizer.state_dict()
for name, value in state_dict.items():
    print(name, value)

weight tensor([[-0.2786,  0.4071, -0.4067],
        [ 0.0814,  0.2762,  0.1213],
        [ 0.2234,  0.0923, -0.1042]])
bias tensor([ 0.2175,  0.3831, -0.0288])
state {}
param_groups [{'lr': 0.01, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'decoupled_weight_decay': False, 'params': [0, 1]}]


In [4]:
cp_handler = CheckpointHandler(model, optimizer, folder_path="cp")
cp_handler.save(epoch=10, val_loss=0.1)

Checkpoint saved at epoch 10 with val_loss 0.1


In [None]:
model, optimizer, epoch, val_loss = cp_handler.load()