Skip to content

Commit afe425e

Browse files
[torchax]: JittableModule statedict handling (#9195)
Co-authored-by: zmelumian <zmelumian@lightricks.com>
1 parent 0818753 commit afe425e

File tree

2 files changed

+137
-0
lines changed

2 files changed

+137
-0
lines changed

torchax/test/test_statedict.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import unittest
2+
import torch
3+
from torch.utils import _pytree as pytree
4+
5+
from torchax import (interop, mesh_util, tensor)
6+
7+
8+
class Model(torch.nn.Module):
9+
10+
def __init__(self):
11+
super(Model, self).__init__()
12+
self.linear = torch.nn.Linear(10, 5)
13+
14+
def forward(self, x):
15+
return self.linear(x)
16+
17+
18+
mesh = mesh_util.Mesh.fsdp_mesh()
19+
model = interop.JittableModule(mesh.initialize_model_sharded(Model, ()))
20+
21+
22+
class TestTensorStateDict(unittest.TestCase):
23+
24+
def test_get_statedict(self):
25+
state_dict_cpu = model.cpu_state_dict()
26+
is_xla_tensor = pytree.tree_map(lambda t: isinstance(t, tensor.Tensor),
27+
state_dict_cpu)
28+
assert not any(
29+
is_xla_tensor.values()), "State dict should not contain XLA tensors"
30+
31+
def test_load_statedict(self):
32+
state_dict_cpu = model.cpu_state_dict()
33+
state_dict_cpu = pytree.tree_map(torch.zeros_like, state_dict_cpu)
34+
model.load_state_dict(state_dict_cpu)
35+
is_zeros = pytree.tree_map(lambda t: torch.equal(t, torch.zeros_like(t)),
36+
state_dict_cpu)
37+
assert all(is_zeros.values()), "State dict should be zeros"
38+
39+
def test_load_statedict_partial(self):
40+
state_dict_cpu = model.cpu_state_dict()
41+
del state_dict_cpu['_model.linear.bias']
42+
state_dict_cpu = pytree.tree_map(torch.ones_like, state_dict_cpu)
43+
key_check = model.load_state_dict(state_dict_cpu, strict=False)
44+
assert key_check.missing_keys == [
45+
'_model.linear.bias'
46+
], "Missing keys should be '_model.linear.bias'"
47+
linear_weight = model.state_dict()['_model.linear.weight']
48+
assert torch.equal(
49+
linear_weight,
50+
torch.ones_like(linear_weight)), "Linear weight should be ones"
51+
52+
53+
if __name__ == '__main__':
54+
unittest.main()

torchax/torchax/interop.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Mapping, Any
12
import collections
23
import copy
34
import functools
@@ -132,6 +133,88 @@ def call(*args, **kwargs):
132133

133134
self._jitted[key] = call
134135

136+
def cpu_state_dict(self, *args, **kwargs):
137+
"""
138+
Wrapper for state_dict
139+
140+
this function will make sure to transfer all the parameters to CPU
141+
making it easier to save the state dict with torch.save
142+
143+
Returns:
144+
Mapping[str, Any]: A mapping of parameter names to their values (in torch CPU)
145+
"""
146+
state_dict = super().state_dict(*args, **kwargs)
147+
state_dict = pytree.tree_map(lambda t: t.cpu(), state_dict)
148+
return state_dict
149+
150+
def load_state_dict(self,
151+
state_dict: Mapping[str, Any],
152+
strict: bool = True,
153+
assign: bool = False):
154+
"""
155+
Wrapper for load_state_dict
156+
157+
This function assumes torch CPU state dict and will transfer the parameters to the correct device
158+
and dtype before loading them into the model.
159+
160+
Args:
161+
state_dict (Mapping[str, Any]): A mapping of parameter names to their values (in torch CPU)
162+
strict (bool, optional): whether to strictly enforce that the keys
163+
in :attr:`state_dict` match the keys returned by this module's
164+
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
165+
assign (bool, optional): When set to ``False``, the properties of the tensors
166+
in the current module are preserved whereas setting it to ``True`` preserves
167+
properties of the Tensors in the state dict. The only
168+
exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s
169+
for which the value from the module is preserved.
170+
Default: ``False``
171+
172+
Returns:
173+
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
174+
* **missing_keys** is a list of str containing any keys that are expected
175+
by this module but missing from the provided ``state_dict``.
176+
* **unexpected_keys** is a list of str containing the keys that are not
177+
expected by this module but present in the provided ``state_dict``.
178+
"""
179+
# Move tensors to JAX to have easier time extracting sharding information
180+
current_state_dict = super().state_dict()
181+
current_state_dict = jax_view(current_state_dict)
182+
183+
# create out shardings that eithe reuses the current state dict sharding or replicates the weights
184+
def extract_sharding_or_replicate(name):
185+
if name in current_state_dict:
186+
return current_state_dict[name].sharding
187+
return jax.sharding.PartitionSpec()
188+
189+
output_shards = {
190+
name: extract_sharding_or_replicate(name) for name in state_dict
191+
}
192+
193+
def convert_to_xla_tensor_if_needed(t):
194+
is_torch_tensor = isinstance(t, torch.Tensor)
195+
is_xla_tensor = isinstance(t, torchax.tensor.Tensor)
196+
if is_xla_tensor:
197+
t = jax_view(t)
198+
elif is_torch_tensor:
199+
# convert to jax tensor
200+
t = tensor.t2j(t)
201+
return t
202+
203+
# convert the state dict to JAX and shard them
204+
state_dict = pytree.tree_map(
205+
tensor.t2j,
206+
state_dict,
207+
)
208+
# Convert ordered dict to regular dict, pjit type-safety checks
209+
state_dict = dict(state_dict)
210+
jitted = jax_jit(
211+
lambda t: t, kwargs_for_jax_jit={"out_shardings": output_shards})
212+
state_dict = jitted(state_dict)
213+
# review it as torch tensors, so we can use torch.assign if we need to
214+
state_dict = torch_view(state_dict)
215+
216+
return super().load_state_dict(state_dict, strict, assign)
217+
135218

136219
class CompileMixin:
137220

0 commit comments

Comments
 (0)