Skip to content

[Feature Request] TensorClassModuleBase generic base class, a nn.Module base class with TensorClass input and output #1355

@az0uz

Description

@az0uz

Motivation

I am a big fan of TensorClass as it helps defining complex and/or large collections of structured Tensor data, and helps defining clear inputs and outputs of a Module. On larger code base with multiple contributors, it helps for discovery and as-you-type documentation using the language server auto-complete and in-code docstrings.

The current limitation of using TensorClass compared to TensorDict is the lack of a TensorClassModuleBase, equivalent to a TensorDictModuleBase.

Having a way to write modules that take a specific TensorClass as input and another as output would help write cleaner code and prevent runtime errors of using non-existing key thanks to static linting tools.

Solution

  • Implementation of a generic TensorClassModuleBase[InputTensorClass, OutputTensorClass] abstract base class to inherit from, with typed forward and __call__ methods
  • Add a helper method to convert the child class into a TensorDictModuleBase to support onnx exports

Example class signature

InputTensorClass = TypeVar("InputTensorClass", bound=TensorClass)
OutputTensorClass = TypeVar("OutputTensorClass", bound=TensorClass)

class TensorClassModule(Generic[InputTensorClass, OutputTensorClass], ABC, nn.Module):

    @abstractmethod
    def forward(self, x: InputTensorClass) -> OutputTensorClass:
        ...

    def __call__(self, x: InputTensorClass) -> OutputTensorClass:
        return cast("OutputTensorClass", super().__call__(x))

    def to_tensordict_module(self) -> TensorClassAsDictModule:
        ...

Example usage

class PoseTensor(TensorClass):
    """Full orientation and position of a body in the world."""
    quaternion: Tensor
    """[..., 4]: Contains orientation as a unit quaternion."""
    positions: Tensor
    """[..., 3]: Contains world positions."""


class MyPolicyObservations(TensorClass):
    object_a: PoseTensor
    object_b: PoseTensor


class MyPolicyOutput(TensorClass):
    control_a: Tensor
    control_b: Tensor


class PoseFeaturesModule(TensorClassModule[PoseTensor, Tensor]):
    ...


class MyPolicy(TensorClassModule[MyPolicyObservations, MyPolicyOutput]):
    def __init__(self):
        self.pose_features_module = PoseFeaturesModule()
        self.mlp_a = MLP(...)
        self.mlp_b = MLP(....)
    def forward(x: MyModelObservations) -> MyModelOutput:
        features_a = self.pose_features_module(x.object_a)
        features_b = self.pose_features_module(x.object_a)
        mlp_input = torch.cat([features_a, features_b], dim=-1)
        return MyModelOutput(
            control_a=self.mlp_a(mlp_input)
            control_b=self.mlp_b(mlp_input)
        )

Alternatives

Alternatives is to use TensorDict, TensorDictModule and TensorDictModuleBase, but there is not static checking of existing keys, and less-clear signatures of input and output

Additional context

I have made a prototype implementation I can contribute in a new PR if this feature is accepted.

Checklist

  • I have checked that there is no similar issue in the repo (required)

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions