# A proposal for defining `jax.save` and `jax.load`
... or perhaps something like `jax.save_to_dir` and `jax.load_from_dir`

### Objectives for the new API:

The new API should be:
- simple
- strict by default (no pickle), but allow pickle
    - but allow storing in-built values like numbers (e.g., int, float), strings and bytes
- not be pickle, should be significantly orthogonal to what pickle does
    - NOTE: the effort to support some pickling JAX arrays is also important
- provide the ability for partial reads
- allow to read and write arrays in a distributed sharded context
    - this is already supported with tensorstore (and Yash's low-level API)
- have an `async` version for both save and load
- have the ability to read in just the structure of the data
- have a flat storage of leaf data that the user can access themselves
- support remote reading and writing from/to cloud storage, GCS and S3 at a minimum

In [25]:
import tempfile
from pathlib import Path
from subprocess import check_output
from pprint import pprint

import jax
from jax.experimental.array_serialization.new_api import save, load, load_pytree

In [26]:
data = ["hello", {"world": ["!", (1, 2)]}]

In [33]:
with tempfile.TemporaryDirectory() as tmpdir:
  save(data, tmpdir)
  print("The `pytreedef.json` is a human-readable pytree stored format")
  print((Path(tmpdir) / "pytreedef.json").read_text())
  print("-" * 80)
  print("The leaf data is organized in a flat directory under `leaf_data`")
  print(check_output(["tree", tmpdir]).decode()) 
  
  # read only the data structure
  print("-" * 80)
  print("PyTree Structure:")
  pytree_structure = load_pytree(tmpdir)
  pprint(pytree_structure)
  
  # read only integers back
  print("-" * 80)
  print("Partial read of data:")
  pytree_structure = jax.tree.map(lambda x: None if x.startswith("str") else x, pytree_structure)
  new_data = load(tmpdir, pytree=pytree_structure)
  pprint(new_data)

The `pytreedef.json` is a human-readable pytree stored format
{
  "__jax_tree_repr": {
    "node_type": "builtins.list",
    "node_data_ref": null,
    "children": [
      {
        "node_type": "leaf",
        "node_data_ref": null,
        "children": [],
        "leaf_id": "str -> 1d02b1a4-b93e-456f-bb43-0a96f522a751"
      },
      {
        "node_type": "builtins.dict",
        "node_data_ref": [
          "world"
        ],
        "children": [
          {
            "node_type": "builtins.list",
            "node_data_ref": null,
            "children": [
              {
                "node_type": "leaf",
                "node_data_ref": null,
                "children": [],
                "leaf_id": "str -> fb337e95-5bca-4abb-baa9-9f202bf384ec"
              },
              {
                "node_type": "builtins.tuple",
                "node_data_ref": null,
                "children": [
                  {
                    "node_type": "leaf",
                    "n

## Allowing custom nodes with `pickle`

TODO

## Best-effort reading when pickled objects are no longer available

TODO