From 68971a8ffb4e1a1d8a7f69fd789f8a60f5f0d357 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Thu, 18 Jul 2019 20:07:56 +0200 Subject: [PATCH] Fix get_all_math_dtypes for device='cuda' retuning None --- torch/testing/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/testing/__init__.py b/torch/testing/__init__.py index 9fe6e8377ec7..fbf7178e9e08 100644 --- a/torch/testing/__init__.py +++ b/torch/testing/__init__.py @@ -96,10 +96,10 @@ def get_all_math_dtypes(device): torch.float32, torch.float64] # torch.float16 is a math dtype on cuda but not cpu. - if device == 'cpu': - return dtypes - else: - return dtypes.append(torch.float16) + if device.startswith('cuda'): + dtypes.append(torch.float16) + + return dtypes def get_all_device_types():