Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ def __init__(
self._locked_tensordicts = []
self._get_post_hook = []

def __iter__(self):
yield from self._param_td.__iter__()

def register_get_post_hook(self, hook):
"""Register a hook to be called after any get operation on leaf tensors."""
if not callable(hook):
Expand Down
31 changes: 25 additions & 6 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@
import re
import warnings

try:
from enum import StrEnum
except ImportError:
from .utils import StrEnum
from textwrap import indent
from typing import Any, Dict, List, Optional

Expand All @@ -30,6 +26,16 @@

from torch.utils._contextlib import _DecoratorContextManager

try:
from torch.compiler import is_compiling
except ImportError:
from torch._dynamo import is_compiling

try:
from enum import StrEnum
except ImportError:
from .utils import StrEnum

__all__ = ["ProbabilisticTensorDictModule", "ProbabilisticTensorDictSequential"]


Expand Down Expand Up @@ -350,11 +356,13 @@ def get_dist(self, tensordict: TensorDictBase) -> D.Distribution:
if isinstance(dist_key, tuple):
dist_key = dist_key[-1]
dist_kwargs[dist_key] = tensordict.get(td_key)
dist = self.distribution_class(**dist_kwargs, **self.distribution_kwargs)
dist = self.distribution_class(
**dist_kwargs, **_dynamo_friendly_to_dict(self.distribution_kwargs)
)
except TypeError as err:
if "an unexpected keyword argument" in str(err):
raise TypeError(
"distribution keywords and tensordict keys indicated by ProbabilisticTensorDictModule.dist_keys must match."
"distribution keywords and tensordict keys indicated by ProbabilisticTensorDictModule.dist_keys must match. "
f"Got this error message: \n{indent(str(err), 4 * ' ')}\nwith dist_keys={self.dist_keys}"
)
elif re.search(r"missing.*required positional arguments", str(err)):
Expand Down Expand Up @@ -623,3 +631,14 @@ def forward(
) -> TensorDictBase:
tensordict_out = self.get_dist_params(tensordict, tensordict_out, **kwargs)
return self.module[-1](tensordict_out, _requires_sample=self._requires_sample)


def _dynamo_friendly_to_dict(data):
if not is_compiling():
return data
if isinstance(data, TensorDictBase):
items = list(data.items())
if not items:
return {}
return dict(items)
return data
23 changes: 23 additions & 0 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
)
from tensordict.nn import (
CudaGraphModule,
InteractionType,
ProbabilisticTensorDictModule as Prob,
TensorDictModule,
TensorDictModule as Mod,
TensorDictSequential as Seq,
Expand Down Expand Up @@ -662,6 +664,27 @@ def test_dispatch_tensor(self, mode):
mod_compile = torch.compile(mod, fullgraph=_v2_5, mode=mode)
torch.testing.assert_close(mod(x=x, y=y), mod_compile(x=x, y=y))

def test_prob_module_with_kwargs(self, mode):
kwargs = TensorDictParams(TensorDict(scale=1.0), no_convert=True)
dist_cls = torch.distributions.Normal
mod = Mod(torch.nn.Linear(3, 3), in_keys=["inp"], out_keys=["loc"])
prob_mod = Seq(
mod,
Prob(
in_keys=["loc"],
out_keys=["sample"],
return_log_prob=True,
distribution_class=dist_cls,
distribution_kwargs=kwargs,
default_interaction_type=InteractionType.RANDOM,
),
)
# check that the scale is in the buffers
assert len(list(prob_mod.buffers())) == 1
prob_mod(TensorDict(inp=torch.randn(3)))
prob_mod_c = torch.compile(prob_mod, fullgraph=True, mode=mode)
prob_mod_c(TensorDict(inp=torch.randn(3)))


@pytest.mark.skipif(
TORCH_VERSION <= version.parse("2.4.0"), reason="requires torch>2.4"
Expand Down
Loading