# A proposal for defining `jax.save` and `jax.load`
... or perhaps something like `jax.save_to_dir` and `jax.load_from_dir`

### Objectives for the new API:

The new API should be:
- simple
- strict by default (no pickle), but allow pickle
    - but allow storing in-built values like numbers (e.g., int, float), strings and bytes
- not be pickle, should be significantly orthogonal to what pickle does
    - NOTE: the effort to support some pickling JAX arrays is also important
- provide the ability for partial reads
- allow to read and write arrays in a distributed sharded context
    - this is already supported with tensorstore (and Yash's low-level API)
- have an `async` version for both save and load
- have the ability to read in just the structure of the data
- have a flat storage of leaf data that the user can access themselves
- support remote reading and writing from/to cloud storage, GCS and S3 at a minimum

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import tempfile
from pathlib import Path
from subprocess import check_output
from pprint import pprint
from dataclasses import dataclass
from typing import Any
import pickle
import collections

os.environ["JAX_ENABLE_X64"] = "1"

import jax
from jax.experimental.array_serialization import save, load, load_pytree
from jax.experimental.array_serialization import async_save, async_load
from jax.experimental.array_serialization import nonblocking_load, nonblocking_save
from jax import numpy as jnp, random

In [3]:
data = ["hello", {"world": ["!", (1, 2)]}]

In [4]:
tmpdir = tempfile.TemporaryDirectory().name
fut = save(data, tmpdir)

In [5]:
print("The `pytreedef.json` is a human-readable pytree stored format")
print((Path(tmpdir) / "pytreedef.json").read_text())
print("-" * 80)
print("The leaf data is organized in a flat directory under `leaf_data`")
print(check_output(["tree", tmpdir]).decode()) 

The `pytreedef.json` is a human-readable pytree stored format
{
  "__jax_tree_repr": {
    "node_type": "builtins.list",
    "node_data_ref": null,
    "children": [
      {
        "node_type": "leaf",
        "node_data_ref": null,
        "children": [],
        "leaf_id": "str -> e8f29fdd-4cd4-49a4-82b5-937537fb9973"
      },
      {
        "node_type": "builtins.dict",
        "node_data_ref": [
          "world"
        ],
        "children": [
          {
            "node_type": "builtins.list",
            "node_data_ref": null,
            "children": [
              {
                "node_type": "leaf",
                "node_data_ref": null,
                "children": [],
                "leaf_id": "str -> d76788ae-c640-4d96-8f18-72e6f92d6f20"
              },
              {
                "node_type": "builtins.tuple",
                "node_data_ref": null,
                "children": [
                  {
                    "node_type": "leaf",
                    "n

In [6]:
# read only the data structure
print("-" * 80)
print("PyTree Structure:")
pytree_structure = load_pytree(tmpdir)
pprint(pytree_structure)

--------------------------------------------------------------------------------
PyTree Structure:
['str -> e8f29fdd-4cd4-49a4-82b5-937537fb9973',
 {'world': ['str -> d76788ae-c640-4d96-8f18-72e6f92d6f20',
            ('int -> 4dceae7b-649a-4225-a16a-8ae116afe1ab',
             'int -> cb1408ff-c2b6-46e8-bac4-a092044647ca')]}]


In [7]:
# read only integers back
print("-" * 80)
print("Partial read of data:")
pytree_structure = jax.tree.map(lambda x: None if x.startswith("str") else x, pytree_structure)
new_data = load(tmpdir, pytree=pytree_structure)
pprint(new_data)

--------------------------------------------------------------------------------
Partial read of data:
[None, {'world': [None, (1, 2)]}]


## Allowing custom nodes with `pickle`

In [8]:
#@functools.partial(jax.tree_util.register_dataclass, data_fields=["a", "c"], 
#                   meta_fields=["op"])
@jax.tree_util.register_pytree_node_class
@dataclass
class D:
  op: str
  a: Any
  c: int

  def tree_flatten(self):
    return ((self.a, self.c), self.op)

  @classmethod
  def tree_unflatten(cls, aux_data, children):
    return cls(aux_data, *children)
  
tmpdir = tempfile.TemporaryDirectory().name
try:
  save({"dataclass": D("tanh", random.normal(random.key(0), (7,)), 5), 
        "a": 1}, tmpdir)
except ValueError:
  print("Correctly refusing to serialize custom objects")

save({"dataclass": D("tanh", random.normal(random.key(0), (7,)), 5), 
      "a": 1}, tmpdir, pickle_module=pickle)
print("Serialized with pickle")

print("Extra folder: `node_data`")
print(check_output(["tree", "-L", "2", tmpdir]).decode()) 
print("-" * 80)
print((Path(tmpdir) / "pytreedef.json").read_text())

Correctly refusing to serialize custom objects
Serialized with pickle
Extra folder: `node_data`
[01;34m/tmp/tmph7le3ocf[0m
├── [01;34mleaf_data[0m
│   ├── [01;34marray_store.tensorstore[0m
│   └── [01;31mobj_data.zip[0m
├── [01;31mnode_data.zip[0m
└── [00mpytreedef.json[0m

3 directories, 3 files

--------------------------------------------------------------------------------
{
  "__jax_tree_repr": {
    "node_type": "builtins.dict",
    "node_data_ref": [
      "a",
      "dataclass"
    ],
    "children": [
      {
        "node_type": "leaf",
        "node_data_ref": null,
        "children": [],
        "leaf_id": "int -> 9347e37c-b722-417a-9672-0d8661315cb1"
      },
      {
        "node_type": "__main__.D",
        "node_data_ref": "b8338c1e-0b51-4562-9ca7-3b5b54d2131f",
        "children": [
          {
            "node_type": "leaf",
            "node_data_ref": null,
            "children": [],
            "leaf_id": "Array[[7], float64] -> 9ff32f46-6620-42a6-89

In [9]:
try:
  load(tmpdir)
except ValueError:
  print("Correctly refuses to read without pickle")

print("Reads correctly with pickle")
print(load(tmpdir, pickle_module=pickle))

Correctly refuses to read without pickle
Reads correctly with pickle
{'a': 1, 'dataclass': D(op='tanh', a=Array([ 0.33864229, -0.59818536,  2.29231856,  0.27937332,  1.88502002,
       -2.09015473,  0.0490814 ], dtype=float64), c=5)}


## Best-effort reading when pickled objects are no longer available

We can also attempt to load with `best_effort=True` without pickle or if the class definition / custom node registration has been lost.

This will print a warning and will read the children of the former custom node and organized them in a list.

All node data (e.g., static fields) are not read.

In [10]:
print(load(tmpdir, best_effort=True))



{'a': 1, 'dataclass': [Array([ 0.33864229, -0.59818536,  2.29231856,  0.27937332,  1.88502002,
       -2.09015473,  0.0490814 ], dtype=float64), 5]}


In [11]:
load_pytree(tmpdir, best_effort=True)



{'a': 'int -> 9347e37c-b722-417a-9672-0d8661315cb1',
 'dataclass': ['Array[[7], float64] -> 9ff32f46-6620-42a6-8934-5bfde6a36a43',
  'int -> f117c458-63ec-4457-97d8-8b4bc3f12f97']}

Careful! PyTree utils do not preserve key order

In [15]:
d = {"c": 1, "a": 2}
d2 = jax.tree.unflatten(jax.tree.structure(d), jax.tree.flatten(d)[0])
print(d)
print(d2)

{'c': 1, 'a': 2}
{'a': 2, 'c': 1}


Using `OrderedDict` is necessary

In [21]:
from collections import OrderedDict as odict
d = odict({"c": 1, "a": 2})
d2 = jax.tree.unflatten(jax.tree.structure(d), jax.tree.flatten(d)[0])
print(d)
print(d2)

OrderedDict([('c', 1), ('a', 2)])
OrderedDict([('c', 1), ('a', 2)])


In [22]:
fut = nonblocking_save(odict({"c": jnp.ones(100), "a": 2}), "test_checkpoint")
print(fut.pytree)

OrderedDict([('c', ShapeDtypeStruct(shape=(100,), dtype=float64)), ('a', 2)])


In [23]:
fut.done()

True

In [24]:
fut.result()

In [25]:
load("test_checkpoint")

OrderedDict([('c',
              Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                     1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                     1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                     1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                     1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                     1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],      dtype=float64)),
             ('a', 2)])

### Extended safe-modules

We probably want to be able to support non-JAX, but very standard collections
like e.g., flax's `FrozenDict` (and I can't think of anything else).

To do this programmatically, we can add **string** entries to 
`new_api._EXTENDED_NODE_TYPES_MAP` which we then use `importlib` on on the fly.

The alternative is to allow this importlib on-the-fly import for all modules
that are:
1. not a member of `__main__` module 
2. do not contain any non-JSON serializable `node_data()`

But this would mean calling `importlib.import_module` on a data string, so it's
pretty unsafe.

In [26]:
with tempfile.TemporaryDirectory() as tmpdir:
  save(collections.OrderedDict(a=1, b=jnp.ones(10)), tmpdir)
  restored_tree = load(tmpdir)

In [27]:
try:
  from flax.core.frozen_dict import FrozenDict
  # FrozenDict is added to the list at the moment
  with tempfile.TemporaryDirectory() as tmpdir:
    save(FrozenDict(a=1, b=jnp.ones(10)), tmpdir)
    restored_tree = load(tmpdir)
except ImportError:
  pass

### Notes

1. Must the resulting checkpoint be a directory? Can it not be a file?

> The underlying checkpoint is a directory, tensorstore doesn't really support
> writing single-file checkpoints that are well read-optimized.

> It's possible to zip the result, piece-by-piece without wasting disk space,
> which is probably a direction to explore. NOTE: tensorstore seems to support
> **reading** from Python zipfile handles directly.

2. How fast is saving the checkpoint given I/O can be slow?

> Thanks to the underlying async usage, it should be pretty fast.

3. Is RAM usage controlled?

> Not at the moment, but it's possible to improve this. We can use tensorstore
> to limit array writing memory usage and we can rewrite non-array writing to be
> non-buffered through a bytes or text object (they are buffered to more cleanly
> support file://, gcs://, s3:// alternatives).

4. Why is the `pytreedef.json` weird like that?

> The "cleanest" way to save a pytree structure is to just use a JSON
> representation with leafs replaced with their data reference id. However,
> JSON doesn't distinguish between tuple, list and so it doesn't really preserve
> the actual pytree, even if it's limited to only in-built types. Also, when the
> pytree contains custom nodes, we need a custom tree representation anyway.

5. Isn't overwriting a **directory** checkpoint extremely dangerous if the 
"checkpoint" path turns out to be e.g. "/usr/local"?

> Yes, but we first check for files and directories we didn't create and refuse
> to overwrite if there are any.

6. Why doesn't Python LSP not work with synchronous versions: `save`, `load`?

I don't know, I need to fix it.

7. Restored pytrees have dictionary node keys in a different order. Why?

> This is pytree behavior, dictionaries order is not preserved.
> [https://github.com/google/jax/issues/4085](https://github.com/google/jax/issues/4085)

8. Which host (process_id) writes the save directory and what part of it?

> Currently if remote path is detected, only `jax.process_idx() == 0` writes 
> non-arrays. All processes write arrays as that is what tensorstore expects.

9. TODO

> - passing sharding not tested

# Experiments

In [28]:
import time
import os
import shutil

from orbax import checkpoint as ocp
import numpy as np
import tensorstore as ts

def orbax_save(state, path, ocdbt_target_file_size: int = 2 * 1024 ** 3):
  start = time.time()
  ocp.PyTreeCheckpointer(use_ocdbt=True, use_zarr3=True).save(
    path, ocp.args.PyTreeSave(
      item=state, ocdbt_target_data_file_size=ocdbt_target_file_size))
  print(f"Saved checkpoint to {path} in {time.time() - start:.2f} sec")

def orbax_load(path, shape_dtype):
  start = time.time()
  state = ocp.PyTreeCheckpointer(use_ocdbt=True, use_zarr3=True).restore(
    path, ocp.args.PyTreeRestore(
      shape_dtype, restore_args=ocp.checkpoint_utils.construct_restore_args(shape_dtype),
  ))
  end = time.time()
  print(f"Loaded checkpoint from {path} in {end - start:.2f} sec")
  return state

In [32]:
state = {"k": jnp.array([1, 2, 3]), "c": np.random.randn(100, 100), 
         "d": np.array(2 ** 48),
         "n": [np.random.randn(np.random.randint(0, 100)) for _ in range(1000)], 
         "obj": "hello"
        }

In [33]:
path = Path("test_chkpt2").absolute()
#if path.exists():
#  shutil.rmtree(path)
save(state, path)

In [43]:
sum(len(fnames) for (root, dirs, fnames) in os.walk(path))

89

In [35]:
fut = nonblocking_load(path)

In [46]:
fut.pytree

{'c': ShapeDtypeStruct(shape=(100, 100), dtype=float64),
 'd': ShapeDtypeStruct(shape=(), dtype=int64),
 'k': ShapeDtypeStruct(shape=(3,), dtype=int64),
 'n': [ShapeDtypeStruct(shape=(58,), dtype=float64),
  ShapeDtypeStruct(shape=(18,), dtype=float64),
  ShapeDtypeStruct(shape=(99,), dtype=float64),
  ShapeDtypeStruct(shape=(17,), dtype=float64),
  ShapeDtypeStruct(shape=(10,), dtype=float64),
  ShapeDtypeStruct(shape=(24,), dtype=float64),
  ShapeDtypeStruct(shape=(12,), dtype=float64),
  ShapeDtypeStruct(shape=(94,), dtype=float64),
  ShapeDtypeStruct(shape=(19,), dtype=float64),
  ShapeDtypeStruct(shape=(17,), dtype=float64),
  ShapeDtypeStruct(shape=(69,), dtype=float64),
  ShapeDtypeStruct(shape=(48,), dtype=float64),
  ShapeDtypeStruct(shape=(12,), dtype=float64),
  ShapeDtypeStruct(shape=(81,), dtype=float64),
  ShapeDtypeStruct(shape=(61,), dtype=float64),
  ShapeDtypeStruct(shape=(8,), dtype=float64),
  ShapeDtypeStruct(shape=(24,), dtype=float64),
  ShapeDtypeStruct(shape=(2

In [37]:
restored = load(path)

In [38]:
restored["d"]

Array(281474976710656, dtype=int64)

In [40]:
_ = [np.testing.assert_allclose(x, y) if isinstance(x, (np.ndarray, jax.Array)) else None
for x, y in jax.util.safe_zip(jax.tree.leaves(state), jax.tree.leaves(restored))]