In [None]:
import torch

class CheckpointManager:
    def __init__(self, model, optimizer, checkpoint_path='checkpoint.pth'):
        self.model = model
        self.optimizer = optimizer
        self.checkpoint_path = checkpoint_path

    def save(self, epoch, val_loss):
        """Save model, optimizer state, current epoch, and validation loss."""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'val_loss': val_loss,
        }
        torch.save(checkpoint, self.checkpoint_path)
        print(f'Checkpoint saved at epoch {epoch} with val_loss {val_loss}')

    def load(self):
        """Load checkpoint and restore model & optimizer states."""
        checkpoint = torch.load(self.checkpoint_path)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        val_loss = checkpoint['val_loss']
        print(f'Checkpoint loaded. Resuming from epoch {epoch} with val_loss {val_loss}')
        return epoch, val_loss