Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve one_hot #15457

Open
zasdfgbnm opened this issue Dec 21, 2018 · 9 comments
Open

Improve one_hot #15457

zasdfgbnm opened this issue Dec 21, 2018 · 9 comments
Labels
function request A request for a new function or the addition of new arguments/modes to an existing function. module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@zasdfgbnm
Copy link
Collaborator

zasdfgbnm commented Dec 21, 2018

As discussed in #15208, the following improve can be made on torch.nn.functional.ont_hot:

  • Make the num_classes parameter int64_t? instead of int64_t num_classes = -1
  • Add optional dtype argument
  • Add optional dim argument

cc: @zou3519 @vadimkantorov

cc @albanD @mruberry

@zou3519 zou3519 added the todo Not as important as medium or high priority tasks, but we will work on these. label Dec 27, 2018
@deepaks4077
Copy link

@zou3519 : Mind if I take a look at this?

@zou3519
Copy link
Contributor

zou3519 commented Jan 7, 2019

I think the conclusion from @zasdfgbnm on this was that it's very easy to make this API change on the pytorch side but making the JIT support this as well was a little more involved

@zasdfgbnm
Copy link
Collaborator Author

@zou3519 Yes and no. The optional dtype rely on #15154, and the optional int rely on #15234, those were not merged when I wrote #15208. But now this is available, so that part should not be a problem.

@ssnl
Copy link
Collaborator

ssnl commented Jul 16, 2019

optional dtype is pretty import IMO. returning a long tensor is useless is many cases.

@fmassa fmassa added enhancement Not as big of a feature, but technically not a bug. Should be easy to fix high priority module: operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed todo Not as important as medium or high priority tasks, but we will work on these. labels Jul 17, 2019
@fmassa
Copy link
Member

fmassa commented Jul 17, 2019

Constructor functions should natively expose dtype and device. This is currently inferred from the tensor that is passed, but for consistency we should add it as well.

I'm not sure this deserves to be high-pri though, as one_hot is in concept not very different from diag, which also doesn't expose dtype nor device

@ssnl
Copy link
Collaborator

ssnl commented Jul 17, 2019

@fmassa I would say that this is not a ctor function. However, it should have a dtype arg (similar to how F.softmax has one) because of common precision issue and conversion needs.

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Feb 6, 2020

Also in #33046 I propose changing default dtype to bool/byte from long.

@vadimkantorov
Copy link
Contributor

It would also be good if one_hot supported receiving the class index as python scalar, currently only torch.tensor(class_idx) works

@mruberry mruberry added function request A request for a new function or the addition of new arguments/modes to an existing function. module: nn Related to torch.nn and removed enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: operators (deprecated) labels Oct 10, 2020
@myron
Copy link

myron commented Oct 10, 2022

it would be also good to support 2D and 3D inputs, to create a multi-channel 2D or 3D array (without the need for view/reshape)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
function request A request for a new function or the addition of new arguments/modes to an existing function. module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

8 participants