In [8]:
from secretnote.instrumentation.tree_util import pytree_snapshot, recursive_defaultdict


In [9]:
import random

from jax.tree_util import (
    GetAttrKey,
    register_pytree_with_keys_class,
    tree_unflatten,
    tree_structure,
)


@register_pytree_with_keys_class
class MyContainer:
    def __init__(self, x):
        self.x = x
        self.r = random.random()

    def tree_flatten_with_keys(self):
        return (
            (GetAttrKey("x"), self.x),
            (GetAttrKey("r"), self.r),
        ), None

    @classmethod
    def tree_unflatten(cls, aux, children):
        self = cls(children[0])
        self.r = children[1]
        return self

    def __hash__(self):
        return hash((self.x, self.r))


In [10]:
c = MyContainer(1)

data = {
    1: [1, 2, 3],
    2: MyContainer(1),
    3: {MyContainer(2): [4, 5, 6]},
}


In [11]:
serialized = pytree_snapshot(data)


In [18]:
import json


structure = recursive_defaultdict()

for kp in serialized.keypaths:
    s = {}
    for k in kp[:-1]:
        s = structure[k]
    s[kp[-1]] = True

tree_unflatten(tree_structure(json.loads(json.dumps(structure))), serialized.leaves)


{'1': {'0': PyTreeLeafSnapshot(id='python/id/0x104ef5d10', type='builtins.int', snapshot='1'),
  '1': PyTreeLeafSnapshot(id='python/id/0x104ef5d30', type='builtins.int', snapshot='2'),
  '2': PyTreeLeafSnapshot(id='python/id/0x104ef5d50', type='builtins.int', snapshot='3')},
 '2': {'0': PyTreeLeafSnapshot(id='python/id/0x104ef5d10', type='builtins.int', snapshot='1'),
  '1': PyTreeLeafSnapshot(id='python/id/0x128bda370', type='builtins.float', snapshot='0.050695655302517295')},
 '3': {},
 'python/id/0x2b44005b0': {'0': PyTreeLeafSnapshot(id='python/id/0x104ef5d70', type='builtins.int', snapshot='4'),
  '1': PyTreeLeafSnapshot(id='python/id/0x104ef5d90', type='builtins.int', snapshot='5'),
  '2': PyTreeLeafSnapshot(id='python/id/0x104ef5db0', type='builtins.int', snapshot='6')}}