-
Notifications
You must be signed in to change notification settings - Fork 108
Description
Motivation
I am a big fan of TensorClass as it helps defining complex and/or large collections of structured Tensor data, and helps defining clear inputs and outputs of a Module. On larger code base with multiple contributors, it helps for discovery and as-you-type documentation using the language server auto-complete and in-code docstrings.
The current limitation of using TensorClass compared to TensorDict is the lack of a TensorClassModuleBase, equivalent to a TensorDictModuleBase.
Having a way to write modules that take a specific TensorClass as input and another as output would help write cleaner code and prevent runtime errors of using non-existing key thanks to static linting tools.
Solution
- Implementation of a generic
TensorClassModuleBase[InputTensorClass, OutputTensorClass]abstract base class to inherit from, with typedforwardand__call__methods - Add a helper method to convert the child class into a
TensorDictModuleBaseto support onnx exports
Example class signature
InputTensorClass = TypeVar("InputTensorClass", bound=TensorClass)
OutputTensorClass = TypeVar("OutputTensorClass", bound=TensorClass)
class TensorClassModule(Generic[InputTensorClass, OutputTensorClass], ABC, nn.Module):
@abstractmethod
def forward(self, x: InputTensorClass) -> OutputTensorClass:
...
def __call__(self, x: InputTensorClass) -> OutputTensorClass:
return cast("OutputTensorClass", super().__call__(x))
def to_tensordict_module(self) -> TensorClassAsDictModule:
...
Example usage
class PoseTensor(TensorClass):
"""Full orientation and position of a body in the world."""
quaternion: Tensor
"""[..., 4]: Contains orientation as a unit quaternion."""
positions: Tensor
"""[..., 3]: Contains world positions."""
class MyPolicyObservations(TensorClass):
object_a: PoseTensor
object_b: PoseTensor
class MyPolicyOutput(TensorClass):
control_a: Tensor
control_b: Tensor
class PoseFeaturesModule(TensorClassModule[PoseTensor, Tensor]):
...
class MyPolicy(TensorClassModule[MyPolicyObservations, MyPolicyOutput]):
def __init__(self):
self.pose_features_module = PoseFeaturesModule()
self.mlp_a = MLP(...)
self.mlp_b = MLP(....)
def forward(x: MyModelObservations) -> MyModelOutput:
features_a = self.pose_features_module(x.object_a)
features_b = self.pose_features_module(x.object_a)
mlp_input = torch.cat([features_a, features_b], dim=-1)
return MyModelOutput(
control_a=self.mlp_a(mlp_input)
control_b=self.mlp_b(mlp_input)
)
Alternatives
Alternatives is to use TensorDict, TensorDictModule and TensorDictModuleBase, but there is not static checking of existing keys, and less-clear signatures of input and output
Additional context
I have made a prototype implementation I can contribute in a new PR if this feature is accepted.
Checklist
- I have checked that there is no similar issue in the repo (required)