In [1]:
#meta:tag=hide
%env METAFLOW_PROFILE=dev-valay
%env METAFLOW_UI_URL=


env: METAFLOW_PROFILE=dev-valay
env: METAFLOW_UI_URL=


In [2]:
#meta:tag=hide
import os
os.makedirs("temp_files", exist_ok=True)


# `@checkpoint`


<!-- START doctoc -->
<!-- END doctoc -->


## Introduction 

Metaflow naturally allows passing state around between `@step`s via the means of [Data Artifacts](https://docs.metaflow.org/metaflow/client#accessing-data) (i.e. values set to `self`). This means that when user code in a @step completes execution, all the values set to `self` are automatically saved as a Data Artifact and those are accessible to the next `@step`s. Such persistence of state allows users to ["resume" code execution](https://docs.metaflow.org/metaflow/debugging#how-to-use-the-resume-command) from a particular `@step` in case the user code crashes for some reason. 

However there may be scenarios where a user's `@step` will run for a very long duration and there are no natural boundaries to breakup this function. In such cases, the user may want to save the state of the function periodically and if the `@step` crashes, then the desired behavior would be to resume execution from the last saved state. This can be particularly useful when users use an `@retry` decorator. 

The `@checkpoint` decorator provides this functionality by allowing users to save intermediate state during `@step` execution. For example, consider a scenario where a user is training a large deep learning model that takes several hours to train. The user may want to save the model weights periodically during training so that if the training crashes, they can resume from the last saved checkpoint rather than starting from scratch. The `@checkpoint` decorator provides this functionality by allowing users to save intermediate state during `@step` execution.
 
Another common use case is when processing large datasets in batches. Users may want to checkpoint after processing each batch so that if there's a failure, they can resume from the last successfully processed batch rather than reprocessing everything from the beginning.The `@checkpoint` decorator is particularly powerful when combined with `@retry` - if a step fails, it will retry from the last checkpoint rather than from the very beginning of the step. This can significantly reduce recovery time and resource usage when dealing with long-running computations.

## Functional Overview

The `@checkpoint` decorator provides a simple way to save and load checkpoints within your Metaflow steps. It's particularly useful for machine learning workflows where you need to save model states, handle failures, and resume training. This can also be useful for any other form of long-running computation where users want to save the state of the computation periodically so that it can be resumed later. 

The `@checkpoint` decorator injects a `checkpoint` object in Metaflow's `current` object. This `checkpoint` object provides the `save` method that can _save any state present on disk_. Incase the `@step` crashes and retries, this previously saved checkpoint will be loaded to a directory that is exposed via the `current.checkpoint.directory` property. All checkpoints saved are scoped to the `@step` they are created in. All checkpoints are saved under the execution's [namespace](https://docs.metaflow.org/scaling/tagging#namespaces). This also implies that what ever checkpoint gets reloaded in subsequent retries will be for the same `@step` and under the same namespace.


### Simple Example

Consider the below example. The `count_values` `@step` will keep incrementing a `counter` until it reaches a certain value. At the end of the `@step`, if the counter has not reached a threshold value then the `@step` will crash. A `@retry` decorator is also added to the `@step` to retry the function in case it crashes. The `@checkpoint` decorator is used to save the state of the counter during the iterations and if the `@step` crashes, the `current.checkpoint.directory` will be populated with the last saved checkpoint. The contents of the counter present in this directory are loaded if the checkpoint is loaded. If a checkpoint is loaded for a `@step`, the `current.checkpoint.is_loaded` property will be set to `True`. The info about the checkpoint will be accessible via the `current.checkpoint.info` property.

Calling `current.checkpoint.save()` will save the contents of the `current.checkpoint.directory` to the datastore. The return value of this method is a reference to the checkpoint that was created. This reference can be loaded in later steps or within the same step upon retries. The `save` method can even take the following optional arguments:
- `path` : A custom path to a directory/file that will be saved as a checkpoint 
- `metadata` : A dictionary of metadata that to be saved with the checkpoint
- `name` : A custom name for the checkpoint to distinguish between different checkpoints created during a step 
- `latest` : A boolean flag to indicate if this checkpoint should be marked as the latest checkpoint. This will be checkpoint that will be reloaded by default unless the user specifies different settings in the `@checkpoint` decorator.



In [3]:
%%writefile temp_files/checkpoint_basic.py
#meta:tag=hide_output
from metaflow import FlowSpec, current, step, retry, checkpoint
import os
import time 

MAX_RETRIES = 5

class CheckpointSimpleFlow(FlowSpec):

    continue_until = 15

    @step #meta_hide_line
    def start(self):#meta_hide_line
        self.next(self.count_values)#meta_hide_line

    @checkpoint
    @retry(times=4)
    @step
    def count_values(self):
        counter = 0
        _file_path = os.path.join(
            current.checkpoint.directory, 
            "counter.txt"
        )
        if current.checkpoint.is_loaded and os.path.exists(_file_path):
            # Load the file written by a previous execution of the step
            checkpoint_info = current.checkpoint.info
            file = _read_file(_file_path)
            print("Loaded a checkpoint from pathspec %s, attempt %s" % (checkpoint_info.pathspec, checkpoint_info.attempt))
            print(
                "reading the counter value from the file",
                file,
            )
            counter = int(file)

        per_retry_range = 10
        for i in range(counter, counter + per_retry_range):
            counter = i
            _write_file(_file_path, str(i))
            # `current.checkpoint.save` will save everything in the
            # `current.checkpoint.directory`
            # Saving implies the objects are saved in the datastore
            # Saving a checkpoint will return a reference to the checkpoint
            self.final_checkpoint = current.checkpoint.save()
            time.sleep(0.1)
            
        print("Current value of counter", counter)
        if self.continue_until > counter:
            raise ValueError("retry")

        self.next(self.end)

    @step #meta_hide_line
    def end(self):#meta_hide_line
        pass #meta_hide_line

def _read_file(path):#meta_hide_line
    with open(path, "r") as f:#meta_hide_line
        return f.read().strip()#meta_hide_line

def _write_file(path, contents):#meta_hide_line
    with open(path, "w") as f:#meta_hide_line
        f.write(contents)#meta_hide_line

if __name__ == "__main__": #meta_hide_line
    CheckpointSimpleFlow() #meta_hide_line


Overwriting temp_files/checkpoint_basic.py


In [4]:
#meta:tag=hide_input
#meta:show_steps=count_values
! python temp_files/checkpoint_basic.py run 

[35m[1mMetaflow 2.12.36.post9-git09d02cb-dirty+obcheckpoint(0.1.4);ob(v1)[0m[35m[22m executing [0m[31m[1mCheckpointSimpleFlow[0m[35m[22m[0m[35m[22m for [0m[31m[1muser:valay@outerbounds.co[0m[35m[22m[K[0m[35m[22m[0m


[35m[22mValidating your flow...[K[0m[35m[22m[0m
[32m[1m    The graph looks good![K[0m[32m[1m[0m
[35m[22mRunning pylint...[K[0m[35m[22m[0m


[32m[1m    Pylint is happy![K[0m[32m[1m[0m


[35m2024-12-11 06:09:18.422 [0m[1mWorkflow starting (run-id 7453):[0m


[35m2024-12-11 06:09:19.557 [0m[32m[7453/start/47439 (pid 2203169)] [0m[1mTask is starting.[0m


[35m2024-12-11 06:09:21.586 [0m[32m[7453/start/47439 (pid 2203169)] [0m[1mTask finished successfully.[0m


[35m2024-12-11 06:09:21.843 [0m[32m[7453/count_values/47440 (pid 2203250)] [0m[1mTask is starting.[0m


[35m2024-12-11 06:09:36.330 [0m[32m[7453/count_values/47440 (pid 2203250)] [0m[22mCurrent value of counter 9[0m


[35m2024-12-11 06:09:41.164 [0m[32m[7453/count_values/47440 (pid 2203250)] [0m[22m<flow CheckpointSimpleFlow step count_values> failed:[0m


[35m2024-12-11 06:09:43.680 [0m[32m[7453/count_values/47440 (pid 2203250)] [0m[22mInternal error[0m
[35m2024-12-11 06:09:43.681 [0m[32m[7453/count_values/47440 (pid 2203250)] [0m[22mTraceback (most recent call last):[0m
[35m2024-12-11 06:09:43.682 [0m[32m[7453/count_values/47440 (pid 2203250)] [0m[22mFile "/home/ubuntu/metaflow/metaflow/cli.py", line 1167, in main[0m
[35m2024-12-11 06:09:43.682 [0m[32m[7453/count_values/47440 (pid 2203250)] [0m[22mstart(auto_envvar_prefix="METAFLOW", obj=state)[0m
[35m2024-12-11 06:09:43.682 [0m[32m[7453/count_values/47440 (pid 2203250)] [0m[22mFile "/home/ubuntu/metaflow/metaflow/tracing/tracing_modules.py", line 111, in wrapper_func[0m
[35m2024-12-11 06:09:43.682 [0m[32m[7453/count_values/47440 (pid 2203250)] [0m[22mreturn func(args, kwargs)[0m
[35m2024-12-11 06:09:43.682 [0m[32m[7453/count_values/47440 (pid 2203250)] [0m[22m^^^^^^^^^^^^^^^^^^^^^[0m
[35m2024-12-11 06:09:43.682 [0m[32m[7453/count_val

[35m2024-12-11 06:09:43.868 [0m[32m[7453/count_values/47440 (pid 2203250)] [0m[22m^^^^^^^^^^^^^^^^^^^^^^^^^^[0m
[35m2024-12-11 06:09:43.868 [0m[32m[7453/count_values/47440 (pid 2203250)] [0m[22mFile "/home/ubuntu/metaflow/metaflow/_vendor/click/core.py", line 782, in main[0m
[35m2024-12-11 06:09:43.868 [0m[32m[7453/count_values/47440 (pid 2203250)] [0m[22mrv = self.invoke(ctx)[0m
[35m2024-12-11 06:09:43.868 [0m[32m[7453/count_values/47440 (pid 2203250)] [0m[22m^^^^^^^^^^^^^^^^[0m
[35m2024-12-11 06:09:43.868 [0m[32m[7453/count_values/47440 (pid 2203250)] [0m[22mFile "/home/ubuntu/metaflow/metaflow/_vendor/click/core.py", line 1259, in invoke[0m
[35m2024-12-11 06:09:43.868 [0m[32m[7453/count_values/47440 (pid 2203250)] [0m[22mreturn _process_result(sub_ctx.command.invoke(sub_ctx))[0m
[35m2024-12-11 06:09:43.868 [0m[32m[7453/count_values/47440 (pid 2203250)] [0m[22m^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^[0m
[35m2024-12-11 06:09:43.868 [0m[32m[74

[35m2024-12-11 06:09:43.975 [0m[32m[7453/count_values/47440 (pid 2203250)] [0m[1mTask failed.[0m


[35m2024-12-11 06:09:44.102 [0m[32m[7453/count_values/47440 (pid 2204262)] [0m[1mTask is starting (retry).[0m


[35m2024-12-11 06:09:45.287 [0m[32m[7453/count_values/47440 (pid 2204262)] [0m[22m[@checkpoint] Loading the following checkpoint:[0m


[35m2024-12-11 06:09:46.579 [0m[32m[7453/count_values/47440 (pid 2204262)] [0m[22m[pathspec] CheckpointSimpleFlow/7453/count_values/47440[0m
[35m2024-12-11 06:09:46.579 [0m[32m[7453/count_values/47440 (pid 2204262)] [0m[22m[key] mf.checkpoints/checkpoints/artifacts/CheckpointSimpleFlow/count_values/26ec4b03ee0e/6ed314b634c7/1e2df857.0.mfchckpt.9[0m
[35m2024-12-11 06:09:46.579 [0m[32m[7453/count_values/47440 (pid 2204262)] [0m[22m[created on] 2024-12-11T06:09:36.055073[0m
[35m2024-12-11 06:09:46.579 [0m[32m[7453/count_values/47440 (pid 2204262)] [0m[22m[url] s3://obp-475b0e-metaflow/metaflow/mf.checkpoints/checkpoints/artifacts/CheckpointSimpleFlow/count_values/26ec4b03ee0e/6ed314b634c7/1e2df857.0.mfchckpt.9[0m
[35m2024-12-11 06:09:46.579 [0m[32m[7453/count_values/47440 (pid 2204262)] [0m[22mLoaded a checkpoint from pathspec CheckpointSimpleFlow/7453/count_values/47440, attempt 0[0m


[35m2024-12-11 06:09:59.759 [0m[32m[7453/count_values/47440 (pid 2204262)] [0m[22mreading the counter value from the file 9[0m
[35m2024-12-11 06:09:59.759 [0m[32m[7453/count_values/47440 (pid 2204262)] [0m[22mCurrent value of counter 18[0m


[35m2024-12-11 06:10:09.146 [0m[32m[7453/count_values/47440 (pid 2204262)] [0m[1mTask finished successfully.[0m


[35m2024-12-11 06:10:09.503 [0m[32m[7453/end/47448 (pid 2205202)] [0m[1mTask is starting.[0m


[35m2024-12-11 06:10:12.237 [0m[32m[7453/end/47448 (pid 2205202)] [0m[1mTask finished successfully.[0m


[35m2024-12-11 06:10:12.420 [0m[1mDone![0m


### Load Policies And Accessing Past Checkpoints. 

The `@checkpoint` decorator provides a `load_policy` argument which alters the checkpoint loading behavior in Metaflow. The `load_policy` argument can take the following values:
- `fresh` (default): The latest checkpoint created within the task will be loaded. This means that no checkpoint will ever be loaded at the start of the task execution on the very first attempt. Upon subsequent retries, the latest checkpoint will be loaded. 
- `eager` : The latest available checkpoint associated with a step (which might have even been created in a previous execution) will be loaded. 
- `none`: No checkpoint will ever be loaded. It's left to the user to explicitly choose/load a checkpoint within the user code. Users can select a checkpoint within user code by accessing the `current.checkpoint.list` method and the load the checkpoint using the `current.checkpoint.load` method.

The `current.checkpoint` object also provides a `list` method. The `current.checkpoint.list` method returns a list of checkpoints associated with the current step. All checkpoints listed will be scoped to the current step and the current task's namespace. The `list` method can take several optional arguments such as: 
- `name`: A string to filter checkpoints by name
- `task`: A Metaflow pathspec string or [Metaflow Task object](https://docs.metaflow.org/api/client#task) to filter checkpoints
- `attempt`: An integer to filter checkpoints by attempt number
- `within_task` : A boolean flag to indicate if the checkpoints should be filtered by the currently running task or for all previous tasks of the step.


### Saving Checkpoints within Subprocesses.
[TODO]

### Saving / Loading Checkpoints within a foreach
[TODO]

### Saving / Loading Checkpoints for Gang Scheduled `@parallel` steps
[TODO]