4141 unravel_key ,
4242)
4343from tensordict .base import NO_DEFAULT
44- from tensordict .utils import _getitem_batch_size , NestedKey
44+ from tensordict .utils import _getitem_batch_size , is_non_tensor , NestedKey
4545from torchrl ._utils import _make_ordinal_device , get_binary_env_var , implement_for
4646
4747DEVICE_TYPING = Union [torch .device , str , int ]
@@ -582,6 +582,16 @@ def clear_device_(self) -> T:
582582 """
583583 return self
584584
585+ @abc .abstractmethod
586+ def cardinality (self ) -> int :
587+ """The cardinality of the spec.
588+
589+ This refers to the number of possible outcomes in a spec. It is assumed that the cardinality of a composite
590+ spec is the cartesian product of all possible outcomes.
591+
592+ """
593+ ...
594+
585595 def encode (
586596 self ,
587597 val : np .ndarray | torch .Tensor | TensorDictBase ,
@@ -1515,6 +1525,9 @@ def __init__(
15151525 def n (self ):
15161526 return self .space .n
15171527
1528+ def cardinality (self ) -> int :
1529+ return self .n
1530+
15181531 def update_mask (self , mask ):
15191532 """Sets a mask to prevent some of the possible outcomes when a sample is taken.
15201533
@@ -2107,6 +2120,9 @@ def enumerate(self) -> Any:
21072120 f"enumerate is not implemented for spec of class { type (self ).__name__ } ."
21082121 )
21092122
2123+ def cardinality (self ) -> int :
2124+ return float ("inf" )
2125+
21102126 def __eq__ (self , other ):
21112127 return (
21122128 type (other ) == type (self )
@@ -2426,8 +2442,11 @@ def __init__(
24262442 shape = shape , space = None , device = device , dtype = dtype , domain = domain , ** kwargs
24272443 )
24282444
2445+ def cardinality (self ) -> Any :
2446+ raise RuntimeError ("Cannot enumerate a NonTensorSpec." )
2447+
24292448 def enumerate (self ) -> Any :
2430- raise NotImplementedError ("Cannot enumerate a NonTensorSpec." )
2449+ raise RuntimeError ("Cannot enumerate a NonTensorSpec." )
24312450
24322451 def to (self , dest : Union [torch .dtype , DEVICE_TYPING ]) -> NonTensor :
24332452 if isinstance (dest , torch .dtype ):
@@ -2466,10 +2485,10 @@ def one(self, shape=None):
24662485 data = None , batch_size = (* shape , * self ._safe_shape ), device = self .device
24672486 )
24682487
2469- def is_in (self , val : torch . Tensor ) -> bool :
2488+ def is_in (self , val : Any ) -> bool :
24702489 shape = torch .broadcast_shapes (self ._safe_shape , val .shape )
24712490 return (
2472- isinstance (val , NonTensorData )
2491+ is_non_tensor (val )
24732492 and val .shape == shape
24742493 # We relax constrains on device as they're hard to enforce for non-tensor
24752494 # tensordicts and pointless
@@ -2832,6 +2851,9 @@ def __init__(
28322851 )
28332852 self .update_mask (mask )
28342853
2854+ def cardinality (self ) -> int :
2855+ return torch .as_tensor (self .nvec ).prod ()
2856+
28352857 def enumerate (self ) -> torch .Tensor :
28362858 nvec = self .nvec
28372859 enum_disc = self .to_categorical_spec ().enumerate ()
@@ -3220,13 +3242,20 @@ class Categorical(TensorSpec):
32203242 The spec will have the shape defined by the ``shape`` argument: if a singleton dimension is
32213243 desired for the training dimension, one should specify it explicitly.
32223244
3245+ Attributes:
3246+ n (int): The number of possible outcomes.
3247+ shape (torch.Size): The shape of the variable.
3248+ device (torch.device): The device of the tensors.
3249+ dtype (torch.dtype): The dtype of the tensors.
3250+
32233251 Args:
3224- n (int): number of possible outcomes.
3252+ n (int): number of possible outcomes. If set to -1, the cardinality of the categorical spec is undefined,
3253+ and `set_provisional_n` must be called before sampling from this spec.
32253254 shape: (torch.Size, optional): shape of the variable, default is "torch.Size([])".
3226- device (str, int or torch.device, optional): device of the tensors.
3227- dtype (str or torch.dtype, optional): dtype of the tensors.
3228- mask (torch.Tensor or None): mask some of the possible outcomes when a
3229- sample is taken. See :meth:`~.update_mask` for more information.
3255+ device (str, int or torch.device, optional): the device of the tensors.
3256+ dtype (str or torch.dtype, optional): the dtype of the tensors.
3257+ mask (torch.Tensor or None): A boolean mask to prevent some of the possible outcomes when a sample is taken.
3258+ See :meth:`~.update_mask` for more information.
32303259
32313260 Examples:
32323261 >>> categ = Categorical(3)
@@ -3249,6 +3278,13 @@ class Categorical(TensorSpec):
32493278 domain=discrete)
32503279 >>> categ.rand()
32513280 tensor([1])
3281+ >>> categ = Categorical(-1)
3282+ >>> categ.set_provisional_n(5)
3283+ >>> categ.rand()
3284+ tensor(3)
3285+
3286+ .. note:: When n is set to -1, calling `rand` without first setting a provisional n using `set_provisional_n`
3287+ will raise a ``RuntimeError``.
32523288
32533289 """
32543290
@@ -3276,16 +3312,31 @@ def __init__(
32763312 shape = shape , space = space , device = device , dtype = dtype , domain = "discrete"
32773313 )
32783314 self .update_mask (mask )
3315+ self ._provisional_n = None
32793316
32803317 def enumerate (self ) -> torch .Tensor :
3281- arange = torch .arange (self .n , dtype = self .dtype , device = self .device )
3318+ dtype = self .dtype
3319+ if dtype is torch .bool :
3320+ dtype = torch .uint8
3321+ arange = torch .arange (self .n , dtype = dtype , device = self .device )
32823322 if self .ndim :
32833323 arange = arange .view (- 1 , * (1 ,) * self .ndim )
32843324 return arange .expand (self .n , * self .shape )
32853325
32863326 @property
32873327 def n (self ):
3288- return self .space .n
3328+ n = self .space .n
3329+ if n == - 1 :
3330+ n = self ._provisional_n
3331+ if n is None :
3332+ raise RuntimeError (
3333+ f"Undefined cardinality for { type (self )} . Please call "
3334+ f"spec.set_provisional_n(int)."
3335+ )
3336+ return n
3337+
3338+ def cardinality (self ) -> int :
3339+ return self .n
32893340
32903341 def update_mask (self , mask ):
32913342 """Sets a mask to prevent some of the possible outcomes when a sample is taken.
@@ -3316,13 +3367,33 @@ def update_mask(self, mask):
33163367 raise ValueError ("Only boolean masks are accepted." )
33173368 self .mask = mask
33183369
3370+ def set_provisional_n (self , n : int ):
3371+ """Set the cardinality of the Categorical spec temporarily.
3372+
3373+ This method is required to be called before sampling from the spec when n is -1.
3374+
3375+ Args:
3376+ n (int): The cardinality of the Categorical spec.
3377+
3378+ """
3379+ self ._provisional_n = n
3380+
33193381 def rand (self , shape : torch .Size = None ) -> torch .Tensor :
3382+ if self .space .n < 0 :
3383+ if self ._provisional_n is None :
3384+ raise RuntimeError (
3385+ "Cannot generate random categorical samples for undefined cardinality (n=-1). "
3386+ "To sample from this class, first call Categorical.set_provisional_n(n) before calling rand()."
3387+ )
3388+ n = self ._provisional_n
3389+ else :
3390+ n = self .space .n
33203391 if shape is None :
33213392 shape = _size ([])
33223393 if self .mask is None :
33233394 return torch .randint (
33243395 0 ,
3325- self . space . n ,
3396+ n ,
33263397 _size ([* shape , * self .shape ]),
33273398 device = self .device ,
33283399 dtype = self .dtype ,
@@ -3334,6 +3405,12 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor:
33343405 else :
33353406 mask_flat = mask
33363407 shape_out = mask .shape [:- 1 ]
3408+ # Check that the mask has the right size
3409+ if mask_flat .shape [- 1 ] != n :
3410+ raise ValueError (
3411+ "The last dimension of the mask must match the number of action allowed by the "
3412+ f"Categorical spec. Got mask.shape={ self .mask .shape } and n={ n } ."
3413+ )
33373414 out = torch .multinomial (mask_flat .float (), 1 ).reshape (shape_out )
33383415 return out
33393416
@@ -3360,6 +3437,8 @@ def is_in(self, val: torch.Tensor) -> bool:
33603437 dtype_match = val .dtype == self .dtype
33613438 if not dtype_match :
33623439 return False
3440+ if self .space .n == - 1 :
3441+ return True
33633442 return (0 <= val ).all () and (val < self .space .n ).all ()
33643443 shape = self .mask .shape
33653444 shape = _size ([* torch .broadcast_shapes (shape [:- 1 ], val .shape ), shape [- 1 ]])
@@ -3607,7 +3686,7 @@ def __init__(
36073686 device : Optional [DEVICE_TYPING ] = None ,
36083687 dtype : Union [str , torch .dtype ] = torch .int8 ,
36093688 ):
3610- if n is None and not shape :
3689+ if n is None and shape is None :
36113690 raise TypeError ("Must provide either n or shape." )
36123691 if n is None :
36133692 n = shape [- 1 ]
@@ -3813,6 +3892,9 @@ def enumerate(self) -> torch.Tensor:
38133892 arange = arange .expand (arange .shape [0 ], * self .shape )
38143893 return arange
38153894
3895+ def cardinality (self ) -> int :
3896+ return self .nvec ._base .prod ()
3897+
38163898 def update_mask (self , mask ):
38173899 """Sets a mask to prevent some of the possible outcomes when a sample is taken.
38183900
@@ -4373,7 +4455,7 @@ def set(self, name, spec):
43734455 shape = spec .shape
43744456 if shape [: self .ndim ] != self .shape :
43754457 if (
4376- isinstance (spec , Composite )
4458+ isinstance (spec , ( Composite , NonTensor ) )
43774459 and spec .ndim < self .ndim
43784460 and self .shape [: spec .ndim ] == spec .shape
43794461 ):
@@ -4382,7 +4464,7 @@ def set(self, name, spec):
43824464 spec .shape = self .shape
43834465 else :
43844466 raise ValueError (
4385- "The shape of the spec and the Composite mismatch: the first "
4467+ f "The shape of the spec { type ( spec ). __name__ } and the Composite { type ( self ). __name__ } mismatch: the first "
43864468 f"{ self .ndim } dimensions should match but got spec.shape={ spec .shape } and "
43874469 f"Composite.shape={ self .shape } ."
43884470 )
@@ -4798,6 +4880,18 @@ def clone(self) -> Composite:
47984880 shape = self .shape ,
47994881 )
48004882
4883+ def cardinality (self ) -> int :
4884+ n = None
4885+ for spec in self .values ():
4886+ if spec is None :
4887+ continue
4888+ if n is None :
4889+ n = 1
4890+ n = n * spec .cardinality ()
4891+ if n is None :
4892+ n = 0
4893+ return n
4894+
48014895 def enumerate (self ) -> TensorDictBase :
48024896 # We are going to use meshgrid to create samples of all the subspecs in here
48034897 # but first let's get rid of the batch size, we'll put it back later
0 commit comments