# A proposal for defining `jax.tree.save` and `jax.tree.load`
... or perhaps something like `jax.tree.save_to_dir` and `jax.tree.load_from_dir`

### Objectives for the new API:

The new API should be:
- simple
- strict by default (no pickle), but allow pickle
    - but allow storing in-built values like numbers (e.g., int, float), strings and bytes
- not be pickle, should be significantly orthogonal to what pickle does
- partial reads and writes
- distributed sharded context
    - already supported with tensorstore (and Yash's low-level API)
- nonblocking version for both save and load
- read just the structure of the data
- support remote reading and writing from/to cloud storage, GCS and S3 at a minimum

In [1]:
%load_ext autoreload
%autoreload 2

import os
import tempfile
from pathlib import Path
from subprocess import check_output
from pprint import pprint
import functools
import json
import time
from dataclasses import dataclass
from typing import Any
import pickle
import collections
from types import SimpleNamespace

os.environ["JAX_ENABLE_X64"] = "1"

import jax
from jax.experimental.array_serialization.pytree_serialization import (
  save, load, load_pytreedef, nonblocking_save, nonblocking_load)
from jax.experimental.array_serialization.pytree_serialization_utils import (
  register_pytree_node_serialization, 
  #register_pytree_leaf_serialization
)
from jax.experimental.array_serialization import pytree_serialization_utils as utils
from jax import numpy as jnp, random
from jax import tree
import numpy as np

tree.save = save
tree.load = load
tree.load_pytreedef = load_pytreedef
tree.nonblocking_load = nonblocking_load
tree.nonblocking_save = nonblocking_save

## Register a custom node

In [None]:
def json_dumps(aux_data):
  print(f"hello from json_dumps: {aux_data}")
  return json.dumps(aux_data).encode("utf-8")

def json_loads(aux_data):
  return json.loads(aux_data)

@functools.partial(register_pytree_node_serialization, 
                   serialized_name="class_D",
                   serialize_auxdata=json_dumps, 
                   deserialize_auxdata=json_loads)
@functools.partial(jax.tree_util.register_dataclass, data_fields=["a", "b"], 
                   meta_fields=[])
@dataclass
class D:
  a: int
  b: int

In [None]:
def loads(x):
  return json.loads(x)

dumps = lambda x: json.dumps(x).encode("utf-8")
print(f"loads hash = {hash(loads)}")

In [None]:
tempdir = tempfile.TemporaryDirectory().name
save(bytearray(b"hello"), tempdir)
out = load(tempdir)

In [None]:
tempdir = tempfile.TemporaryDirectory().name
print(tempdir)
save(D(1, 2), tempdir)
load_pytreedef(tempdir)
load(tempdir)

## Simple use-case

In [2]:
jax.typeof(jnp.ones(10))

ShapedArray(float64[10])

In [3]:
#data = ["hello", {"world": ["!", (1, 2)]}, None, jnp.ones(5)]
data = [jnp.array(1), {"world": [jnp.ones((2, 1024, 1024)), (jnp.zeros(3), jnp.ones(4))]}, jnp.ones(5)]

tempdir = tempfile.TemporaryDirectory().name
print(tempdir)
%time fut = tree.save(data, tempdir)
%time restored_data = tree.load(tempdir)
print(restored_data)

jax.tree.map(lambda x, y: jnp.all(x == y), data, restored_data)

/tmp/tmpd_q_p7_z
CPU times: user 42.8 ms, sys: 40.4 ms, total: 83.1 ms
Wall time: 94.3 ms
CPU times: user 13.2 ms, sys: 12.3 ms, total: 25.5 ms
Wall time: 19.6 ms
[Array(1, dtype=int64), {'world': [Array([[[1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        ...,
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.]],

       [[1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        ...,
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.]]], dtype=float64), (Array([0., 0., 0.], dtype=float64), Array([1., 1., 1., 1.], dtype=float64))]}, Array([1., 1., 1., 1., 1.], dtype=float64)]


[Array(True, dtype=bool),
 {'world': [Array(True, dtype=bool),
   (Array(True, dtype=bool), Array(True, dtype=bool))]},
 Array(True, dtype=bool)]

In [4]:
tree.load_pytreedef(tempdir)

[ShapeDtypeStruct(shape=(), dtype=int64),
 {'world': [ShapeDtypeStruct(shape=(2, 1024, 1024), dtype=float64),
   (ShapeDtypeStruct(shape=(3,), dtype=float64),
    ShapeDtypeStruct(shape=(4,), dtype=float64))]},
 ShapeDtypeStruct(shape=(5,), dtype=float64)]

In [5]:
print("The `pytreedef.json` is a human-readable pytree stored format")
print((Path(tempdir) / "pytreedef.json").read_text())
print("-" * 80)
print("The leaf data is organized in a flat directory under `leaf_data`")
print(check_output(["tree", tempdir]).decode()) 

The `pytreedef.json` is a human-readable pytree stored format
{
  "__jax_pytreedef_repr": "BAAAAJD///8IAAAAAAAAAwMAAACQAAAAGAAAAAQAAACA////AAAKABAADwAIAAQACgAAAAwAAAAcAAAAAAAABAEAAAAEAAAABQAAAHdvcmxkAAAAAQAAAAQAAADo////CAAAAAAAAAMCAAAAMAAAAAwAAAAIAAwACwAEAAgAAAAIAAAAAAAAAgIAAAAMAAAABAAAAPT////4/////P///wQABAAEAAAA",
  "__jax_leaf_ids": [
    "Array(int64[]) -> 0",
    "Array(float64[2, 1024, 1024]) -> 1",
    "Array(float64[3]) -> 2",
    "Array(float64[4]) -> 3",
    "Array(float64[5]) -> 4"
  ]
}
--------------------------------------------------------------------------------
The leaf data is organized in a flat directory under `leaf_data`
[01;34m/tmp/tmpd_q_p7_z[0m
├── [01;34marray_store.tensorstore[0m
│   ├── [01;34md[0m
│   │   ├── [00m0f97594e20094eaac7ce29d7b5c0b315[0m
│   │   ├── [00mbfe3326f332b44066f7c6d524af98a7c[0m
│   │   └── [00mea4c038215fcfe6a5f22594f8c5366c9[0m
│   └── [00mmanifest.ocdbt[0m
└── [00mpytreedef.json[0m

3 directories, 5 files



In [11]:
# read only the data structure
print("-" * 80)
print("PyTree Structure:")
pytree_structure = tree.load_pytreedef(tempdir)
pprint(pytree_structure)

# read only integers back
print("-" * 80)
print("Partial read of data:")
#pytree_structure = jax.tree.map(lambda x: x if x.startswith("int") else None, 
#                                pytree_structure)
new_data = tree.load(tempdir, pytree=pytree_structure)
pprint(new_data)

--------------------------------------------------------------------------------
PyTree Structure:
[ShapeDtypeStruct(shape=(), dtype=int64),
 {'world': [ShapeDtypeStruct(shape=(2, 1024, 1024), dtype=float64),
            (ShapeDtypeStruct(shape=(3,), dtype=float64),
             ShapeDtypeStruct(shape=(4,), dtype=float64))]},
 ShapeDtypeStruct(shape=(5,), dtype=float64)]
--------------------------------------------------------------------------------
Partial read of data:
[Array(1, dtype=int64),
 {'world': [Array([[[1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        ...,
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.]],

       [[1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        ...,
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 

In [None]:
pytree_mask = jax.tree.map(condition_fn, tree.load_pytreedef(tempdir))
new_data = tree.load(tempdir, pytree=pytree_mask)

## Custom node registration

In [None]:
tempdir = tempfile.TemporaryDirectory().name
arr = jnp.ones(10)
fut = nonblocking_save(arr, tempdir)
while not fut.done():
  time.sleep(1e-3)
fut = nonblocking_load(tempdir)
print(fut.pytree)

In [None]:
tempdir = tempfile.TemporaryDirectory().name
@functools.partial(register_pytree_node_serialization,
                   serialized_name="CustomNode",
                   serialize_auxdata=json.dumps,
                   deserialize_auxdata=json.loads)
@functools.partial(jax.tree_util.register_dataclass, data_fields=["a", "b"], 
                   meta_fields=["op"])
@dataclass 
class CustomNode:
  a: Any
  b: Any
  op: str

data = ["hello", {"world": ["!", (1, b"data")]}, None, (CustomNode(1, 2, "hi"))]
tree.save(data, tempdir)
out = tree.load(tempdir)
print(out)

In [None]:
@functools.partial(register_pytree_node_serialization,
                   serialized_name="CustomLeaf",
                   serialize_auxdata=lambda p: json.dumps(p.a),
                   deserialize_auxdata=lambda x: CustomLeaf(json.loads(x)))
@jax.tree_util.register_static
@dataclass
class CustomLeaf:
  a: int = 2
  
@functools.partial(register_pytree_node_serialization,
                   serialized_name="CustomNode",
                   serialize_auxdata=json.dumps,
                   deserialize_auxdata=json.loads)
@functools.partial(jax.tree_util.register_dataclass, data_fields=["a", "b"], 
                   meta_fields=["op"])
@dataclass 
class CustomNode:
  a: Any
  b: Any
  op: str

def serialize_D(aux_data):
  return json.dumps(aux_data)

def deserialize_D(aux_data):
  return json.loads(aux_data)

data = ["hello", {"world": ["!", (1, 2)]}, None, (CustomLeaf(), CustomNode(1, 2, "hi"))]

jax.tree.flatten(data)

save(data, "/tmp/hello")
out = load("/tmp/hello")

## Custom Leaf

In [12]:
tempdir = tempfile.TemporaryDirectory().name
class AutonomousDrivingMap:
  def __init__(self, chunks: int):
    self.lines = [["blob" for _ in range(j)] for j in range(chunks)]

  @staticmethod
  def serialize(self):
    return json.dumps(self.lines)

  @staticmethod
  def deserialize(data):
    (a := AutonomousDrivingMap(0)).lines = json.loads(data)
    return a

AutonomousDrivingMap = pybind_mod.cpp_class

register_pytree_leaf_serialization(
  AutonomousDrivingMap, serialized_name="AutonomousDrivingMap",
  serialize_leaf=AutonomousDrivingMap.serialize,
  deserialize_leaf=AutonomousDrivingMap.deserialize)

data = ["hello", {"world": ["!", (1, b"data")]}, None, (AutonomousDrivingMap(5))]
tree.save(data, tempdir)
print(tree.load(tempdir))

NameError: name 'pybind_mod' is not defined

## Fallback serialization `pickle` and `json`

In [13]:
@functools.partial(jax.tree_util.register_dataclass, data_fields=["a", "c"], 
                   meta_fields=["op"])
@dataclass
class UnregisteredCustomNode:
  op: str
  a: Any
  c: int
  
tmpdir = tempfile.TemporaryDirectory().name
try:
  save({"dataclass": UnregisteredCustomNode(
    "tanh", random.normal(random.key(0), (7,)), 5), "a": 1}, tmpdir)
except ValueError:
  print("Correctly refusing to serialize custom objects")

Correctly refusing to serialize custom objects


## Incremental writing is supported

In [14]:
tempdir = tempfile.TemporaryDirectory().name
incremental_tree = [None, None, None]
save(incremental_tree, tempdir)
incremental_tree[0] = 1
save(incremental_tree, tempdir, partial_write=True)
ret = load(tempdir)
print(ret)
assert ret[0] == 1 and ret[1] is None and ret[2] is None
incremental_tree[0] = None
incremental_tree[1] = 4
save(incremental_tree, tempdir, partial_write=True)
ret = load(tempdir)
print(ret)
assert ret[0] == 1 and ret[1] == 4 and ret[2] is None
incremental_tree[0], incremental_tree[2] = None, jnp.ones(4)
save(incremental_tree, tempdir, partial_write=True)
ret = load(tempdir)
print(ret)
#assert (ret[0] == 1 and ret[1] is None and (np.testing.assert_allclose(ret[2], jnp.ones(4)) is None))

AttributeError: 'NoneType' object has no attribute 'dtype'

In [None]:
tempdir = tempfile.TemporaryDirectory().name
tree.save([None, None, None], tempdir)
print(tree.load(tempdir))
tree.save([1, None, None], tempdir, partial_write=True)
print(tree.load(tempdir))
tree.save([None, 4, None], tempdir, partial_write=True)
print(tree.load(tempdir))
tree.save([None, None, jnp.ones(10)], tempdir, partial_write=True)
print(tree.load(tempdir))

Careful! PyTree utils do not preserve key order

In [None]:
d = {"c": 1, "a": 2}
d2 = jax.tree.unflatten(jax.tree.structure(d), jax.tree.flatten(d)[0])
print(d)
print(d2)

Using `OrderedDict` is necessary

In [None]:
from collections import OrderedDict as odict
d = odict({"c": 1, "a": 2})
d2 = jax.tree.unflatten(jax.tree.structure(d), jax.tree.flatten(d)[0])
print(d)
print(d2)

In [None]:
tempdir = tempfile.TemporaryDirectory().name
fut = nonblocking_save(odict({"c": jnp.ones(100), "a": 2, "none": None}), tempdir)
print(fut.pytree)

In [None]:
load(tempdir)

In [None]:
pytree_def = load_pytreedef(tempdir)
arr_tree = jax.tree.map(lambda x: x if x.startswith("Array") else None, pytree_def)
other_tree = jax.tree.map(lambda x: x if not x.startswith("Array") else None, pytree_def)

In [None]:
print(load(tempdir, pytree=arr_tree))
print("---")
print(load(tempdir, pytree=other_tree))
print("---")
tree_together = jax.tree.map(lambda x, y: x if x is not None else y, load(tempdir, pytree=arr_tree), load(tempdir, pytree=other_tree), is_leaf=lambda x: x is None)
print(tree_together)

In [None]:
ret = nonblocking_load(tempdir)
while not ret.done():
  pass
ret = ret.result()
ret

### Extended safe-modules

We probably want to be able to support non-JAX, but very standard collections
like e.g., flax's `FrozenDict` (and I can't think of anything else).

To do this programmatically, we can add **string** entries to 
`new_api._EXTENDED_NODE_TYPES_MAP` which we then use `importlib` on on the fly.

The alternative is to allow this importlib on-the-fly import for all modules
that are:
1. not a member of `__main__` module 
2. do not contain any non-JSON serializable `node_data()`

But this would mean calling `importlib.import_module` on a data string, so it's
pretty unsafe.

In [None]:
with tempfile.TemporaryDirectory() as tmpdir:
  save(collections.OrderedDict(a=1, b=jnp.ones(10)), tmpdir)
  restored_tree = load(tmpdir)

In [None]:
FrozenDict(dict(a=1))

try:
  from flax.core.frozen_dict import FrozenDict
  # FrozenDict is added to the list at the moment
  with tempfile.TemporaryDirectory() as tmpdir:
    save(FrozenDict(a=1, b=jnp.ones(10)), tmpdir)
    restored_tree = load(tmpdir)
except ImportError:
  pass

In [None]:
d = FrozenDict(dict(a=1))

### Notes

1. Must the resulting checkpoint be a directory? Can it not be a file?

> The underlying checkpoint is a directory, tensorstore doesn't really support
> writing single-file checkpoints that are well read-optimized.

> It's possible to zip the result, piece-by-piece without wasting disk space,
> which is probably a direction to explore. NOTE: tensorstore seems to support
> **reading** from Python zipfile handles directly.

2. How fast is saving the checkpoint given I/O can be slow?

> Thanks to the underlying async usage, it should be pretty fast.

3. Is RAM usage controlled?

> Not at the moment, but it's possible to improve this. We can use tensorstore
> to limit array writing memory usage and we can rewrite non-array writing to be
> non-buffered through a bytes or text object (they are buffered to more cleanly
> support file://, gcs://, s3:// alternatives).

4. Why is the `pytreedef.json` weird like that?

> The "cleanest" way to save a pytree structure is to just use a JSON
> representation with leafs replaced with their data reference id. However,
> JSON doesn't distinguish between tuple, list and so it doesn't really preserve
> the actual pytree, even if it's limited to only in-built types. Also, when the
> pytree contains custom nodes, we need a custom tree representation anyway.

5. Isn't overwriting a **directory** checkpoint extremely dangerous if the 
"checkpoint" path turns out to be e.g. "/usr/local"?

> Yes, but we first check for files and directories we didn't create and refuse
> to overwrite if there are any.

6. Why doesn't Python LSP not work with synchronous versions: `save`, `load`?

I don't know, I need to fix it.

7. Restored pytrees have dictionary node keys in a different order. Why?

> This is pytree behavior, dictionaries order is not preserved.
> [https://github.com/google/jax/issues/4085](https://github.com/google/jax/issues/4085)

8. Which host (process_id) writes the save directory and what part of it?

> Currently if remote path is detected, only `jax.process_idx() == 0` writes 
> non-arrays. All processes write arrays as that is what tensorstore expects.

9. TODO

> - passing sharding not tested