From 3f9697b10e8e68f06e516d84a62b4f1698664de1 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Fri, 6 Nov 2020 16:57:51 -0800 Subject: [PATCH] Correctly compare Stream IValues (#47303) Summary: Stream IValue equality comparison was comparing wrong object. Pull Request resolved: https://github.com/pytorch/pytorch/pull/47303 Test Plan: Added a new C++ test Fixes #{issue number} Reviewed By: bdhirsh Differential Revision: D24752434 Pulled By: gmagogsfm fbshipit-source-id: 78bc7a812740485ebbc7cf0c06c2e671a7ccd26f --- aten/src/ATen/core/ivalue.cpp | 2 +- aten/src/ATen/test/ivalue_test.cpp | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index 207899a850dc..e27b326163ea 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -273,7 +273,7 @@ IValue IValue::equals(const IValue& rhs) const { case Tag::Tuple: return rhs.isTuple() && *lhs.toTuple() == *rhs.toTuple(); case Tag::Stream: - return rhs.isStream() && lhs.toStream() == lhs.toStream(); + return rhs.isStream() && lhs.toStream() == rhs.toStream(); case Tag::Device: return rhs.isDevice() && lhs.toDevice() == rhs.toDevice(); case Tag::GenericList: diff --git a/aten/src/ATen/test/ivalue_test.cpp b/aten/src/ATen/test/ivalue_test.cpp index 6474aa45d4dd..14e75205aa66 100644 --- a/aten/src/ATen/test/ivalue_test.cpp +++ b/aten/src/ATen/test/ivalue_test.cpp @@ -259,6 +259,18 @@ TEST(IValueTest, ListNestedEquality) { EXPECT_NE(c2, c3); } +TEST(IValueTest, StreamEquality) { + at::Device device1 = at::Device(kCUDA, 0); + at::Device device2 = at::Device(kCUDA, 1); + c10::Stream stream1 = c10::Stream(c10::Stream::Default::DEFAULT, device1); + c10::Stream stream2 = c10::Stream(c10::Stream::Default::DEFAULT, device2); + IValue lhs(stream1); + IValue rhs_different(stream2); + IValue rhs_same(stream1); + EXPECT_FALSE(lhs.equals(rhs_different).toBool()); + EXPECT_TRUE(lhs.equals(rhs_same).toBool()); +} + TEST(IValueTest, EnumEquality) { auto cu = std::make_shared(); IValue int_ivalue_1(1);