Skip to content

torch.nn.functional.one_hot should gracefully skip negative and out-of-range indices #45352

@msbaines

Description

@msbaines

🚀 Feature

There are useful algorithms that benefit for one_hot skipping negative and out-of-range indices. tensorflow does not crash in these scenarios and instead emits a 0 vector.

>>> x = torch.tensor([1, 2, 8])
>>> F.one_hot(x, num_classes=4)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Class values must be smaller than num_classes.
>>> x = torch.tensor([1, 2, -1])
>>> F.one_hot(x, num_classes=4)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Class values must be non-negative.

Motivation

Be able to simply represent some algorithms.

Pitch

Be able to simply represent some algorithms. Easy to copy code/algorithms to/from tf and torch.

Alternatives

Explicitly clear negative and out-of-range values before calling one_hot and then cleanup the one_hot vector after. This is less efficient and less elegant.

Additional context

cc @albanD @mruberry

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions