Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/tensorclass.rst
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ Here is an example:
:template: td_template.rst

tensorclass
TensorClass
NonTensorData
NonTensorStack

Expand Down
7 changes: 6 additions & 1 deletion tensordict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@
)
from tensordict.memmap import MemoryMappedTensor
from tensordict.persistent import PersistentTensorDict
from tensordict.tensorclass import NonTensorData, NonTensorStack, tensorclass
from tensordict.tensorclass import (
NonTensorData,
NonTensorStack,
tensorclass,
TensorClass,
)
from tensordict.utils import (
assert_allclose_td,
assert_close,
Expand Down
48 changes: 46 additions & 2 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import abc
import concurrent
import ctypes

Expand All @@ -24,6 +25,7 @@
from dataclasses import dataclass
from pathlib import Path
from textwrap import indent

from typing import Any, Callable, get_type_hints, List, Sequence, Type, TypeVar

import numpy as np
Expand Down Expand Up @@ -361,14 +363,16 @@ def __init__(self, autocast: bool = False, frozen: bool = False):
self.frozen = frozen

@dataclass_transform()
def __call__(self, cls):
def __call__(self, cls: T) -> T:
clz = _tensorclass(cls, frozen=self.frozen)
clz.autocast = self.autocast
return clz


@dataclass_transform()
def tensorclass(cls=None, /, *, autocast: bool = False, frozen: bool = False):
def tensorclass(
cls: T = None, /, *, autocast: bool = False, frozen: bool = False
) -> T | None:
"""A decorator to create :obj:`tensorclass` classes.

``tensorclass`` classes are specialized :func:`dataclasses.dataclass` instances that
Expand Down Expand Up @@ -3300,3 +3304,43 @@ def _update_shared_nontensor(nontensor, val):
raise NotImplementedError(
f"Updating {type(nontensor).__name__} within a shared/memmaped structure is not supported."
)


class _TensorClassMeta(abc.ABCMeta):
def __new__(mcs, name, bases, namespace, **kwargs):
# Create the class using the ABCMeta's __new__ method
cls = super().__new__(mcs, name, bases, namespace, **kwargs)

# Apply the dataclass decorator to the class
cls = _tensorclass(cls, frozen=False)

return cls


class TensorClass(metaclass=_TensorClassMeta):
"""TensorClass is the inheritance-based version of the @tensorclass decorator.

TensorClass allows you to code dataclasses that are better type-checked and more pythonic than those built with
the @tensorclass decorator.

Examples:
>>> from typing import Any
>>> import torch
>>> from tensordict import TensorClass
>>> class Foo(TensorClass):
... tensor: torch.Tensor
... non_tensor: Any
... nested: Any = None
>>> foo = Foo(tensor=torch.randn(3), non_tensor="a string!", nested=None, batch_size=[3])
>>> print(foo)
Foo(
non_tensor=NonTensorData(data=a string!, batch_size=torch.Size([3]), device=None),
tensor=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
nested=None,
batch_size=torch.Size([3]),
device=None,
is_shared=False)

"""

...
Loading
Loading