# Extending the TrainedModel Table to support centralized checkpointing

In [1]:
import os
import datajoint as dj
dj.config['database.host'] = os.environ['DJ_HOST']
dj.config['database.user'] = os.environ['DJ_USER']
dj.config['database.password'] = os.environ['DJ_PASS']
dj.config['enable_python_native_blobs'] = True
if not "stores" in dj.config:
    dj.config["stores"] = {}
dj.config["stores"]["minio"] = {  # store in s3
        "protocol": "s3",
        "endpoint": os.environ.get("MINIO_ENDPOINT", "DUMMY_ENDPOINT"),
        "bucket": "nnfabrik",
        "location": "dj-store",
        "access_key": os.environ.get("MINIO_ACCESS_KEY", "FAKEKEY"),
        "secret_key": os.environ.get("MINIO_SECRET_KEY", "FAKEKEY"),
    }

from nnfabrik.examples import nnfabrik 
from nnfabrik.examples.nnfabrik import *

Connecting anix@134.2.168.16:3306


## (1) Switch out TrainedModel table

In [2]:
from nnfabrik.templates.checkpoint import TrainedModelChkptBase, my_checkpoint

Checkpoint = my_checkpoint(nnfabrik)

@nnfabrik.schema
class TrainedModelChkpt(TrainedModelChkptBase):
    table_comment = "My Trained models with checkpointing"
    nnfabrik = nnfabrik
    checkpoint_table = Checkpoint


Connecting anix@134.2.168.16:3306


## (2) Setup Training

### Add some entries to the table (usual step)

In [3]:
fabrikant_info = dict(fabrikant_name="Your Name", email="your@email.com", affiliation="thelab", dj_username="yourname")
Fabrikant().insert1(fabrikant_info, skip_duplicates=True)

In [4]:
Seed().insert([{'seed':7}])

In [5]:
dataset_fn = "nnfabrik.examples.mnist.dataset.mnist_dataset_fn"

dataset_config = dict(batch_size=64) # we specify all the inputs except the ones required by nnfabrik

Dataset().add_entry(dataset_fn=dataset_fn, dataset_config=dataset_config, 
                    dataset_fabrikant="Your Name", dataset_comment="A comment about the dataset!");

In [6]:
# specify model function as string (the function must be importable) as well as the model config
model_fn = "nnfabrik.examples.mnist.model.mnist_model_fn"
model_config = dict(h_dim=5) # we specify all the inputs except the ones required by nnfabrik

Model().add_entry(model_fn=model_fn, model_config=model_config, 
                  model_fabrikant="Your Name", model_comment="A comment about the model!");

### Use a trainer that uses the checkpointing-feature by calling the callback function
Otherwise `TrainedModelChkpt` would work just as `TrainedModel`

In [7]:
# specify trainer function as string (the function must be importable) as well as the trainer config
trainer_fn = "nnfabrik.examples.mnist_checkpoint.trainer.chkpt_trainer_fn"
trainer_config = dict(epochs=3) # we specify all the inputs except the ones required by nnfabrik

Trainer().add_entry(trainer_fn=trainer_fn, trainer_config=trainer_config, 
                  trainer_fabrikant="Your Name", trainer_comment="A comment about the trainer!");

## (3) Try out the checkpointing feature 
Try populating `TrainedModelChkpt` and intterupt the training some time after the first epoch (before training finishes)

In [8]:
TrainedModelChkpt.populate()

  0%|          | 4/938 [00:00<00:28, 32.71it/s]

Epoch 0


100%|██████████| 938/938 [00:20<00:00, 46.74it/s]
  1%|          | 5/938 [00:00<00:21, 44.23it/s]

Nothing to delete
Epoch 1


  7%|▋         | 65/938 [00:01<00:19, 44.89it/s]


KeyboardInterrupt: 

### Now, let's check the tables.

In [9]:
TrainedModelChkpt()

model_fn  name of the model function,model_hash  hash of the model configuration,dataset_fn  name of the dataset loader function,dataset_hash  hash of the configuration object,trainer_fn  name of the Trainer loader function,trainer_hash  hash of the configuration object,seed  Random seed that is passed to the model- and dataset-builder,comment  short description,score  loss,output  trainer object's output,fabrikant_name  Name of the contributor that added this entry,trainedmodel_ts  UTZ timestamp at time of insertion
,,,,,,,,,,,


In [10]:
Checkpoint()

trainer_fn  name of the Trainer loader function,trainer_hash  hash of the configuration object,dataset_fn  name of the dataset loader function,dataset_hash  hash of the configuration object,model_fn  name of the model function,model_hash  hash of the model configuration,seed  Random seed that is passed to the model- and dataset-builder,epoch  epoch of creation,score  current score at epoch,state  current state,fabrikant_name  Name of the contributor that added this entry,trainedmodel_ts  UTZ timestamp at time of insertion
nnfabrik.examples.mnist_checkpoint.trainer.chkpt_trainer_fn,91bc1788b17e9db4c5e16a286a35c0d0,nnfabrik.examples.mnist.dataset.mnist_dataset_fn,9aee736870714f8b7c3cc084087ce886,nnfabrik.examples.mnist.model.mnist_model_fn,24922759b843076328c4f3b9df3f88d0,7,0,73.7033,=BLOB=,,2020-11-12 01:16:04


As you can see, the current state of our training was saved in the `Checkpoint` table.

### Continue the interrupted training
This will happen automatically by calling populate again, after clearing the error state of the corresponding job.

In [11]:
# delete all jobs in error state:
(schema.jobs & "status='error'").delete()

In [12]:
TrainedModelChkpt.populate()

  1%|          | 5/938 [00:00<00:21, 43.80it/s]

Epoch 1


100%|██████████| 938/938 [00:19<00:00, 48.68it/s]
  1%|          | 5/938 [00:00<00:21, 43.56it/s]

Nothing to delete
Epoch 2


100%|██████████| 938/938 [00:19<00:00, 47.83it/s]


Deleting intermediate checkpoints...


In [13]:
TrainedModelChkpt()

model_fn  name of the model function,model_hash  hash of the model configuration,dataset_fn  name of the dataset loader function,dataset_hash  hash of the configuration object,trainer_fn  name of the Trainer loader function,trainer_hash  hash of the configuration object,seed  Random seed that is passed to the model- and dataset-builder,comment  short description,score  loss,output  trainer object's output,fabrikant_name  Name of the contributor that added this entry,trainedmodel_ts  UTZ timestamp at time of insertion
nnfabrik.examples.mnist.model.mnist_model_fn,24922759b843076328c4f3b9df3f88d0,nnfabrik.examples.mnist.dataset.mnist_dataset_fn,9aee736870714f8b7c3cc084087ce886,nnfabrik.examples.mnist_checkpoint.trainer.chkpt_trainer_fn,91bc1788b17e9db4c5e16a286a35c0d0,7,A comment about the trainer!.A comment about the model!.A comment about the dataset!,84.2067,=BLOB=,Arne Nix,2020-11-12 01:16:58


In [14]:
Checkpoint()

trainer_fn  name of the Trainer loader function,trainer_hash  hash of the configuration object,dataset_fn  name of the dataset loader function,dataset_hash  hash of the configuration object,model_fn  name of the model function,model_hash  hash of the model configuration,seed  Random seed that is passed to the model- and dataset-builder,epoch  epoch of creation,score  current score at epoch,state  current state,fabrikant_name  Name of the contributor that added this entry,trainedmodel_ts  UTZ timestamp at time of insertion
,,,,,,,,,,,


As you can see, the entries in `Checkpoint` corresponding to our training run were deleted once our training finished.