Skip to content

Commit

Permalink
jaxable data class updates
Browse files Browse the repository at this point in the history
  • Loading branch information
planes committed May 13, 2022
1 parent 7538af9 commit 5253351
Showing 1 changed file with 40 additions and 6 deletions.
46 changes: 40 additions & 6 deletions trunk/SUAVE/Core/Data.py
Expand Up @@ -66,8 +66,35 @@ def tree_flatten(self):
Properties Used:
N/A
"""
children = self.values()
aux_data = self.keys()

#Get all keys and values from the data class
keys = list(self.keys())
values = self.values()

# Make a dictionary of the strings in self
aux_dict = dict()
for key, value in zip(keys,values):
if type(value) is str:
aux_dict[key] = self.pop(key)

# Some children classes might have "static_keys" that are marked as immutable
if hasattr(self,'static_keys'):
for k in self.static_keys:
aux_dict[k] = self.pop(k)

# static_keys is also a static key...
aux_dict['static_keys'] = self.pop('static_keys')

#Get all keys and values from the data class, now that they don't have strings or static_keys
stringless_keys = list(self.keys())
children = self.values()

# Put back what I have taken away
self.update(aux_dict)

# Pack up the results
aux_data = [stringless_keys,aux_dict]

return (children, aux_data)

@classmethod
Expand All @@ -88,12 +115,19 @@ def tree_unflatten(cls, aux_data, children):
Properties Used:
N/A
"""
"""
# Create the class
recreated = cls()
length = len(aux_data)
keys = list(aux_data)

# add an string data
recreated.update(aux_data[1])

# keys
keys = aux_data[0]
length = len(keys)
keys = list(keys)
for ii in range(length):
recreated.append(children[ii],keys[ii])
recreated[keys[ii]] = children[ii]

return recreated

Expand Down

0 comments on commit 5253351

Please sign in to comment.