From 678965b53e7ec326c63a78c9f9ae20a66508d7fd Mon Sep 17 00:00:00 2001 From: Tom Begley Date: Tue, 22 Nov 2022 14:38:02 +0000 Subject: [PATCH] Rename ProbabilisticTensorDictModule keys --- tensordict/nn/probabilistic.py | 56 +++++++++++++++++----------------- tensordict/nn/sequence.py | 4 +-- test/test_tensordictmodules.py | 50 +++++++++++++++--------------- 3 files changed, 55 insertions(+), 55 deletions(-) diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 15137efbd..ae85af910 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -78,13 +78,13 @@ class ProbabilisticTensorDictModule(TensorDictModule): module (nn.Module): a nn.Module used to map the input to the output parameter space. Can be a functional module (FunctionalModule or FunctionalModuleWithBuffers), in which case the :obj:`forward` method will expect the params (and possibly) buffers keyword arguments. - dist_param_keys (str or iterable of str or dict): key(s) that will be produced + dist_in_keys (str or iterable of str or dict): key(s) that will be produced by the inner TDModule and that will be used to build the distribution. Importantly, if it's an iterable of string or a string, those keys must match the keywords used by the distribution class of interest, e.g. :obj:`"loc"` and :obj:`"scale"` for the Normal distribution - and similar. If dist_param_keys is a dictionary,, the keys are the keys of the distribution and the values are the + and similar. If dist_in_keys is a dictionary,, the keys are the keys of the distribution and the values are the keys in the tensordict that will get match to the corresponding distribution keys. - out_key_sample (str or iterable of str): keys where the sampled values will be + sample_out_key (str or iterable of str): keys where the sampled values will be written. Importantly, if this key is part of the :obj:`out_keys` of the inner model, the sampling step will be skipped. default_interaction_mode (str, optional): default method to be used to retrieve the output value. Should be one of: @@ -119,8 +119,8 @@ class of interest, e.g. :obj:`"loc"` and :obj:`"scale"` for the Normal distribut >>> module = TensorDictModule(fnet, in_keys=["input", "hidden"], out_keys=["loc", "scale"]) >>> td_module = ProbabilisticTensorDictModule( ... module=module, - ... dist_param_keys=["loc", "scale"], - ... out_key_sample=["action"], + ... dist_in_keys=["loc", "scale"], + ... sample_out_key=["action"], ... distribution_class=Normal, ... return_log_prob=True, ... ) @@ -160,8 +160,8 @@ class of interest, e.g. :obj:`"loc"` and :obj:`"scale"` for the Normal distribut def __init__( self, module: TensorDictModule, - dist_param_keys: Union[str, Sequence[str], dict], - out_key_sample: Union[str, Sequence[str]], + dist_in_keys: Union[str, Sequence[str], dict], + sample_out_key: Union[str, Sequence[str]], default_interaction_mode: str = "mode", distribution_class: Type = Delta, distribution_kwargs: Optional[dict] = None, @@ -173,27 +173,27 @@ def __init__( # if the module returns the sampled key we wont be sampling it again # then ProbabilisticTensorDictModule is presumably used to return the distribution using `get_dist` - if isinstance(dist_param_keys, str): - dist_param_keys = [dist_param_keys] - if isinstance(out_key_sample, str): - out_key_sample = [out_key_sample] - if not isinstance(dist_param_keys, dict): - dist_param_keys = {param_key: param_key for param_key in dist_param_keys} - for key in dist_param_keys.values(): + if isinstance(dist_in_keys, str): + dist_in_keys = [dist_in_keys] + if isinstance(sample_out_key, str): + sample_out_key = [sample_out_key] + if not isinstance(dist_in_keys, dict): + dist_in_keys = {param_key: param_key for param_key in dist_in_keys} + for key in dist_in_keys.values(): if key not in module.out_keys: raise RuntimeError( f"The key {key} could not be found in the wrapped module `{type(module)}.out_keys`." ) module_out_keys = module.out_keys - self.out_key_sample = out_key_sample - _check_all_str(self.out_key_sample) - out_key_sample = [key for key in out_key_sample if key not in module_out_keys] - self._requires_sample = bool(len(out_key_sample)) - out_keys = out_key_sample + module_out_keys + self.sample_out_key = sample_out_key + _check_all_str(self.sample_out_key) + sample_out_key = [key for key in sample_out_key if key not in module_out_keys] + self._requires_sample = bool(len(sample_out_key)) + out_keys = sample_out_key + module_out_keys super().__init__(module=module, in_keys=in_keys, out_keys=out_keys) - self.dist_param_keys = dist_param_keys - _check_all_str(self.dist_param_keys.keys()) - _check_all_str(self.dist_param_keys.values()) + self.dist_in_keys = dist_in_keys + _check_all_str(self.dist_in_keys.keys()) + _check_all_str(self.dist_in_keys.values()) self.default_interaction_mode = default_interaction_mode if isinstance(distribution_class, str): @@ -258,17 +258,17 @@ def get_dist( def build_dist_from_params(self, tensordict_out: TensorDictBase) -> d.Distribution: try: - selected_td_out = tensordict_out.select(*self.dist_param_keys.values()) + selected_td_out = tensordict_out.select(*self.dist_in_keys.values()) dist_kwargs = { dist_key: selected_td_out[td_key] - for dist_key, td_key in self.dist_param_keys.items() + for dist_key, td_key in self.dist_in_keys.items() } dist = self.distribution_class(**dist_kwargs) except TypeError as err: if "an unexpected keyword argument" in str(err): raise TypeError( - "distribution keywords and tensordict keys indicated by ProbabilisticTensorDictModule.dist_param_keys must match." - f"Got this error message: \n{indent(str(err), 4 * ' ')}\nwith dist_param_keys={self.dist_param_keys}" + "distribution keywords and tensordict keys indicated by ProbabilisticTensorDictModule.dist_in_keys must match." + f"Got this error message: \n{indent(str(err), 4 * ' ')}\nwith dist_in_keys={self.dist_in_keys}" ) elif re.search(r"missing.*required positional arguments", str(err)): raise TypeError( @@ -299,13 +299,13 @@ def forward( if isinstance(out_tensors, Tensor): out_tensors = (out_tensors,) tensordict_out.update( - {key: value for key, value in zip(self.out_key_sample, out_tensors)} + {key: value for key, value in zip(self.sample_out_key, out_tensors)} ) if self.return_log_prob: log_prob = dist.log_prob(*out_tensors) tensordict_out.set("sample_log_prob", log_prob) elif self.return_log_prob: - out_tensors = [tensordict_out.get(key) for key in self.out_key_sample] + out_tensors = [tensordict_out.get(key) for key in self.sample_out_key] log_prob = dist.log_prob(*out_tensors) tensordict_out.set("sample_log_prob", log_prob) # raise RuntimeError( diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index db9eacac2..61ee496d4 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -68,8 +68,8 @@ class TensorDictSequential(TensorDictModule): >>> fmodule1 = TensorDictModule(fnet1, in_keys=["input"], out_keys=["loc", "scale"]) >>> td_module1 = ProbabilisticTensorDictModule( ... module=fmodule1, - ... dist_param_keys=["loc", "scale"], - ... out_key_sample=["hidden"], + ... dist_in_keys=["loc", "scale"], + ... sample_out_key=["hidden"], ... distribution_class=Normal, ... return_log_prob=True, ... ) diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 6b8785fef..dbc339356 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -57,16 +57,16 @@ def test_stateful_probabilistic(self, lazy, interaction_mode, out_keys): kwargs = {"distribution_class": Normal} if out_keys == ["loc", "scale"]: - dist_param_keys = ["loc", "scale"] + dist_in_keys = ["loc", "scale"] elif out_keys == ["loc_1", "scale_1"]: - dist_param_keys = {"loc": "loc_1", "scale": "scale_1"} + dist_in_keys = {"loc": "loc_1", "scale": "scale_1"} else: raise NotImplementedError tensordict_module = ProbabilisticTensorDictModule( module=net, - dist_param_keys=dist_param_keys, - out_key_sample=["out"], + dist_in_keys=dist_in_keys, + sample_out_key=["out"], **kwargs, ) @@ -108,8 +108,8 @@ def test_functional_probabilistic(self): kwargs = {"distribution_class": Normal} tensordict_module = ProbabilisticTensorDictModule( module=tdnet, - dist_param_keys=["loc", "scale"], - out_key_sample=["out"], + dist_in_keys=["loc", "scale"], + sample_out_key=["out"], **kwargs, ) @@ -130,8 +130,8 @@ def test_functional_probabilistic_laterconstruct(self): kwargs = {"distribution_class": Normal} tensordict_module = ProbabilisticTensorDictModule( module=tdnet, - dist_param_keys=["loc", "scale"], - out_key_sample=["out"], + dist_in_keys=["loc", "scale"], + sample_out_key=["out"], **kwargs, ) tensordict_module, ( @@ -175,8 +175,8 @@ def test_functional_with_buffer_probabilistic(self): tdmodule = ProbabilisticTensorDictModule( module=tdnet, - dist_param_keys=["loc", "scale"], - out_key_sample=["out"], + dist_in_keys=["loc", "scale"], + sample_out_key=["out"], **kwargs, ) @@ -197,8 +197,8 @@ def test_functional_with_buffer_probabilistic_laterconstruct(self): kwargs = {"distribution_class": Normal} tdmodule = ProbabilisticTensorDictModule( module=tdnet, - dist_param_keys=["loc", "scale"], - out_key_sample=["out"], + dist_in_keys=["loc", "scale"], + sample_out_key=["out"], **kwargs, ) tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers() @@ -255,8 +255,8 @@ def test_vmap_probabilistic(self): tdmodule = ProbabilisticTensorDictModule( module=tdnet, - dist_param_keys=["loc", "scale"], - out_key_sample=["out"], + dist_in_keys=["loc", "scale"], + sample_out_key=["out"], **kwargs, ) @@ -293,8 +293,8 @@ def test_vmap_probabilistic_laterconstruct(self): kwargs = {"distribution_class": Normal} tdmodule = ProbabilisticTensorDictModule( module=tdnet, - dist_param_keys=["loc", "scale"], - out_key_sample=["out"], + dist_in_keys=["loc", "scale"], + sample_out_key=["out"], **kwargs, ) tdmodule, (params, buffers) = tdmodule.make_functional_with_buffers() @@ -435,8 +435,8 @@ def test_stateful_probabilistic(self, lazy): ) tdmodule2 = ProbabilisticTensorDictModule( module=net2, - dist_param_keys=["loc", "scale"], - out_key_sample=["out"], + dist_in_keys=["loc", "scale"], + sample_out_key=["out"], **kwargs, ) tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2) @@ -529,7 +529,7 @@ def test_functional_probabilistic(self): fdummy_net, in_keys=["hidden"], out_keys=["hidden"] ) tdmodule2 = ProbabilisticTensorDictModule( - fnet2, dist_param_keys=["loc", "scale"], out_key_sample=["out"], **kwargs + fnet2, dist_in_keys=["loc", "scale"], sample_out_key=["out"], **kwargs ) tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2) @@ -629,7 +629,7 @@ def test_functional_with_buffer_probabilistic(self): fdummy_net, in_keys=["hidden"], out_keys=["hidden"] ) tdmodule2 = ProbabilisticTensorDictModule( - fnet2, dist_param_keys=["loc", "scale"], out_key_sample=["out"], **kwargs + fnet2, dist_in_keys=["loc", "scale"], sample_out_key=["out"], **kwargs ) tdmodule = TensorDictSequential(tdmodule1, dummy_tdmodule, tdmodule2) @@ -672,8 +672,8 @@ def test_functional_with_buffer_probabilistic_laterconstruct(self): tdmodule1 = TensorDictModule(net1, in_keys=["in"], out_keys=["hidden"]) tdmodule2 = ProbabilisticTensorDictModule( net2, - dist_param_keys=["loc", "scale"], - out_key_sample=["out"], + dist_in_keys=["loc", "scale"], + sample_out_key=["out"], **kwargs, ) tdmodule = TensorDictSequential(tdmodule1, tdmodule2) @@ -761,7 +761,7 @@ def test_vmap_probabilistic(self): kwargs = {"distribution_class": Normal} tdmodule1 = TensorDictModule(fnet1, in_keys=["in"], out_keys=["hidden"]) tdmodule2 = ProbabilisticTensorDictModule( - fnet2, out_key_sample=["out"], dist_param_keys=["loc", "scale"], **kwargs + fnet2, sample_out_key=["out"], dist_in_keys=["loc", "scale"], **kwargs ) tdmodule = TensorDictSequential(tdmodule1, tdmodule2) @@ -856,10 +856,10 @@ def test_sequential_partial(self, stack, functional): tdmodule1 = TensorDictModule(fnet1, in_keys=["a"], out_keys=["hidden"]) tdmodule2 = ProbabilisticTensorDictModule( - fnet2, out_key_sample=["out"], dist_param_keys=["loc", "scale"], **kwargs + fnet2, sample_out_key=["out"], dist_in_keys=["loc", "scale"], **kwargs ) tdmodule3 = ProbabilisticTensorDictModule( - fnet3, out_key_sample=["out"], dist_param_keys=["loc", "scale"], **kwargs + fnet3, sample_out_key=["out"], dist_in_keys=["loc", "scale"], **kwargs ) tdmodule = TensorDictSequential( tdmodule1, tdmodule2, tdmodule3, partial_tolerant=True