Skip to content

"multinomial_kernel_cuda" not implemented for 'Half' #29211

@nottombrown

Description

@nottombrown

It appears that even after the 1.3.1 fixes, torch.distributions.Categorical no longer works with half dtypes

To Reproduce

pip install -U --pre torch==1.4.0.dev20191104 torchvision -f https://download.pytorch.org/whl/nightly/cu101/torch_nightly.html
import torch


def test_fp16_categorical():
    logits_fp16 = torch.randn(20).cuda().half()

    # These are fine
    torch.argmax(logits_fp16)
    torch.max(logits_fp16)

    # This is also fine
    logits_fp32 = logits_fp16.float()
    sample = torch.distributions.Categorical(logits=logits_fp32).sample()
    print(sample)

    # This fails
    sample = torch.distributions.Categorical(logits=logits_fp16).sample()
    print(sample)


if __name__ == "__main__":
    test_fp16_categorical()
File "/opt/conda/lib/python3.7/site-packages/torch/distributions/categorical.py", line 107, in sample
    sample_2d = torch.multinomial(probs_2d, 1, True)
RuntimeError: "multinomial_kernel_cuda" not implemented for 'Half'

Environment

Collecting environment information...
PyTorch version: 1.4.0.dev20191104
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Ubuntu 16.04.5 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
CMake version: version 3.11.1

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration:
GPU 0: Tesla V100-SXM2-16GB
GPU 1: Tesla V100-SXM2-16GB
GPU 2: Tesla V100-SXM2-16GB
GPU 3: Tesla V100-SXM2-16GB
GPU 4: Tesla V100-SXM2-16GB
GPU 5: Tesla V100-SXM2-16GB
GPU 6: Tesla V100-SXM2-16GB
GPU 7: Tesla V100-SXM2-16GB

Nvidia driver version: 418.87.01
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.4.2

Additional context

This is not blocking us because we can convert to fp32 before sampling

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions