Skip to content

Commit

Permalink
Add complex32, complex64 and complex128 dtypes (#11173)
Browse files Browse the repository at this point in the history
Summary:
We don't generate a corresponding Type implementations for them,
so this doesn't do anything at the moment.

We don't plan on supporting complex32 in the near future, but
it is added to reserve the name and number in case we do at
some point in the future.

Pull Request resolved: #11173

Reviewed By: SsnL

Differential Revision: D9627477

Pulled By: ezyang

fbshipit-source-id: f49a44ab1c92d8a33130c249ac7b234f210a65e6
  • Loading branch information
ezyang authored and facebook-github-bot committed Sep 4, 2018
1 parent c5b021c commit cd4c326
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 39 deletions.
6 changes: 6 additions & 0 deletions aten/src/ATen/DLConvertor.cpp
Expand Up @@ -37,6 +37,12 @@ static DLDataType getDLDataType(const Type& type) {
case ScalarType::Half:
dtype.code = DLDataTypeCode::kDLFloat;
break;
case ScalarType::ComplexHalf:
throw std::logic_error("ComplexHalf is not supported by dlpack");
case ScalarType::ComplexFloat:
throw std::logic_error("ComplexFloat is not supported by dlpack");
case ScalarType::ComplexDouble:
throw std::logic_error("ComplexDouble is not supported by dlpack");
case ScalarType::Undefined:
throw std::logic_error("Undefined is not a valid ScalarType");
case ScalarType::NumOptions:
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/core/Half.h
Expand Up @@ -68,6 +68,14 @@ struct alignas(2) Half {
#endif
};

// This is just a placeholder for whatever complex representation we
// end up deciding to use for half-precision complex numbers.
struct alignas(4) ComplexHalf {
Half real_;
Half imag_;
ComplexHalf() = default;
};

template <typename To, typename From>
To convert(From f) {
return static_cast<To>(f);
Expand Down
63 changes: 44 additions & 19 deletions aten/src/ATen/core/ScalarType.h
Expand Up @@ -6,20 +6,34 @@

#include <cstdint>
#include <iostream>
#include <complex>

namespace at {

// NB: Order matters for this macro; it is relied upon in
// _promoteTypesLookup and the serialization format.
#define AT_FORALL_SCALAR_TYPES(_) \
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
_(uint8_t,Byte,i) /* 0 */ \
_(int8_t,Char,i) /* 1 */ \
_(int16_t,Short,i) /* 2 */ \
_(int,Int,i) /* 3 */ \
_(int64_t,Long,i) /* 4 */ \
_(at::Half,Half,d) /* 5 */ \
_(float,Float,d) /* 6 */ \
_(double,Double,d) /* 7 */
_(double,Double,d) /* 7 */ \
_(at::ComplexHalf,ComplexHalf,z) /* 8 */ \
_(std::complex<float>,ComplexFloat,z) /* 9 */ \
_(std::complex<double>,ComplexDouble,z) /* 10 */

#define AT_FORALL_SCALAR_TYPES(_) \
_(uint8_t,Byte,i) \
_(int8_t,Char,i) \
_(int16_t,Short,i) \
_(int,Int,i) \
_(int64_t,Long,i) \
_(at::Half,Half,d) \
_(float,Float,d) \
_(double,Double,d)

#define AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(_) \
_(uint8_t,Byte,i) \
Expand All @@ -33,9 +47,9 @@ _(double,Double,d)
enum class ScalarType {
#define DEFINE_ENUM(_1,n,_2) \
n,
AT_FORALL_SCALAR_TYPES(DEFINE_ENUM)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ENUM)
#undef DEFINE_ENUM
Undefined, // 8
Undefined,
NumOptions
};

Expand All @@ -44,7 +58,7 @@ static inline DataType scalarTypeToDataType(ScalarType scalar_type) {
case ScalarType:: name : return caffe2::TypeMeta::Id<ctype>();

switch(scalar_type) {
AT_FORALL_SCALAR_TYPES(DEFINE_CASE)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
case ScalarType::Undefined: return DataType::uninitialized();
default: AT_ERROR("Unrecognized Scalartype ", scalar_type, " (please report this error)");
}
Expand All @@ -56,7 +70,7 @@ static inline ScalarType dataTypeToScalarType(DataType dtype) {
if (dtype == caffe2::TypeMeta::Id<ctype>()) { \
return ScalarType:: name; \
}
AT_FORALL_SCALAR_TYPES(DEFINE_IF)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_IF)
#undef DEFINE_IF
if (dtype == at::DataType::uninitialized()) {
return ScalarType::Undefined;
Expand All @@ -67,15 +81,15 @@ static inline ScalarType dataTypeToScalarType(DataType dtype) {
#define DEFINE_CONSTANT(_,name,_2) \
constexpr ScalarType k##name = ScalarType::name;

AT_FORALL_SCALAR_TYPES(DEFINE_CONSTANT)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CONSTANT)
#undef DEFINE_CONSTANT

static inline const char * toString(ScalarType t) {
#define DEFINE_CASE(_,name,_2) \
case ScalarType:: name : return #name;

switch(t) {
AT_FORALL_SCALAR_TYPES(DEFINE_CASE)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
default:
return "UNKNOWN_SCALAR";
}
Expand All @@ -87,7 +101,7 @@ static inline size_t elementSize(ScalarType t) {
case ScalarType:: name : return sizeof(ctype);

switch(t) {
AT_FORALL_SCALAR_TYPES(CASE_ELEMENTSIZE_CASE)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(CASE_ELEMENTSIZE_CASE)
default:
AT_ERROR("Unknown ScalarType");
}
Expand All @@ -108,6 +122,12 @@ static inline bool isFloatingType(ScalarType t) {
t == ScalarType::Half);
}

static inline bool isComplexType(ScalarType t) {
return (t == ScalarType::ComplexHalf ||
t == ScalarType::ComplexFloat ||
t == ScalarType::ComplexDouble);
}

static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
// This is generated according to NumPy's promote_types
constexpr auto u1 = ScalarType::Byte;
Expand All @@ -119,19 +139,24 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
constexpr auto f4 = ScalarType::Float;
constexpr auto f8 = ScalarType::Double;
constexpr auto ud = ScalarType::Undefined;
if (a == ud || b == ud) {
return ScalarType::Undefined;
}
if (isComplexType(a) || isComplexType(b)) {
AT_ERROR("promoteTypes with complex numbers is not handled yet; figure out what the correct rules should be");
}
static constexpr ScalarType _promoteTypesLookup
[static_cast<int>(ScalarType::NumOptions)]
[static_cast<int>(ScalarType::NumOptions)] = {
/* u1 i1 i2 i4 i8 f2 f4 f8, ud */
/* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8, ud },
/* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8, ud },
/* i2 */ { i2, i2, i2, i4, i8, f4, f4, f8, ud },
/* i4 */ { i4, i4, i4, i4, i8, f8, f4, f8, ud },
/* i8 */ { i8, i8, i8, i8, i8, f8, f4, f8, ud },
/* f2 */ { f2, f2, f4, f8, f8, f2, f4, f8, ud },
/* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8, ud },
/* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8, ud },
/* ud */ { ud, ud, ud, ud, ud, ud, ud, ud, ud },
/* u1 i1 i2 i4 i8 f2 f4 f8 */
/* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8 },
/* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8 },
/* i2 */ { i2, i2, i2, i4, i8, f4, f4, f8 },
/* i4 */ { i4, i4, i4, i4, i8, f8, f4, f8 },
/* i8 */ { i8, i8, i8, i8, i8, f8, f4, f8 },
/* f2 */ { f2, f2, f4, f8, f8, f2, f4, f8 },
/* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8 },
/* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8 },
};
return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/core/ScalarTypeUtils.h
Expand Up @@ -13,7 +13,7 @@ template <> \
struct CTypeToScalarType<ct> { \
static inline at::ScalarType to() { return at::ScalarType::st; } \
};
AT_FORALL_SCALAR_TYPES(DEFINE_TO_SCALAR_TYPE)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_TO_SCALAR_TYPE)
#undef DEFINE_TO_SCALAR_TYPE

} // namespace at
40 changes: 22 additions & 18 deletions aten/src/ATen/core/typeid.h
Expand Up @@ -10,6 +10,7 @@
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <complex>
#ifdef __GXX_RTTI
#include <typeinfo>
#endif
Expand Down Expand Up @@ -466,26 +467,29 @@ CAFFE_DECLARE_KNOWN_TYPE(4, int64_t)
CAFFE_DECLARE_KNOWN_TYPE(5, at::Half)
CAFFE_DECLARE_KNOWN_TYPE(6, float)
CAFFE_DECLARE_KNOWN_TYPE(7, double)
// 8 = undefined type id

CAFFE_DECLARE_KNOWN_TYPE(9, Tensor)
CAFFE_DECLARE_KNOWN_TYPE(10, std::string)
CAFFE_DECLARE_KNOWN_TYPE(11, bool)
CAFFE_DECLARE_KNOWN_TYPE(12, uint16_t)
CAFFE_DECLARE_KNOWN_TYPE(13, char)
CAFFE_DECLARE_KNOWN_TYPE(14, std::unique_ptr<std::mutex>)
CAFFE_DECLARE_KNOWN_TYPE(15, std::unique_ptr<std::atomic<bool>>)
CAFFE_DECLARE_KNOWN_TYPE(16, std::vector<int32_t>)
CAFFE_DECLARE_KNOWN_TYPE(17, std::vector<int64_t>)
CAFFE_DECLARE_KNOWN_TYPE(18, std::vector<unsigned long>)
CAFFE_DECLARE_KNOWN_TYPE(19, bool*)
CAFFE_DECLARE_KNOWN_TYPE(20, char*)
CAFFE_DECLARE_KNOWN_TYPE(21, int*)
CAFFE_DECLARE_KNOWN_TYPE(8, at::ComplexHalf)
CAFFE_DECLARE_KNOWN_TYPE(9, std::complex<float>)
CAFFE_DECLARE_KNOWN_TYPE(10, std::complex<double>)
// 10 = undefined type id

CAFFE_DECLARE_KNOWN_TYPE(12, Tensor)
CAFFE_DECLARE_KNOWN_TYPE(13, std::string)
CAFFE_DECLARE_KNOWN_TYPE(14, bool)
CAFFE_DECLARE_KNOWN_TYPE(15, uint16_t)
CAFFE_DECLARE_KNOWN_TYPE(16, char)
CAFFE_DECLARE_KNOWN_TYPE(17, std::unique_ptr<std::mutex>)
CAFFE_DECLARE_KNOWN_TYPE(18, std::unique_ptr<std::atomic<bool>>)
CAFFE_DECLARE_KNOWN_TYPE(19, std::vector<int32_t>)
CAFFE_DECLARE_KNOWN_TYPE(20, std::vector<int64_t>)
CAFFE_DECLARE_KNOWN_TYPE(21, std::vector<unsigned long>)
CAFFE_DECLARE_KNOWN_TYPE(22, bool*)
CAFFE_DECLARE_KNOWN_TYPE(23, char*)
CAFFE_DECLARE_KNOWN_TYPE(24, int*)

#ifdef CAFFE2_UNIQUE_LONG_TYPEMETA
CAFFE_DECLARE_KNOWN_TYPE(22, long)
CAFFE_DECLARE_KNOWN_TYPE(23, std::vector<long>)
CAFFE_DECLARE_KNOWN_TYPE(25, long)
CAFFE_DECLARE_KNOWN_TYPE(26, std::vector<long>)
#endif // CAFFE2_UNIQUE_LONG_TYPEMETA

CAFFE_DECLARE_KNOWN_TYPE(24, _CaffeHighestPreallocatedTypeId)
CAFFE_DECLARE_KNOWN_TYPE(27, _CaffeHighestPreallocatedTypeId)
} // namespace caffe2
3 changes: 3 additions & 0 deletions test/cpp/api/serialization.cpp
Expand Up @@ -52,6 +52,9 @@ TEST_CASE("serialization") {
// XXX can't serialize half tensors at the moment since contiguous() is
// not implemented for this type;
continue;
} else if (at::isComplexType(static_cast<torch::Dtype>(i))) {
// Not supported yet
continue;
} else if (i == static_cast<int>(torch::Dtype::Undefined)) {
// We can't construct a tensor for this type. This is tested in
// serialization/undefined anyway.
Expand Down
8 changes: 7 additions & 1 deletion torch/csrc/utils/tensor_dtypes.cpp
Expand Up @@ -30,6 +30,12 @@ static std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarTy
return std::make_pair("int16", "short");
case at::ScalarType::Half:
return std::make_pair("float16", "half");
case at::ScalarType::ComplexHalf:
return std::make_pair("complex32", "");
case at::ScalarType::ComplexFloat:
return std::make_pair("complex64", "");
case at::ScalarType::ComplexDouble:
return std::make_pair("complex128", "");
default:
throw std::runtime_error("Unimplemented scalar type");
}
Expand All @@ -42,7 +48,7 @@ void initializeDtypes() {
#define DEFINE_SCALAR_TYPE(_1,n,_2) at::ScalarType::n,

at::ScalarType all_scalar_types[] = {
AT_FORALL_SCALAR_TYPES(DEFINE_SCALAR_TYPE)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_SCALAR_TYPE)
};

for (at::ScalarType scalarType: all_scalar_types) {
Expand Down

0 comments on commit cd4c326

Please sign in to comment.