-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Closed
Labels
enhancementNot as big of a feature, but technically not a bug. Should be easy to fixNot as big of a feature, but technically not a bug. Should be easy to fixhigh prioritymodule: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generalmodule: nnRelated to torch.nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
Only a LongTensor
may be used as the target
parameter of nn.NLLLoss.forward
. Runtime errors would occur while using Char/Byte/ShortTensor
, as shown by the following code snippet.
e = torch.Tensor([[math.log(3/4), math.log(1/4)]])
t = torch.LongTensor([0])
torch.nn.NLLLoss()(e, t)
t = torch.CharTensor([0])
torch.nn.NLLLoss()(e, t) # RuntimeError: expected scalar type Long but found Char
t = torch.ByteTensor([0])
torch.nn.NLLLoss()(e, t) # RuntimeError: expected scalar type Long but found Byte
I really want to store pre-computed targets as byte tensors to save memory. Is it possible to have NLLLoss
accept additional types of targets?
I am requesting byte-typed target because adding a type would enlarge the binary by one C++ template initialization, and we could divide a wide softmax into a hierarchy of much shorter (say 256) ones.
cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @Varal7 @ngimel @albanD @mruberry
Metadata
Metadata
Assignees
Labels
enhancementNot as big of a feature, but technically not a bug. Should be easy to fixNot as big of a feature, but technically not a bug. Should be easy to fixhigh prioritymodule: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generalmodule: nnRelated to torch.nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module