diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index a99811f04..18cb89e10 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -4003,7 +4003,6 @@ def _update_inv_op_kwargs(self, tensor: Tensor) -> dict[str, Any]: def _stack_onto_( self, - # key: str, list_item: list[CompatibleType], dim: int, ) -> T: diff --git a/tensordict/_td.py b/tensordict/_td.py index c83c4cfbe..dd422963b 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -235,13 +235,13 @@ class TensorDict(TensorDictBase): def __init__( self, - source: T | dict[str, CompatibleType] = None, + source: T | dict[NestedKey, CompatibleType] = None, batch_size: Sequence[int] | torch.Size | int | None = None, device: DeviceType | None = None, names: Sequence[str] | None = None, - non_blocking: bool = None, + non_blocking: bool | None = None, lock: bool = False, - **kwargs, + **kwargs: dict[str, Any] | None, ) -> None: if (source is not None) and kwargs: raise ValueError( @@ -304,14 +304,14 @@ def __init__( @classmethod def _new_unsafe( cls, - source: T | dict[str, CompatibleType] = None, + source: T | dict[NestedKey, CompatibleType] = None, batch_size: Sequence[int] | torch.Size | int | None = None, device: DeviceType | None = None, names: Sequence[str] | None = None, - non_blocking: bool = None, + non_blocking: bool | None = None, lock: bool = False, nested: bool = True, - **kwargs, + **kwargs: dict[str, Any] | None, ) -> TensorDict: if is_dynamo_compiling(): return TensorDict(