# A proposal for defining `jax.save` and `jax.load`
... or perhaps something like `jax.save_to_dir` and `jax.load_from_dir`

### Objectives for the new API:

The new API should be:
- simple
- strict by default (no pickle), but allow pickle
    - but allow storing in-built values like numbers (e.g., int, float), strings and bytes
- not be pickle, should be significantly orthogonal to what pickle does
    - NOTE: the effort to support some pickling JAX arrays is also important
- provide the ability for partial reads
- allow to read and write arrays in a distributed sharded context
    - this is already supported with tensorstore (and Yash's low-level API)
- have an `async` version for both save and load
- have the ability to read in just the structure of the data
- have a flat storage of leaf data that the user can access themselves
- support remote reading and writing from/to cloud storage, GCS and S3 at a minimum

In [1]:
import tempfile
from pathlib import Path
from subprocess import check_output
from pprint import pprint
from dataclasses import dataclass
import functools
from typing import Any
import pickle

import jax
from jax.experimental.array_serialization.new_api import save, load, load_pytree
from jax import numpy as jnp, random

In [2]:
data = ["hello", {"world": ["!", (1, 2)]}]

In [3]:
tmpdir = tempfile.TemporaryDirectory().name
save(data, tmpdir)

In [4]:
print("The `pytreedef.json` is a human-readable pytree stored format")
print((Path(tmpdir) / "pytreedef.json").read_text())
print("-" * 80)
print("The leaf data is organized in a flat directory under `leaf_data`")
print(check_output(["tree", tmpdir]).decode()) 

The `pytreedef.json` is a human-readable pytree stored format
{
  "__jax_tree_repr": {
    "node_type": "builtins.list",
    "node_data_ref": null,
    "children": [
      {
        "node_type": "leaf",
        "node_data_ref": null,
        "children": [],
        "leaf_id": "str -> e4b2fd01-19b1-4a93-ada8-e07efaeb4867"
      },
      {
        "node_type": "builtins.dict",
        "node_data_ref": [
          "world"
        ],
        "children": [
          {
            "node_type": "builtins.list",
            "node_data_ref": null,
            "children": [
              {
                "node_type": "leaf",
                "node_data_ref": null,
                "children": [],
                "leaf_id": "str -> 575bfa86-ed48-47f5-bac8-a799afc17809"
              },
              {
                "node_type": "builtins.tuple",
                "node_data_ref": null,
                "children": [
                  {
                    "node_type": "leaf",
                    "n

In [5]:
# read only the data structure
print("-" * 80)
print("PyTree Structure:")
pytree_structure = load_pytree(tmpdir)
pprint(pytree_structure)

--------------------------------------------------------------------------------
PyTree Structure:
['str -> e4b2fd01-19b1-4a93-ada8-e07efaeb4867',
 {'world': ['str -> 575bfa86-ed48-47f5-bac8-a799afc17809',
            ('int -> 29e85017-058f-4bcc-9464-f5baa9d1924d',
             'int -> 60706ed3-be9e-42dd-aff1-5d630e42ef1a')]}]


In [6]:
# read only integers back
print("-" * 80)
print("Partial read of data:")
pytree_structure = jax.tree.map(lambda x: None if x.startswith("str") else x, pytree_structure)
new_data = load(tmpdir, pytree=pytree_structure)
pprint(new_data)

--------------------------------------------------------------------------------
Partial read of data:
[None, {'world': [None, (1, 2)]}]


## Allowing custom nodes with `pickle`

In [7]:
@functools.partial(jax.tree_util.register_dataclass, data_fields=["a", "c"], 
                   meta_fields=["op"])
@dataclass
class D:
  op: str
  a: Any
  c: int
  
tmpdir = tempfile.TemporaryDirectory().name
try:
  save({"dataclass": D("tanh", random.normal(random.key(0), (7,)), 5), 
        "a": 1}, tmpdir)
except ValueError:
  print("Correctly refusing to serialize custom objects")

save({"dataclass": D("tanh", random.normal(random.key(0), (7,)), 5), 
      "a": 1}, tmpdir, pickle_module=pickle)
print("Serialized with pickle")

print("Extra folder: `node_data`")
print(check_output(["tree", "-L", "2", tmpdir]).decode()) 
print("-" * 80)
print((Path(tmpdir) / "pytreedef.json").read_text())

Correctly refusing to serialize custom objects
Serialized with pickle
Extra folder: `node_data`
[01;34m/tmp/tmpi3o49fav[0m
├── [01;34mleaf_data[0m
│   ├── [00m2b2a436a-afca-4fbf-a5cd-6a81c3b53c69.json[0m
│   ├── [01;34m8c1c43cf-5d57-493b-877e-aa396afef9d1.tensorstore[0m
│   └── [00mce4aec25-642a-41d8-9557-cce14798da21.json[0m
├── [01;34mnode_data[0m
│   └── [00mef09d15c-5698-4b13-b8e4-19fc97fb7dbf.pickle[0m
└── [00mpytreedef.json[0m

4 directories, 4 files

--------------------------------------------------------------------------------
{
  "__jax_tree_repr": {
    "node_type": "builtins.dict",
    "node_data_ref": [
      "a",
      "dataclass"
    ],
    "children": [
      {
        "node_type": "leaf",
        "node_data_ref": null,
        "children": [],
        "leaf_id": "int -> 2b2a436a-afca-4fbf-a5cd-6a81c3b53c69"
      },
      {
        "node_type": "__main__.D",
        "node_data_ref": "ef09d15c-5698-4b13-b8e4-19fc97fb7dbf",
        "children": [
         

In [8]:
try:
  load(tmpdir)
except ValueError:
  print("Correctly refuses to read without pickle")

print("Reads corectly with pickle")
print(load(tmpdir, pickle_module=pickle))

Correctly refuses to read without pickle
Reads corectly with pickle
{'a': 1, 'dataclass': D(op='tanh', a=Array([ 0.08086783, -0.38624713, -0.37565565,  0.58691907, -1.2758198 ,
        2.1192005 , -0.85821223], dtype=float32), c=5)}


## Best-effort reading when pickled objects are no longer available

We can also attempt to load with `best_effort=True` without pickle or if the class definition / custom node registration has been lost.

This will print a warning and will read the children of the former custom node and organized them in a list.

All node data (e.g., static fields) are not read.

In [9]:
print(load(tmpdir, best_effort=True))



{'a': 1, 'dataclass': [Array([ 0.08086783, -0.38624713, -0.37565565,  0.58691907, -1.2758198 ,
        2.1192005 , -0.85821223], dtype=float32), 5]}


### Notes

1. Must the resulting checkpoint be a directory? Can it not be a file?

> The underlying checkpoint is a directory, tensorstore doesn't really support
> writing single-file checkpoints that are well read-optimized.

> It's possible to zip the result, piece-by-piece without wasting disk space,
> which is probably a direction to explore. NOTE: tensorstore seems to support
> **reading** from Python zipfile handles directly.

2. How CPU efficient is saving the checkpoint?

> Thanks to the underlying async usage, it should be pretty efficient.

3. Is RAM usage controlled?

> Not at the moment, but it's possible to improve this. We can use tensorstore
> to limit array writing memory usage and we can rewrite non-array writing to be
> non-buffered through a bytes or text object (they are buffered to more cleanly
> support file://, gcs://, s3:// alternatives).

4. Why is the `pytreedef.json` weird like that?

> The "cleanest" way to save a pytree structure is to just use a JSON
> representation with leafs replaced with their data reference id. However,
> JSON doesn't distinguish between tuple, list and so it doesn't really preserve
> the actual pytree, even if it's limited to only in-built types. Also, when the
> pytree contains custom nodes, we need a custom tree representation anyway.

5. Isn't overwriting a **directory** checkpoint extremely dangerous if the 
"checkpoint" path turns out to be e.g "/usr/local"?

> Yes, but we first check for files and directories we didn't create and refuse
> to overwrite if there are any.

6. Why doesn't Python LSP not work with synchronous versions: `save`, `load`?

> I don't know, I need to fix it.