In [None]:
import torch
from datetime import datetime
import os

### Save and load models


In [None]:
def print_state_dict(model, title=None):
    if title:
        print(f"------------ {title} ------------")
    state_dict = model.state_dict()
    for name, value in state_dict.items():
        print(f"{name}:\n{value}\n")

In [None]:
# Create a simple linear model: input size 3, output size 3
model = torch.nn.Linear(3, 3)

# Set up Adam optimizer for the model's parameters
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

print_state_dict(model, "Model")
print_state_dict(optimizer, "Optimizer")

In [None]:
checkpoint = {
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
}

# Save the checkpoint dictionary to a file named 'cp_test.pth'
torch.save(checkpoint, "cp_test.pth")

In [None]:
# Initialize new model and optimizer
model_new = torch.nn.Linear(3, 3)
optimizer_new = torch.optim.Adam(model.parameters(), lr=1)
print_state_dict(model_new, "Model")
print_state_dict(optimizer_new, "Optimizer")


In [None]:
# Load the checkpoint (a dictionary containing model and optimizer state_dicts)
checkpoint_loaded = torch.load("cp_test.pth")

# Restore the new model's parameters from the checkpoint
model_new.load_state_dict(checkpoint_loaded["model_state_dict"])

# Restore the new optimizer's state from the checkpoint
optimizer_new.load_state_dict(checkpoint_loaded["optimizer_state_dict"])

print_state_dict(model_new, "Model")
print_state_dict(optimizer_new, "Optimizer")

### Custom class


In [None]:
class CheckpointHandler:
    @staticmethod
    def list_saved_files(root="."):
        for root, dirs, files in os.walk(root):
            for file in files:
                filepath = os.path.join(root, file)
                filepath = filepath.replace('\\', '/')  # replace backslash with forward slash
                if file.endswith("pth") or file.endswith("pt"):
                    print(filepath)

    @staticmethod
    def make_dir(folder_path):
        os.makedirs(folder_path, exist_ok=True)

    @staticmethod
    def get_dt():
        return datetime.now().strftime("%Y-%m-%d_%H-%M")

    @staticmethod
    def save(save_path, model, optimizer=None, epoch=None, val_loss=None):
        checkpoint = {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "val_loss": val_loss,
        }
        torch.save(checkpoint, save_path)

    @staticmethod
    def load(save_path, model, optimizer=None):
        checkpoint = torch.load(save_path)
        model.load_state_dict(checkpoint["model_state_dict"])
        if optimizer is not None:
            optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        epoch = checkpoint["epoch"]
        val_loss = checkpoint["val_loss"]
        return model, optimizer, epoch, val_loss

In [None]:
model = torch.nn.Linear(3, 3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
print_state_dict(model, title="Model")
print_state_dict(optimizer, title="Optimizer")
#
cph = CheckpointHandler()
cph.make_dir("./checkpoints")
dt = cph.get_dt()
#
save_path = f"./checkpoints/{dt}.pth"
cph.save(save_path=save_path, model=model, optimizer=optimizer)

In [None]:
model = torch.nn.Linear(3, 3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
print_state_dict(model, title="Model")
print_state_dict(optimizer, title="Optimizer")

In [None]:
cph.list_saved_files()
cph.load(save_path=save_path, model=model, optimizer=optimizer)

print_state_dict(model, title="Model")
print_state_dict(optimizer, title="Optimizer")