diff --git a/tensordict/nn/functional_modules.py b/tensordict/nn/functional_modules.py index 663d9920d..57d64ad78 100644 --- a/tensordict/nn/functional_modules.py +++ b/tensordict/nn/functional_modules.py @@ -128,11 +128,14 @@ def __init__(self): def __enter__(self): for tdtype in PYTREE_REGISTERED_TDS + PYTREE_REGISTERED_LAZY_TDS: - self.tdnodes[tdtype] = SUPPORTED_NODES.pop(tdtype) + node = SUPPORTED_NODES.pop(tdtype, None) + if node is None: + continue + self.tdnodes[tdtype] = node def __exit__(self, exc_type, exc_val, exc_tb): - for tdtype in PYTREE_REGISTERED_TDS + PYTREE_REGISTERED_LAZY_TDS: - SUPPORTED_NODES[tdtype] = self.tdnodes[tdtype] + for tdtype, node in self.tdnodes.items(): + SUPPORTED_NODES[tdtype] = node def set(self): self.__enter__()