Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,14 @@ def __getattr__(self, item: str) -> Any:
try:
return getattr(self.__dict__["_param_td"], item)
except AttributeError:
return super().__getattr__(item)
try:
return super().__getattr__(item)
except AttributeError as e:
# During some state-dict loads, we may encounter cases where pytorch does a getattr
# with the module name
if item in self.keys():
return TensorDictParams(self[item])
raise e
else:
return super().__getattr__(item)

Expand Down
59 changes: 59 additions & 0 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,65 @@ def test_export_seq(self):
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y})
torch.testing.assert_close(out.module()(x=x, y=y), tdm(x=x, y=y))

# This tests passes but there are various things that need to be fixed:
# - we cannot use vmap directly
# - if we use strict=True, there's an error due to the fact that export ignores
# the replacement of the params (ie, params are still on "meta" and the values
# after the call on the exported module don't match the original ones).
# Currently only works with strict=False, because export fails to see that
# the params in the module have changed and are not 'meta' anymore => this
# is symptomatic of export failing to see the functional call
@pytest.mark.parametrize("strict", [False]) # , True])
def test_export_with_td_params(self, strict):
module = torch.nn.Sequential(
torch.nn.Linear(3, 4),
torch.nn.Linear(4, 5),
)
module_td = TensorDictParams(
TensorDict.from_module(module).data.expand(2).clone()
)
assert all(
isinstance(p, torch.nn.Parameter) for p in module_td.values(True, True)
)

class MyModule(torch.nn.Module):
def __init__(self, td_params):
super().__init__()
self.tdparams = td_params
self.arch = torch.nn.Sequential(
torch.nn.Linear(3, 4, device="meta"),
torch.nn.Linear(4, 5, device="meta"),
)

def forward(self, x):
# vmap with params currently fails
# return torch.vmap(self.batch_forward, (0, None))(self.tdparams, x)
return torch.stack(
[self.batch_forward(p, x) for p in self.tdparams.unbind(0)]
)

def batch_forward(self, params, x):
with params.to_module(self.arch):
return self.arch(x)
# This could be an option but dynamo doesn't know how to trace through state_dict ops
# sd = self.arch.state_dict()
# try:
# self.arch.load_state_dict(params.flatten_keys().to_dict(), assign=True)
# return self.arch(x)
# finally:
# self.arch.load_state_dict(sd, assign=True)

m = MyModule(module_td)
x = torch.randn(3)
assert m(x).shape == (2, 5)
exported_module = torch.export.export(
m,
args=(),
kwargs={"x": x},
strict=strict,
)
torch.testing.assert_close(exported_module.module()(x=x), m(x))


@pytest.mark.skipif(not _has_onnx, reason="ONNX is not available")
class TestONNXExport:
Expand Down
Loading