In [None]:
import torch
import train
train.set_namespace('train-api')

train_dataloader = None
val_dataloader = None

Insights:
- Should only be able to call cp.train() maximum 1 times per block
- cp.train() is what always loads all object states - the model parameters as well as optimizer and schedulers
- A `with model.checkpoint(...)` block should only ever be executed once while writing the notebook. It should not be modified (unless we call delete)

In [None]:
def build_model():
    return torch.nn.Linear(in_features=5, out_features=2)

model = train.Model('model-001', build_model(), 'A demonstration of the API')

with model.checkpoint(id=1, description='Adam') as cp:
    # This builds the objects and stores them on the model level. 
    cp.setup(
        loss_cls=torch.nn.MSELoss, 
        optimizer_cls=torch.optim.Adam,
        optimizer_args=dict(lr=5e-4, weight_decay=5e-4)
    )
    # If a checkpoint file exists, training will not occur. Instead, we simply load the results from the file. 
    # Also, the post-training state of the optimizer and schedulers which were built by cp.setup() is loaded.
    cp.train(train_dataloader, val_dataloader, epochs=10, watch='val_accuracy', load_best=True, warmup=1, metrics=[train.metric_accuracy])
    cp.plot_metrics()

with model.checkpoint(id=2, description='Train some more') as cp:
    cp.train(train_dataloader, val_dataloader, epochs=10, watch='val_accuracy', load_best=True, metrics=[train.metric_accuracy])
    cp.plot_metrics()

with model.checkpoint(id=3, description='SGD') as cp:
    # We don't have to provide all objects to cp.setup() every time. Only what we define will override what's currently stored in the model.
    cp.setup(optimizer_cls=torch.optim.SGD)
    # Here we use load_best=False, which means that the checkpoint will contain the parameters at the last training epoch.
    # However, if the best results were achieved prior to the last training epoch, a backup checkpoint is created.
    cp.train(train_dataloader, val_dataloader, epochs=10, watch='val_accuracy', load_best=False, metrics=[train.metric_accuracy])
    cp.plot_metrics()

with model.checkpoint(id='3.1', description='Continue SGD from last') as cp:
    cp.train(train_dataloader, val_dataloader, epochs=10, watch='val_accuracy', load_best=True, metrics=[train.metric_accuracy])
    cp.plot_metrics()

# Here we are reverting the model back to checkpoint 3, loading it from the backup checkpoint - i.e. the best results 
# achieved during training, not the last results.
model.load_checkpoint(id=3, from_backup=True)

# Here we are effectively forking the training process, continuing from the best rather than the last epoch achieved during checkpoint 3
with model.checkpoint(id='3.2', description='Continue SGD from best') as cp:
    cp.train(train_dataloader, val_dataloader, epochs=10, watch='val_accuracy', load_best=True, metrics=[train.metric_accuracy])
    cp.plot_metrics()

# We delete checkpoint 3 and its backup. This means that if we rerun the notebook, the code under `with model.checkpoint(id=3, ...)`
# will have to rerun
model.delete_checkpoint('3.1')

# Delete all model checkpoints and tensorboard data
model.delete()
