Skip to content

Commit

Permalink
Add complex IValues (#50883)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #50883

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D26003682

Pulled By: anjali411

fbshipit-source-id: f02967d2d236d740cd8647891f732f1d63098d3e
  • Loading branch information
anjali411 authored and facebook-github-bot committed Jan 22, 2021
1 parent 002d978 commit 9ac30d9
Show file tree
Hide file tree
Showing 10 changed files with 151 additions and 16 deletions.
15 changes: 15 additions & 0 deletions aten/src/ATen/core/ivalue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ TypePtr IValue::type() const {
return StorageType::get();
case Tag::Double:
return FloatType::get();
case Tag::ComplexDouble:
return ComplexDoubleType::get();
case Tag::Int:
return IntType::get();
case Tag::Bool:
Expand Down Expand Up @@ -284,6 +286,8 @@ IValue IValue::equals(const IValue& rhs) const {
return rhs.isStorage() && lhs.toStorage().unsafeGetStorageImpl() == rhs.toStorage().unsafeGetStorageImpl();
case Tag::Double:
return rhs.isDouble() && lhs.toDouble() == rhs.toDouble();
case Tag::ComplexDouble:
return rhs.isComplexDouble() && lhs.toComplexDouble() == rhs.toComplexDouble();
case Tag::Int:
return rhs.isInt() && lhs.toInt() == rhs.toInt();
case Tag::Bool:
Expand Down Expand Up @@ -352,6 +356,7 @@ size_t IValue::hash(const IValue& v) {
case Tag::Capsule:
case Tag::Generator:
case Tag::Quantizer:
case Tag::ComplexDouble:
case Tag::Enum:
case Tag::Stream:
case Tag::Uninitialized:
Expand Down Expand Up @@ -687,6 +692,16 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
<< std::setprecision(std::numeric_limits<double>::max_digits10)
<< v.toDouble()
<< std::setprecision(orig_prec);
} case IValue::Tag::ComplexDouble: {
c10::complex<double> d = v.toComplexDouble();
IValue real(d.real()), imag(std::abs(d.imag()));
auto sign = "";
if (d.imag() >= 0) {
sign = "+";
} else {
sign = "-";
}
return out << real << sign << imag << "j";
} case IValue::Tag::Int:
return out << v.toInt();
case IValue::Tag::Bool:
Expand Down
30 changes: 28 additions & 2 deletions aten/src/ATen/core/ivalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,19 @@ struct GenericDict;
struct Object;
struct PyObjectHolder;
struct EnumHolder;
// We need a ComplexHolder because currently the payloads in the Union
// only take 64 bits. Since ComplexDouble takes up 128 bits, and is too big
// to fit in the IValue directly, we indirect complex numbers through an intrusive
// pointer to ComplexHolder (which contains a c10::complex).
struct ComplexHolder : c10::intrusive_ptr_target {
public:
template <typename T>
ComplexHolder(c10::complex<T> c) {
val = convert<decltype(val), c10::complex<T>>(c);
}
ComplexHolder() {}
c10::complex<double> val;
};
} // namespace ivalue

// This is an owning wrapper for a c10::optional<std::vector<T>>
Expand Down Expand Up @@ -107,6 +120,7 @@ struct Capsule {
_(Tensor) \
_(Storage) \
_(Double) \
_(ComplexDouble) \
_(Int) \
_(Bool) \
_(Tuple) \
Expand Down Expand Up @@ -443,6 +457,12 @@ struct TORCH_API IValue final {
return payload.u.as_double;
}

// ComplexDouble
template <typename T>
IValue(c10::complex<T> c);
bool isComplexDouble() const { return Tag::ComplexDouble == tag; }
c10::complex<double> toComplexDouble() const;

// Future
IValue(c10::intrusive_ptr<ivalue::Future> v);
bool isFuture() const {
Expand Down Expand Up @@ -631,22 +651,28 @@ struct TORCH_API IValue final {
return i;
}

// Scalar, which gets encoded as either an Int or a Double
// Scalar, which gets encoded as either an Int, a Double or a ComplexDouble
IValue(at::Scalar s) : IValue() {
if (s.isFloatingPoint()) {
*this = s.toDouble();
} else if (s.isComplex()) {
*this = s.toComplexDouble();
} else {
*this = s.toLong();
}
}

bool isScalar() const {
return isDouble() || isInt();
return isDouble() || isInt() || isComplexDouble();
}

at::Scalar toScalar() const {
if (isDouble())
return toDouble();
else if (isInt())
return toInt();
else if (isComplexDouble())
return toComplexDouble();
throw std::runtime_error("IValue is not a Scalar");
}

Expand Down
13 changes: 13 additions & 0 deletions aten/src/ATen/core/ivalue_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ inline c10::intrusive_ptr<ivalue::EnumHolder> IValue::toEnumHolder() const& {
TORCH_INTERNAL_ASSERT(isEnum(), "Expected Enum but got ", tagKind());
return toIntrusivePtr<ivalue::EnumHolder>();
}
inline c10::complex<double> IValue::toComplexDouble() const {
TORCH_INTERNAL_ASSERT(isComplexDouble(), "Expected ComplexDouble but got ", tagKind());
auto ptr = toIntrusivePtr<ivalue::ComplexHolder>();
return (*ptr).val;
}
inline at::Tensor IValue::toTensor() && {
AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind());
auto result = std::move(payload.as_tensor);
Expand Down Expand Up @@ -754,6 +759,7 @@ DEFINE_TO(at::Storage, toStorage)
DEFINE_TO(c10::Stream, toStream)
DEFINE_TO(float, toDouble)
DEFINE_TO(double, toDouble)
DEFINE_TO(c10::complex<double>, toComplexDouble)
DEFINE_TO(unsigned char, toInt)
DEFINE_TO(signed char, toInt)
DEFINE_TO(unsigned short, toInt)
Expand Down Expand Up @@ -1222,6 +1228,13 @@ inline IValue::IValue(c10::intrusive_ptr<at::Quantizer> v)
payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
}

template <typename T>
inline IValue::IValue(c10::complex<T> c)
: tag(Tag::ComplexDouble), is_intrusive_ptr(true) {
auto v = c10::make_intrusive<ivalue::ComplexHolder>(c);
payload.u.as_intrusive_ptr = v.release();
}

inline const std::string& IValue::toStringRef() const {
AT_ASSERT(isString(), "Expected String but got ", tagKind());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
Expand Down
34 changes: 31 additions & 3 deletions aten/src/ATen/core/jit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,7 @@ using NumberTypePtr = std::shared_ptr<NumberType>;
// Subtype hierarchy for Number Types (NumberType as the base type):
// IntType <: NumberType
// FloatType <: NumberType
// ComplexDoubleType <:NumberType
struct TORCH_API NumberType : public Type {
static NumberTypePtr create() {
return NumberTypePtr(new NumberType()); // NOLINT(modernize-make-shared)
Expand Down Expand Up @@ -1156,6 +1157,33 @@ struct TORCH_API FloatType : public NumberType {
}
};

struct ComplexDoubleType;
using ComplexDoubleTypePtr = std::shared_ptr<ComplexDoubleType>;
// This type represents a Python float number
struct TORCH_API ComplexDoubleType : public NumberType {
static ComplexDoubleTypePtr create() {
return ComplexDoubleTypePtr(new ComplexDoubleType()); // NOLINT(modernize-make-shared)
}
bool operator==(const Type& rhs) const override {
return rhs.kind() == kind();
}
std::string str() const override {
return "complex";
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override {
return rhs->kind() == TypeKind::NumberType || NumberType::isSubtypeOfExt(rhs, why_not);
}
static const TypeKind Kind = TypeKind::ComplexDoubleType;
// global singleton
static ComplexDoubleTypePtr get();

private:
ComplexDoubleType() : NumberType(TypeKind::ComplexDoubleType) {}
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
return "complex";
}
};

struct IntType;
using IntTypePtr = std::shared_ptr<IntType>;
// This type represents a Python int number
Expand Down Expand Up @@ -1547,7 +1575,7 @@ inline at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) {
auto result = tryScalarTypeFromJitType(type);
AT_ASSERTM(
result,
"Add new condition, expected Float, Int, or Bool but got",
"Add new condition, expected Float, Complex, Int, or Bool but got",
type->str());
return *result;
}
Expand Down Expand Up @@ -2125,7 +2153,7 @@ struct TORCH_API ClassType : public NamedType {
torch::jit::Function* findForwardHook(const std::string& name) const;
const std::vector<torch::jit::Function*>& getForwardHooks() const;
const std::vector<torch::jit::Function*>& getForwardPreHooks() const;

void checkForwardPreHookSchema(
int pre_hook_idx,
const FunctionSchema& pre_hook_schema) const;
Expand Down Expand Up @@ -2206,7 +2234,7 @@ struct TORCH_API ClassType : public NamedType {
// List of hooks to be run before/after forward.
std::vector<torch::jit::Function*> forward_hooks_;
std::vector<torch::jit::Function*> forward_pre_hooks_;

// List of properties exposed by this class.
std::vector<Property> properties_;

Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/jit_type_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace c10 {
_(DictType) \
_(NumberType) \
_(FloatType) \
_(ComplexDoubleType) \
_(FutureType) \
_(RRefType) \
_(IntType) \
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/core/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ FloatTypePtr FloatType::get() {
static auto value = FloatType::create();
return value;
}
ComplexDoubleTypePtr ComplexDoubleType::get() {
static auto value = ComplexDoubleType::create();
return value;
}
BoolTypePtr BoolType::get() {
static auto value = BoolType::create();
return value;
Expand Down
46 changes: 45 additions & 1 deletion aten/src/ATen/test/ivalue_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,45 @@ TEST(IValueTest, TuplePrint) {
}
}

TEST(IValueTest, ComplexIValuePrint) {
{
IValue complex(c10::complex<double>(2, -3));
std::stringstream ss;
ss << complex;
ASSERT_EQ(ss.str(), "2.-3.j");
}

{
IValue complex(c10::complex<double>(2, 0));
std::stringstream ss;
ss << complex;
ASSERT_EQ(ss.str(), "2.+0.j");
}

{
IValue complex(c10::complex<double>(0, 3));
std::stringstream ss;
ss << complex;
ASSERT_EQ(ss.str(), "0.+3.j");
}
}

TEST(IValueTest, Complex) {
auto c = c10::complex<double>(2, 3);
auto c_ = c10::complex<double>(2, -3);
IValue c1(c), c2(c_), c3{at::Scalar(c)};

ASSERT_TRUE(c1.isComplexDouble());
ASSERT_TRUE(c3.isComplexDouble());

ASSERT_EQ(c, c1.toComplexDouble());
ASSERT_FALSE(c1 == c2);
ASSERT_TRUE(c1 == c3);

ASSERT_TRUE(c1.isScalar());
ASSERT_TRUE(c2.toScalar().equal(c_));
}

TEST(IValueTest, BasicFuture) {
auto f1 = c10::make_intrusive<ivalue::Future>(IntType::get());
ASSERT_FALSE(f1->completed());
Expand Down Expand Up @@ -484,7 +523,7 @@ TEST(IValueTest, IdentityComparisonAndHashing) {

TEST(IValueTest, getSubValues) {
// Scalars have no subvalues.
IValue integer(42), float_(1.5);
IValue integer(42), float_(1.5), complex(c10::complex<double>(2, 3));

IValue::HashAliasedIValues subvalues;

Expand All @@ -498,6 +537,11 @@ TEST(IValueTest, getSubValues) {

subvalues.clear();

complex.getSubValues(subvalues);
EXPECT_TRUE(subvalues.empty());

subvalues.clear();

at::Tensor t1(at::rand({3, 4})), t2(at::rand({3, 4}));
IValue tv1(t1), tv2(t2);
IValue list(std::vector<at::Tensor>{t1, t2});
Expand Down
20 changes: 10 additions & 10 deletions aten/src/ATen/test/type_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ class OneForward(Interface):

TEST(TypeEquality, TupleEquality) {
// Tuples should be structurally typed
auto type = TupleType::create({IntType::get(), TensorType::get(), FloatType::get()});
auto type2 = TupleType::create({IntType::get(), TensorType::get(), FloatType::get()});
auto type = TupleType::create({IntType::get(), TensorType::get(), FloatType::get(), ComplexDoubleType::get()});
auto type2 = TupleType::create({IntType::get(), TensorType::get(), FloatType::get(), ComplexDoubleType::get()});

EXPECT_EQ(*type, *type2);
}
Expand All @@ -185,24 +185,24 @@ TEST(TypeEquality, NamedTupleEquality) {
// Named tuples should compare equal if they share a name and field names
auto type = TupleType::createNamed(
"MyNamedTuple",
{"a", "b", "c"},
{IntType::get(), TensorType::get(), FloatType::get()});
{"a", "b", "c", "d"},
{IntType::get(), TensorType::get(), FloatType::get(), ComplexDoubleType::get()});
auto type2 = TupleType::createNamed(
"MyNamedTuple",
{"a", "b", "c"},
{IntType::get(), TensorType::get(), FloatType::get()});
{"a", "b", "c", "d"},
{IntType::get(), TensorType::get(), FloatType::get(), ComplexDoubleType::get()});
EXPECT_EQ(*type, *type2);

auto differentName = TupleType::createNamed(
"WowSoDifferent",
{"a", "b", "c"},
{IntType::get(), TensorType::get(), FloatType::get()});
{"a", "b", "c", "d"},
{IntType::get(), TensorType::get(), FloatType::get(), ComplexDoubleType::get()});
EXPECT_NE(*type, *differentName);

auto differentField = TupleType::createNamed(
"MyNamedTuple",
{"wow", "so", "different"},
{IntType::get(), TensorType::get(), FloatType::get()});
{"wow", "so", "very", "different"},
{IntType::get(), TensorType::get(), FloatType::get(), ComplexDoubleType::get()});
EXPECT_NE(*type, *differentField);
}
} // namespace c10
2 changes: 2 additions & 0 deletions torch/csrc/jit/python/pybind_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
case TypeKind::AnyClassType:
case TypeKind::AnyEnumType:
break;
case TypeKind::ComplexDoubleType:
AT_ASSERT(false);
case TypeKind::EnumType:
EnumTypePtr enum_type = type->expect<EnumType>();
py::object py_obj = py::reinterpret_borrow<py::object>(obj);
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/serialization/unpickler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
case AnyEnumType::Kind:
// no op, there is nothing to tag
break;
// TODO(@anjali411): Implement serialization/deserialization for complex numbers
case ComplexDoubleType::Kind:
case EnumType::Kind:
// TODO(gmagogsfm): Implement serialization/deserialization of Enum.
AT_ASSERT(false);
Expand Down

0 comments on commit 9ac30d9

Please sign in to comment.