Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add complex IValues #50883

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 15 additions & 0 deletions aten/src/ATen/core/ivalue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ TypePtr IValue::type() const {
return StorageType::get();
case Tag::Double:
return FloatType::get();
case Tag::ComplexDouble:
return ComplexDoubleType::get();
case Tag::Int:
return IntType::get();
case Tag::Bool:
Expand Down Expand Up @@ -284,6 +286,8 @@ IValue IValue::equals(const IValue& rhs) const {
return rhs.isStorage() && lhs.toStorage().unsafeGetStorageImpl() == rhs.toStorage().unsafeGetStorageImpl();
case Tag::Double:
return rhs.isDouble() && lhs.toDouble() == rhs.toDouble();
case Tag::ComplexDouble:
return rhs.isComplexDouble() && lhs.toComplexDouble() == rhs.toComplexDouble();
case Tag::Int:
return rhs.isInt() && lhs.toInt() == rhs.toInt();
case Tag::Bool:
Expand Down Expand Up @@ -352,6 +356,7 @@ size_t IValue::hash(const IValue& v) {
case Tag::Capsule:
case Tag::Generator:
case Tag::Quantizer:
case Tag::ComplexDouble:
case Tag::Enum:
case Tag::Stream:
case Tag::Uninitialized:
Expand Down Expand Up @@ -687,6 +692,16 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
<< std::setprecision(std::numeric_limits<double>::max_digits10)
<< v.toDouble()
<< std::setprecision(orig_prec);
} case IValue::Tag::ComplexDouble: {
c10::complex<double> d = v.toComplexDouble();
IValue real(d.real()), imag(std::abs(d.imag()));
auto sign = "";
if (d.imag() >= 0) {
sign = "+";
} else {
sign = "-";
}
return out << real << sign << imag << "j";
} case IValue::Tag::Int:
return out << v.toInt();
case IValue::Tag::Bool:
Expand Down
30 changes: 28 additions & 2 deletions aten/src/ATen/core/ivalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,19 @@ struct GenericDict;
struct Object;
struct PyObjectHolder;
struct EnumHolder;
// We need a ComplexHolder because currently the payloads in the Union
// only take 64 bits. Since ComplexDouble takes up 128 bits, and is too big
// to fit in the IValue directly, we indirect complex numbers through an intrusive
// pointer to ComplexHolder (which contains a c10::complex).
struct ComplexHolder : c10::intrusive_ptr_target {
public:
template <typename T>
ComplexHolder(c10::complex<T> c) {
val = convert<decltype(val), c10::complex<T>>(c);
}
ComplexHolder() {}
c10::complex<double> val;
};
} // namespace ivalue

// This is an owning wrapper for a c10::optional<std::vector<T>>
Expand Down Expand Up @@ -107,6 +120,7 @@ struct Capsule {
_(Tensor) \
_(Storage) \
_(Double) \
_(ComplexDouble) \
_(Int) \
_(Bool) \
_(Tuple) \
Expand Down Expand Up @@ -443,6 +457,12 @@ struct TORCH_API IValue final {
return payload.u.as_double;
}

// ComplexDouble
template <typename T>
IValue(c10::complex<T> c);
bool isComplexDouble() const { return Tag::ComplexDouble == tag; }
c10::complex<double> toComplexDouble() const;

// Future
IValue(c10::intrusive_ptr<ivalue::Future> v);
bool isFuture() const {
Expand Down Expand Up @@ -631,22 +651,28 @@ struct TORCH_API IValue final {
return i;
}

// Scalar, which gets encoded as either an Int or a Double
// Scalar, which gets encoded as either an Int, a Double or a ComplexDouble
IValue(at::Scalar s) : IValue() {
if (s.isFloatingPoint()) {
*this = s.toDouble();
} else if (s.isComplex()) {
*this = s.toComplexDouble();
} else {
*this = s.toLong();
}
}

bool isScalar() const {
return isDouble() || isInt();
return isDouble() || isInt() || isComplexDouble();
}

at::Scalar toScalar() const {
if (isDouble())
return toDouble();
else if (isInt())
return toInt();
else if (isComplexDouble())
return toComplexDouble();
throw std::runtime_error("IValue is not a Scalar");
}

Expand Down
13 changes: 13 additions & 0 deletions aten/src/ATen/core/ivalue_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ inline c10::intrusive_ptr<ivalue::EnumHolder> IValue::toEnumHolder() const& {
TORCH_INTERNAL_ASSERT(isEnum(), "Expected Enum but got ", tagKind());
return toIntrusivePtr<ivalue::EnumHolder>();
}
inline c10::complex<double> IValue::toComplexDouble() const {
TORCH_INTERNAL_ASSERT(isComplexDouble(), "Expected ComplexDouble but got ", tagKind());
auto ptr = toIntrusivePtr<ivalue::ComplexHolder>();
return (*ptr).val;
}
inline at::Tensor IValue::toTensor() && {
AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind());
auto result = std::move(payload.as_tensor);
Expand Down Expand Up @@ -754,6 +759,7 @@ DEFINE_TO(at::Storage, toStorage)
DEFINE_TO(c10::Stream, toStream)
DEFINE_TO(float, toDouble)
DEFINE_TO(double, toDouble)
DEFINE_TO(c10::complex<double>, toComplexDouble)
DEFINE_TO(unsigned char, toInt)
DEFINE_TO(signed char, toInt)
DEFINE_TO(unsigned short, toInt)
Expand Down Expand Up @@ -1222,6 +1228,13 @@ inline IValue::IValue(c10::intrusive_ptr<at::Quantizer> v)
payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
}

template <typename T>
inline IValue::IValue(c10::complex<T> c)
: tag(Tag::ComplexDouble), is_intrusive_ptr(true) {
auto v = c10::make_intrusive<ivalue::ComplexHolder>(c);
payload.u.as_intrusive_ptr = v.release();
}

inline const std::string& IValue::toStringRef() const {
AT_ASSERT(isString(), "Expected String but got ", tagKind());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
Expand Down
34 changes: 33 additions & 1 deletion aten/src/ATen/core/jit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,7 @@ using NumberTypePtr = std::shared_ptr<NumberType>;
// Subtype hierarchy for Number Types (NumberType as the base type):
// IntType <: NumberType
// FloatType <: NumberType
// ComplexDoubleType <:NumberType
anjali411 marked this conversation as resolved.
Show resolved Hide resolved
struct TORCH_API NumberType : public Type {
static NumberTypePtr create() {
return NumberTypePtr(new NumberType()); // NOLINT(modernize-make-shared)
Expand Down Expand Up @@ -1156,6 +1157,33 @@ struct TORCH_API FloatType : public NumberType {
}
};

struct ComplexDoubleType;
using ComplexDoubleTypePtr = std::shared_ptr<ComplexDoubleType>;
// This type represents a Python float number
struct TORCH_API ComplexDoubleType : public NumberType {
static ComplexDoubleTypePtr create() {
return ComplexDoubleTypePtr(new ComplexDoubleType()); // NOLINT(modernize-make-shared)
}
bool operator==(const Type& rhs) const override {
return rhs.kind() == kind();
}
std::string str() const override {
return "complex";
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override {
return rhs->kind() == TypeKind::NumberType || NumberType::isSubtypeOfExt(rhs, why_not);
anjali411 marked this conversation as resolved.
Show resolved Hide resolved
}
static const TypeKind Kind = TypeKind::ComplexDoubleType;
// global singleton
static ComplexDoubleTypePtr get();

private:
ComplexDoubleType() : NumberType(TypeKind::ComplexDoubleType) {}
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
return "complex";
}
};

struct IntType;
using IntTypePtr = std::shared_ptr<IntType>;
// This type represents a Python int number
Expand Down Expand Up @@ -1525,6 +1553,8 @@ inline TypePtr TensorType::fromNumberType(TypePtr typ) {
return TensorType::createContiguous(at::kFloat, at::kCPU, {});
} else if (typ->isSubtypeOf(BoolType::get())) {
return TensorType::createContiguous(at::kLong, at::kCPU, {});
} else if (typ->isSubtypeOf(ComplexDoubleType::get())) {
return TensorType::createContiguous(at::kComplexDouble, at::kCPU, {});
anjali411 marked this conversation as resolved.
Show resolved Hide resolved
}
TORCH_CHECK(false, "Unknown number type: ", typ->str());
}
Expand All @@ -1535,6 +1565,8 @@ inline TypePtr TensorType::fromBoolType() {
inline c10::optional<c10::ScalarType> tryScalarTypeFromJitType(const c10::TypePtr & type) {
if (type == FloatType::get()) {
return at::typeMetaToScalarType(c10::get_default_dtype());
} else if (type == ComplexDoubleType::get()) {
return at::typeMetaToScalarType(c10::get_default_complex_dtype());
anjali411 marked this conversation as resolved.
Show resolved Hide resolved
} else if (type == IntType::get()) {
return at::ScalarType::Long;
} else if (type == BoolType::get()) {
Expand All @@ -1547,7 +1579,7 @@ inline at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) {
auto result = tryScalarTypeFromJitType(type);
AT_ASSERTM(
result,
"Add new condition, expected Float, Int, or Bool but got",
"Add new condition, expected Float, Complex, Int, or Bool but got",
type->str());
return *result;
}
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/jit_type_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace c10 {
_(DictType) \
_(NumberType) \
_(FloatType) \
_(ComplexDoubleType) \
_(FutureType) \
_(RRefType) \
_(IntType) \
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/core/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ FloatTypePtr FloatType::get() {
static auto value = FloatType::create();
return value;
}
ComplexDoubleTypePtr ComplexDoubleType::get() {
static auto value = ComplexDoubleType::create();
return value;
}
BoolTypePtr BoolType::get() {
static auto value = BoolType::create();
return value;
Expand Down
46 changes: 45 additions & 1 deletion aten/src/ATen/test/ivalue_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,45 @@ TEST(IValueTest, TuplePrint) {
}
}

TEST(IValueTest, ComplexIValuePrint) {
{
IValue complex(c10::complex<double>(2, -3));
std::stringstream ss;
ss << complex;
ASSERT_EQ(ss.str(), "2.-3.j");
}

{
IValue complex(c10::complex<double>(2, 0));
std::stringstream ss;
ss << complex;
ASSERT_EQ(ss.str(), "2.+0.j");
}

{
IValue complex(c10::complex<double>(0, 3));
std::stringstream ss;
ss << complex;
ASSERT_EQ(ss.str(), "0.+3.j");
}
}

TEST(IValueTest, Complex) {
auto c = c10::complex<double>(2, 3);
auto c_ = c10::complex<double>(2, -3);
IValue c1(c), c2(c_), c3{at::Scalar(c)};

ASSERT_TRUE(c1.isComplexDouble());
ASSERT_TRUE(c3.isComplexDouble());

ASSERT_EQ(c, c1.toComplexDouble());
ASSERT_FALSE(c1 == c2);
ASSERT_TRUE(c1 == c3);

ASSERT_TRUE(c1.isScalar());
ASSERT_TRUE(c2.toScalar().equal(c_));
}

TEST(IValueTest, BasicFuture) {
auto f1 = c10::make_intrusive<ivalue::Future>(IntType::get());
ASSERT_FALSE(f1->completed());
Expand Down Expand Up @@ -484,7 +523,7 @@ TEST(IValueTest, IdentityComparisonAndHashing) {

TEST(IValueTest, getSubValues) {
// Scalars have no subvalues.
IValue integer(42), float_(1.5);
IValue integer(42), float_(1.5), complex(c10::complex<double>(2, 3));

IValue::HashAliasedIValues subvalues;

Expand All @@ -498,6 +537,11 @@ TEST(IValueTest, getSubValues) {

subvalues.clear();

complex.getSubValues(subvalues);
EXPECT_TRUE(subvalues.empty());

subvalues.clear();

at::Tensor t1(at::rand({3, 4})), t2(at::rand({3, 4}));
IValue tv1(t1), tv2(t2);
IValue list(std::vector<at::Tensor>{t1, t2});
Expand Down