Skip to content

Commit

Permalink
Added type promotion logic for complex numbers (#34093)
Browse files Browse the repository at this point in the history
Summary:
Issue: #33780
After this PR:
1. dtype promotion logic will correctly work for ops involving complex scalars
2. added alias for complex64 (cfloat) and complex128 (cdouble)
3. added an internal function get_complex_default_dtype (consciously not exposed in public API)
   - sets the default complex dtype to be double if default_dtype is set to double, else float #34093 (comment)
>>> 1j*torch.ones(2)
tensor([(0.0000 + 1.0000j), (0.0000 + 1.0000j)], dtype=torch.complex64)

>>> torch.set_default_dtype(torch.float64)
>>> 1j*torch.ones(2)
tensor([(0.0000 + 1.0000j), (0.0000 + 1.0000j)], dtype=torch.complex128)

>>> 1j + torch.ones(2)
tensor([(1.0000 + 1.0000j), (1.0000 + 1.0000j)], dtype=torch.complex128)

>>> torch.tensor(1j) + torch.ones(2,2)
tensor([[(1.0000 + 1.0000j), (1.0000 + 1.0000j)],
        [(1.0000 + 1.0000j), (1.0000 + 1.0000j)]], dtype=torch.complex128)
Pull Request resolved: #34093

Differential Revision: D20537125

Pulled By: anjali411

fbshipit-source-id: 05fb1f81b8ba039d0b698cdd2c0bbf8b0ce0b767
  • Loading branch information
anjali411 authored and facebook-github-bot committed Mar 25, 2020
1 parent 361eed6 commit c73e970
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 27 deletions.
19 changes: 13 additions & 6 deletions aten/src/ATen/native/TypeProperties.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,13 @@ static inline ScalarType promote_skip_undefined(ScalarType a, ScalarType b) {


static inline ScalarType combine_categories(ScalarType higher, ScalarType lower) {
if (isFloatingType(higher)) {
if(isComplexType(higher)) {
return higher;
}
if (higher == ScalarType::Bool || isFloatingType(lower)) {
else if(!isComplexType(lower) && isFloatingType(higher)) {
return higher;
}
if (higher == ScalarType::Bool || isFloatingType(lower) || isComplexType(lower)) {
return promote_skip_undefined(higher, lower);
}
if (higher != ScalarType::Undefined) {
Expand All @@ -75,8 +78,14 @@ ResultTypeState update_result_type_state(const Tensor& tensor, const ResultTypeS
}
ResultTypeState new_state = in_state;
ScalarType current = tensor.scalar_type();
if (tensor.unsafeGetTensorImpl()->is_wrapped_number() && isFloatingType(current)) {
current = typeMetaToScalarType(at::get_default_dtype());
if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
auto current_default = typeMetaToScalarType(at::get_default_dtype());
if(isComplexType(current)) {
current = typeMetaToScalarType(at::get_default_complex_dtype());
}
else if(isFloatingType(current)) {
current = current_default;
}
}
if ( tensor.dim() > 0 ) {
new_state.dimResult = promote_skip_undefined(in_state.dimResult, current);
Expand All @@ -85,7 +94,6 @@ ResultTypeState update_result_type_state(const Tensor& tensor, const ResultTypeS
} else {
new_state.zeroResult = promote_skip_undefined(in_state.zeroResult, current);
}

return new_state;
}

Expand All @@ -98,7 +106,6 @@ ScalarType result_type(TensorList tensors) {
for (const Tensor& tensor : tensors) {
state = update_result_type_state(tensor, state);
}

return result_type(state);
}

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/Copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void copy_device_to_device(TensorIterator& iter, bool non_blocking) {
copy_stream));
}
} else {
AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.dtype(0), "copy_", [&] {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBool, kBFloat16, iter.dtype(0), "copy_", [&] {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t x) { return x; });
});
}
Expand Down
9 changes: 9 additions & 0 deletions c10/core/DefaultDtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,21 @@

namespace c10 {
static auto default_dtype = caffe2::TypeMeta::Make<float>();
static auto default_complex_dtype = caffe2::TypeMeta::Make<std::complex<float>>();

void set_default_dtype(caffe2::TypeMeta dtype) {
default_dtype = std::move(dtype);
if(dtype == caffe2::TypeMeta::Make<double>()) {
default_complex_dtype = std::move(caffe2::TypeMeta::Make<std::complex<double>>());
} else {
default_complex_dtype = std::move(caffe2::TypeMeta::Make<std::complex<float>>());
}
}

const caffe2::TypeMeta& get_default_dtype() {
return default_dtype;
}
const caffe2::TypeMeta& get_default_complex_dtype() {
return default_complex_dtype;
}
} // namespace c10
1 change: 1 addition & 0 deletions c10/core/DefaultDtype.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ class TypeMeta;
namespace c10 {
C10_API void set_default_dtype(caffe2::TypeMeta dtype);
C10_API const caffe2::TypeMeta& get_default_dtype();
C10_API const caffe2::TypeMeta& get_default_complex_dtype();
} // namespace c10
15 changes: 12 additions & 3 deletions docs/source/tensor_attributes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ torch.dtype
.. class:: torch.dtype

A :class:`torch.dtype` is an object that represents the data type of a
:class:`torch.Tensor`. PyTorch has nine different data types:
:class:`torch.Tensor`. PyTorch has eleven different data types:

======================== =========================================== ===========================
Data type dtype Tensor types
Data type dtype Legacy Constructors
======================== =========================================== ===========================
32-bit floating point ``torch.float32`` or ``torch.float`` ``torch.*.FloatTensor``
64-bit floating point ``torch.float64`` or ``torch.double`` ``torch.*.DoubleTensor``
64-bit complex ``torch.complex64`` or ``torch.cfloat``
128-bit floating point ``torch.complex128`` or ``torch.cdouble``
16-bit floating point ``torch.float16`` or ``torch.half`` ``torch.*.HalfTensor``
8-bit integer (unsigned) ``torch.uint8`` ``torch.*.ByteTensor``
8-bit integer (signed) ``torch.int8`` ``torch.*.CharTensor``
Expand All @@ -34,13 +36,16 @@ Boolean ``torch.bool`` ``torch
To find out if a :class:`torch.dtype` is a floating point data type, the property :attr:`is_floating_point`
can be used, which returns ``True`` if the data type is a floating point data type.

To find out if a :class:`torch.dtype` is a complex data type, the property :attr:`is_complex`
can be used, which returns ``True`` if the data type is a complex data type.

.. _type-promotion-doc:

When the dtypes of inputs to an arithmetic operation (`add`, `sub`, `div`, `mul`) differ, we promote
by finding the minimum dtype that satisfies the following rules:

* If the type of a scalar operand is of a higher category than tensor operands
(where floating > integral > boolean), we promote to a type with sufficient size to hold
(where complex > floating > integral > boolean), we promote to a type with sufficient size to hold
all scalar operands of that category.
* If a zero-dimension tensor operand has a higher category than dimensioned operands,
we promote to a type with sufficient size and category to hold all zero-dim tensor operands of
Expand All @@ -57,6 +62,8 @@ Promotion Examples::

>>> float_tensor = torch.ones(1, dtype=torch.float)
>>> double_tensor = torch.ones(1, dtype=torch.double)
>>> complex_float_tensor = torch.ones(1, dtype=torch.complex64)
>>> complex_double_tensor = torch.ones(1, dtype=torch.complex128)
>>> int_tensor = torch.ones(1, dtype=torch.int)
>>> long_tensor = torch.ones(1, dtype=torch.long)
>>> uint_tensor = torch.ones(1, dtype=torch.uint8)
Expand All @@ -81,6 +88,8 @@ Promotion Examples::
torch.uint8
>>> (float_tensor + double_tensor).dtype
torch.float64
>>> (complex_float_tensor + complex_double_tensor).dtype
torch.complex128
>>> (bool_tensor + int_tensor).dtype
torch.int32
# Since long is a different kind than float, result dtype only needs to be large enough
Expand Down
74 changes: 61 additions & 13 deletions test/test_type_promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def wrapped_fn(*args, **kwargs):

return wrapped_fn


class TestTypePromotion(TestCase):

# In-place operations don't promote.
Expand Down Expand Up @@ -81,14 +80,44 @@ def test_int_promotion(self, device):

@float_double_default_dtype
def test_float_promotion(self, device):
a = torch.ones([4, 4, 4], dtype=torch.float, device=device)
b = torch.ones([4, 4, 4], dtype=torch.double, device=device)
c = a + b
self.assertEqual(c, b + b)
self.assertEqual(c.dtype, torch.double)
c = b + a
self.assertEqual(c, b + b)
self.assertEqual(c.dtype, torch.double)
def test_promotion(dtype_float, dtype_double):
a = torch.ones([4, 4, 4], dtype=dtype_float, device=device)
b = torch.ones([4, 4, 4], dtype=dtype_double, device=device)
c = a + b
self.assertEqual(c, b + b)
self.assertEqual(c.dtype, dtype_double)
c = b + a
self.assertEqual(c, b + b)
self.assertEqual(c.dtype, dtype_double)
test_promotion(torch.float, torch.double)

@float_double_default_dtype
def test_complex_promotion(self, device):
def test_promotion(dtype_float, dtype_double):
a = torch.ones([4, 4, 4], dtype=dtype_float, device=device)
b = torch.ones([4, 4, 4], dtype=dtype_double, device=device)
c = a + b
self.assertEqual(c, b + b)
self.assertEqual(c.dtype, dtype_double)
c = b + a
self.assertEqual(c, b + b)
self.assertEqual(c.dtype, dtype_double)

test_promotion(torch.complex64, torch.complex128)

a = torch.randn(3, dtype=torch.complex64, device=device)
self.assertEqual((a * 5).dtype, torch.complex64)
# not a "wrapped number"
other = torch.tensor(5.5, dtype=torch.double, device=device)
self.assertEqual((a + other).dtype, torch.complex64)

@float_double_default_dtype
def test_complex_scalar_mult_tensor_promotion(self, device):
a = 1j * torch.ones(2, device=device)
a = a + 1j
b = torch.tensor([2j, 2j], device=device)
self.assertEqual(a, b)
self.assertEqual(a.dtype, b.dtype)

@float_double_default_dtype
def test_add_wrapped(self, device):
Expand Down Expand Up @@ -176,7 +205,17 @@ def _get_test_tensor(self, device, dtype, remove_zeros=False):
shape = [5, 5, 5]
if dtype == torch.bool:
tensor = torch.randint(int(remove_zeros), 2, shape, device=device, dtype=dtype)
elif dtype.is_floating_point:
elif dtype.is_complex:
# "_th_normal_ not supported on CPUType for Half" so simpler create and convert
tensor = torch.randn(shape, dtype=dtype, device=device)
if remove_zeros:
tensor_abs = torch.abs(tensor)
if dtype == torch.complex64:
tensor_abs = tensor_abs.to(torch.float)
elif dtype == torch.complex128:
tensor_abs = tensor_abs.to(torch.double)
tensor[tensor_abs < 0.05] = 5
elif dtype.is_floating_point or dtype.is_complex:
# "_th_normal_ not supported on CPUType for Half" so simpler create and convert
tensor = torch.randn(shape, device=device)
tensor = tensor.to(dtype)
Expand All @@ -194,8 +233,9 @@ def _get_test_tensor(self, device, dtype, remove_zeros=False):
def test_many_promotions(self, device):
# Can also include half on CPU in cases where it will be promoted to a
# supported dtype
dtypes1 = torch.testing.get_all_math_dtypes('cuda')
dtypes2 = torch.testing.get_all_math_dtypes(device)
complex_dtypes = torch.testing.get_all_complex_dtypes()
dtypes1 = torch.testing.get_all_math_dtypes('cuda') + complex_dtypes
dtypes2 = torch.testing.get_all_math_dtypes(device) + complex_dtypes
ops = [torch.add, torch.sub, torch.mul, torch.div, torch.rsub]
for dt1, dt2 in itertools.product(dtypes1, dtypes2):
for op, non_contiguous in itertools.product(ops, [True, False]):
Expand All @@ -217,7 +257,15 @@ def test_many_promotions(self, device):
self.assertEqual(not first.is_contiguous(), non_contiguous)
self.assertEqual(not second.is_contiguous(), non_contiguous)
result = op(first, second)
expected = op(first.to(common_dtype), second.to(common_dtype))
# TODO: copy_() for complex on cuda issues on github: #33567 #35284
if common_dtype.is_complex and first.is_cuda:
first_ = torch.zeros(first.size(), dtype=common_dtype, device=device)
second_ = torch.zeros(second.size(), dtype=common_dtype, device=device)
first_.add_(first)
second_.add_(second)
expected = op(first_, second_)
else:
expected = op(first.to(common_dtype), second.to(common_dtype))
self.assertEqual(result.dtype, expected.dtype, message='{} with {}, {}'.format(op.__name__, dt1, dt2))
self.assertEqual(result, expected, message='{} with {}, {}'.format(op.__name__, dt1, dt2))

Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/utils/tensor_dtypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ static std::pair<std::string, std::string> getDtypeNames(
case at::ScalarType::ComplexHalf:
return std::make_pair("complex32", "");
case at::ScalarType::ComplexFloat:
return std::make_pair("complex64", "");
return std::make_pair("complex64", "cfloat");
case at::ScalarType::ComplexDouble:
return std::make_pair("complex128", "");
return std::make_pair("complex128", "cdouble");
case at::ScalarType::Bool:
return std::make_pair("bool", "");
case at::ScalarType::QInt8:
Expand Down
4 changes: 2 additions & 2 deletions torch/onnx/symbolic_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,8 @@ def _set_operator_export_type(operator_export_type):
'int64_t': 'Long',
'int16_t': 'Short',
'bool': 'Bool',
'complex64': '',
'complex128': ''
'complex64': 'ComplexFloat',
'complex128': 'ComplexDouble'
}


Expand Down
3 changes: 3 additions & 0 deletions torch/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ def get_all_math_dtypes(device):

return dtypes

def get_all_complex_dtypes():
dtypes = [torch.complex64, torch.complex128]
return dtypes

def get_all_device_types():
return ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
Expand Down

0 comments on commit c73e970

Please sign in to comment.