Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions test/cpp/test_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "torch_xla/csrc/ir.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/arithmetic_ir_ops.h"
#include "torch_xla/csrc/ops/dynamic_ir.h"
#include "torch_xla/csrc/ops/ops.h"
#include "torch_xla/csrc/ops/scalar.h"
#include "torch_xla/csrc/ops/select.h"
Expand Down Expand Up @@ -94,5 +95,91 @@ TEST(IrTest, TestScopePusherWithDebugging) {
FLAGS_torch_lazy_ir_debug = restore_FLAGS_torch_lazy_ir_debug;
}

TEST(IrTest, TestSizeNode) {
torch::lazy::NodePtr scalar_node =
ScalarOp(1.0, xla::ShapeUtil::MakeShape(xla::F32, {3, 4}));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps a silly question, for the "op" in "ScalarOp", is it operator or operand? Because it's a scalar, so it looks like an operand. But the torch::lazy::Node is a "operation a Node can be associated to".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a Node, check

return torch::lazy::MakeNode<Scalar>(value, std::move(shape));

torch::lazy::NodePtr size_node_0 =
torch::lazy::MakeNode<SizeNode>(scalar_node, 0);
torch::lazy::NodePtr size_node_1 =
torch::lazy::MakeNode<SizeNode>(scalar_node, 1);
std::shared_ptr<torch::lazy::DimensionNode> dim_node_0 =
std::dynamic_pointer_cast<torch::lazy::DimensionNode>(size_node_0);
std::shared_ptr<torch::lazy::DimensionNode> dim_node_1 =
std::dynamic_pointer_cast<torch::lazy::DimensionNode>(size_node_1);

EXPECT_EQ(dim_node_0->getStaticValue(), 3);
EXPECT_EQ(dim_node_1->getStaticValue(), 4);

ForEachDevice([&](const torch::lazy::BackendDevice& device) {
// Lower the SizeNode and execute the GetDimensionSize.
auto results = ExecuteAndFetch({size_node_0, size_node_1}, device);
EXPECT_EQ(results[0].sum().item().toInt(), 3);
EXPECT_EQ(results[1].sum().item().toInt(), 4);
});
}

TEST(IrTest, TestSizeAddNode) {
torch::lazy::NodePtr scalar_node =
ScalarOp(1.0, xla::ShapeUtil::MakeShape(xla::F32, {3, 4}));
torch::lazy::NodePtr size_node_0 =
torch::lazy::MakeNode<SizeNode>(scalar_node, 0);
torch::lazy::NodePtr size_node_1 =
torch::lazy::MakeNode<SizeNode>(scalar_node, 1);
torch::lazy::NodePtr size_node_add =
torch::lazy::MakeNode<SizeAdd>(size_node_0, size_node_1);
std::shared_ptr<torch::lazy::DimensionNode> dim_node_add =
std::dynamic_pointer_cast<torch::lazy::DimensionNode>(size_node_add);

EXPECT_EQ(dim_node_add->getStaticValue(), 7);

ForEachDevice([&](const torch::lazy::BackendDevice& device) {
// Lower the SizeAddNode and execute the GetDimensionSize.
auto results = ExecuteAndFetch({size_node_add}, device);
EXPECT_EQ(results[0].sum().item().toInt(), 7);
});
}

TEST(IrTest, TestSizeMulNode) {
torch::lazy::NodePtr scalar_node =
ScalarOp(1.0, xla::ShapeUtil::MakeShape(xla::F32, {3, 4}));
torch::lazy::NodePtr size_node_0 =
torch::lazy::MakeNode<SizeNode>(scalar_node, 0);
torch::lazy::NodePtr size_node_1 =
torch::lazy::MakeNode<SizeNode>(scalar_node, 1);
torch::lazy::NodePtr size_node_mul =
torch::lazy::MakeNode<SizeMul>(size_node_0, size_node_1);
std::shared_ptr<torch::lazy::DimensionNode> dim_node_mul =
std::dynamic_pointer_cast<torch::lazy::DimensionNode>(size_node_mul);

EXPECT_EQ(dim_node_mul->getStaticValue(), 12);

ForEachDevice([&](const torch::lazy::BackendDevice& device) {
// Lower the SizeAddNode and execute the GetDimensionSize.
auto results = ExecuteAndFetch({size_node_mul}, device);
EXPECT_EQ(results[0].sum().item().toInt(), 12);
});
}

TEST(IrTest, TestSizeDivNode) {
torch::lazy::NodePtr scalar_node =
ScalarOp(1.0, xla::ShapeUtil::MakeShape(xla::F32, {12, 5}));
torch::lazy::NodePtr size_node_0 =
torch::lazy::MakeNode<SizeNode>(scalar_node, 0);
torch::lazy::NodePtr size_node_1 =
torch::lazy::MakeNode<SizeNode>(scalar_node, 1);
torch::lazy::NodePtr size_node_div =
torch::lazy::MakeNode<SizeDiv>(size_node_0, size_node_1);
std::shared_ptr<torch::lazy::DimensionNode> dim_node_div =
std::dynamic_pointer_cast<torch::lazy::DimensionNode>(size_node_div);

EXPECT_EQ(dim_node_div->getStaticValue(), 2);

ForEachDevice([&](const torch::lazy::BackendDevice& device) {
// Lower the SizeAddNode and execute the GetDimensionSize.
auto results = ExecuteAndFetch({size_node_div}, device);
EXPECT_EQ(results[0].sum().item().toInt(), 2);
});
}

} // namespace cpp_test
} // namespace torch_xla
55 changes: 24 additions & 31 deletions test/cpp/test_symint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "torch_xla/csrc/generated/LazyIr.h"
#include "torch_xla/csrc/ir.h"
#include "torch_xla/csrc/ops/dynamic_ir.h"
#include "torch_xla/csrc/ops/expand.h"
#include "torch_xla/csrc/ops/ops.h"
#include "torch_xla/csrc/torch_util.h"
using std::cerr;
Expand Down Expand Up @@ -58,55 +59,49 @@ TEST(SymintTest, TestSaticSymints) {
TEST(SymintTest, TestDynamicSymint) {
torch::lazy::Value scalar_value =
torch::lazy::Value(ScalarOp(1.0, xla::F32), 0);
// Manully assign the torch::lazy::shape to avoid calling shape fn in this
// test. Note that we have to use one of those codegen ops so they take
// lazy::shape in constructor.
std::vector<torch::lazy::Shape> abs_lazy_shapes = {
torch::lazy::Shape(torch::kFloat, {1})};
torch::lazy::NodePtr abs_node =
torch::lazy::MakeNode<Abs>(scalar_value, std::move(abs_lazy_shapes));
torch::lazy::Value abs_value = torch::lazy::Value(abs_node, 0);
std::vector<int64_t> target_size = {2, 3, 5};
torch::lazy::NodePtr expand_node =
torch::lazy::MakeNode<Expand>(scalar_value, target_size);
torch::lazy::Value expand_value = torch::lazy::Value(expand_node, 0);
torch::lazy::NodePtr size_node =
torch::lazy::MakeNode<SizeNode>(abs_value, /*dim=*/0);
torch::lazy::MakeNode<SizeNode>(expand_value, /*dim=*/0);
auto symint_node =
c10::make_intrusive<torch::lazy::SymIntNodeImpl>(size_node);
// This is not really a dynamic size per say but it is a symint that wraps
// around a SizeNode instead of a scalar.
// This is not a dynamic size from xla perspective but it is a symint that
// wraps around a SizeNode instead of a scalar.
c10::SymInt dynamic_symint = symint_node->toSymInt();

SymIntElements si_element(dynamic_symint);

std::vector<int64_t> upper_bound = si_element.GetUpperBounds();
EXPECT_EQ(upper_bound.size(), 1);
EXPECT_EQ(upper_bound[0], 1);
EXPECT_EQ(upper_bound[0], 2);

std::vector<bool> dynamic_dims = si_element.GetDynamicDims();
EXPECT_EQ(dynamic_dims.size(), 1);
EXPECT_EQ(dynamic_dims[0], true);

std::vector<torch::lazy::NodePtr> size_nodes = si_element.GetSizeNodes();
EXPECT_EQ(size_nodes.size(), 1);
EXPECT_TRUE(si_element.GetSizeNode(0) != nullptr);
EXPECT_EQ(si_element.GetSizeNode(0), size_node);
}

TEST(SymintTest, TestDynamicSymints) {
torch::lazy::Value scalar_value =
torch::lazy::Value(ScalarOp(1.0, xla::F32), 0);
// Assign a incorrect 3d shape for the test purpose
std::vector<torch::lazy::Shape> abs_lazy_shapes = {
torch::lazy::Shape(torch::kFloat, {10, 20, 30})};
torch::lazy::NodePtr abs_node =
torch::lazy::MakeNode<Abs>(scalar_value, std::move(abs_lazy_shapes));

std::vector<int64_t> target_size = {2, 3, 5};
torch::lazy::NodePtr expand_node =
torch::lazy::MakeNode<Expand>(scalar_value, target_size);
torch::lazy::Value expand_value = torch::lazy::Value(expand_node, 0);
std::vector<c10::SymInt> dynamic_symints;
std::vector<torch::lazy::NodePtr> size_nodes;
for (int i = 0; i < 3; i++) {
torch::lazy::Value abs_value = torch::lazy::Value(abs_node, 0);
torch::lazy::NodePtr size_node =
torch::lazy::MakeNode<SizeNode>(abs_value, /*dim=*/i);
torch::lazy::MakeNode<SizeNode>(expand_value, /*dim=*/i);
size_nodes.push_back(size_node);
auto symint_node =
c10::make_intrusive<torch::lazy::SymIntNodeImpl>(size_node);
// This is not really a dynamic size per say but it is a symint that wraps
// around a SizeNode instead of a scalar.
// This is not a dynamic size from xla perspective but it is a symint that
// wraps around a SizeNode instead of a scalar.
dynamic_symints.push_back(symint_node->toSymInt());
}

Expand All @@ -115,18 +110,16 @@ TEST(SymintTest, TestDynamicSymints) {

std::vector<int64_t> upper_bound = si_element.GetUpperBounds();
EXPECT_EQ(upper_bound.size(), 3);
EXPECT_EQ(upper_bound, std::vector<int64_t>({10, 20, 30}));
EXPECT_EQ(upper_bound, std::vector<int64_t>({2, 3, 5}));

std::vector<bool> dynamic_dims = si_element.GetDynamicDims();
EXPECT_EQ(dynamic_dims.size(), 3);
EXPECT_EQ(dynamic_dims, std::vector<bool>({true, true, true}));

std::vector<torch::lazy::NodePtr> size_nodes = si_element.GetSizeNodes();
EXPECT_EQ(size_nodes.size(), 3);
// look up the SizeNode for dimension 0
EXPECT_TRUE(si_element.GetSizeNode(0) != nullptr);
EXPECT_TRUE(si_element.GetSizeNode(1) != nullptr);
EXPECT_TRUE(si_element.GetSizeNode(2) != nullptr);
std::vector<torch::lazy::NodePtr> si_element_size_nodes =
si_element.GetSizeNodes();
EXPECT_EQ(si_element_size_nodes.size(), 3);
EXPECT_EQ(si_element_size_nodes, size_nodes);
}

} // namespace cpp_test
Expand Down
79 changes: 31 additions & 48 deletions torch_xla/csrc/ops/dynamic_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace torch_xla {

SizeNode::SizeNode(torch::lazy::Value input, size_t dim)
: XlaNode(torch::lazy::OpKind{c10::Symbol::fromQualString("aten::size")},
{input}, xla::ShapeUtil::MakeShape(xla::S32, {}), 1,
{input}, xla::ShapeUtil::MakeShape(xla::S64, {}), 1,
torch::lazy::MHash(dim)),
dim_(dim){};

Expand All @@ -23,95 +23,78 @@ XlaOpVector SizeNode::Lower(LoweringContext* loctx) const {
}

int64_t SizeNode::getStaticValue() const {
return operand(0).node->shape(0).size(dim_);
}

bool SizeNode::isSymbolic() const {
auto symbolic_vec = operand(0).node->shape(0).is_symbolic();
if (!symbolic_vec.has_value()) {
return true;
}
return symbolic_vec->at(dim_);
// Not all IR has torch::lazy::shape now, use xla::shape to unblock
// the development.
return dynamic_cast<const XlaNode*>(operand(0).node)
->xla_shape(operand(0).index)
.dimensions(dim_);
}

std::string SizeNode::ToString() const { return "SizeNode"; }

SizeAdd::SizeAdd(torch::lazy::Value a, torch::lazy::Value b)
: XlaNode(torch::lazy::OpKind{c10::Symbol::fromQualString("aten::add")},
{a, b}, xla::ShapeUtil::MakeShape(xla::S32, {}), 1){};
{a, b}, xla::ShapeUtil::MakeShape(xla::S64, {}), 1) {
// SizeAdd can only be perfomed between two DimensionNode
XLA_CHECK(DimCast(operand(0)));
XLA_CHECK(DimCast(operand(1)));
};

int64_t SizeAdd::getStaticValue() const {
return dynamic_cast<const torch::lazy::DimensionNode*>(operand(0).node)
->getStaticValue() +
dynamic_cast<const torch::lazy::DimensionNode*>(operand(1).node)
->getStaticValue();
}

bool SizeAdd::isSymbolic() const {
return DimCast(operand(0))->isSymbolic() || DimCast(operand(1))->isSymbolic();
return DimCast(operand(0))->getStaticValue() +
DimCast(operand(1))->getStaticValue();
}

std::string SizeAdd::ToString() const { return "SizeAdd"; }

XlaOpVector SizeAdd::Lower(LoweringContext* loctx) const {
auto input1 = loctx->GetOutputOp(operand(0));
auto input2 = loctx->GetOutputOp(operand(1));
return ReturnOp(
(xla::GetDimensionSize(input1, 0) + xla::GetDimensionSize(input2, 0)),
loctx);
return ReturnOp((input1 + input2), loctx);
}

SizeMul::SizeMul(torch::lazy::Value a, torch::lazy::Value b)
: XlaNode(torch::lazy::OpKind{c10::Symbol::fromQualString("aten::mul")},
{a, b}, xla::ShapeUtil::MakeShape(xla::S32, {}), 1){};
{a, b}, xla::ShapeUtil::MakeShape(xla::S64, {}), 1) {
// SizeMul can only be perfomed between two DimensionNode
XLA_CHECK(DimCast(operand(0)));
XLA_CHECK(DimCast(operand(1)));
};

int64_t SizeMul::getStaticValue() const {
return dynamic_cast<const torch::lazy::DimensionNode*>(operand(0).node)
->getStaticValue() *
dynamic_cast<const torch::lazy::DimensionNode*>(operand(1).node)
->getStaticValue();
}

bool SizeMul::isSymbolic() const {
return DimCast(operand(0))->isSymbolic() || DimCast(operand(1))->isSymbolic();
return DimCast(operand(0))->getStaticValue() *
DimCast(operand(1))->getStaticValue();
}

std::string SizeMul::ToString() const { return "SizeMul"; }

XlaOpVector SizeMul::Lower(LoweringContext* loctx) const {
auto input1 = loctx->GetOutputOp(operand(0));
auto input2 = loctx->GetOutputOp(operand(1));
return ReturnOp(xla::Mul(xla::GetDimensionSize(input1, 0),
xla::GetDimensionSize(input2, 0)),
loctx);
return ReturnOp(input1 * input2, loctx);
}

SizeDiv::SizeDiv(torch::lazy::Value a, torch::lazy::Value b)
: XlaNode(torch::lazy::OpKind{c10::Symbol::fromQualString("aten::div")},
{a, b}, xla::ShapeUtil::MakeShape(xla::S32, {}), 1){};
{a, b}, xla::ShapeUtil::MakeShape(xla::S64, {}), 1) {
// SizeDiv can only be perfomed between two DimensionNode
XLA_CHECK(DimCast(operand(0)));
XLA_CHECK(DimCast(operand(1)));
};

int64_t SizeDiv::getStaticValue() const {
XLA_CHECK(dynamic_cast<const torch::lazy::DimensionNode*>(operand(1).node)
->getStaticValue() != 0)
XLA_CHECK(DimCast(operand(1))->getStaticValue() != 0)
<< "Can't divide a dimension by zero";
return dynamic_cast<const torch::lazy::DimensionNode*>(operand(0).node)
->getStaticValue() /
dynamic_cast<const torch::lazy::DimensionNode*>(operand(1).node)
->getStaticValue();
}

bool SizeDiv::isSymbolic() const {
return DimCast(operand(0))->isSymbolic() || DimCast(operand(1))->isSymbolic();
return DimCast(operand(0))->getStaticValue() /
DimCast(operand(1))->getStaticValue();
}

std::string SizeDiv::ToString() const { return "SizeDiv"; }

XlaOpVector SizeDiv::Lower(LoweringContext* loctx) const {
auto input1 = loctx->GetOutputOp(operand(0));
auto input2 = loctx->GetOutputOp(operand(1));
return ReturnOp(xla::Div(xla::GetDimensionSize(input1, 0),
xla::GetDimensionSize(input2, 0)),
loctx);
return ReturnOp(xla::Div(input1, input2), loctx);
}

} // namespace torch_xla
8 changes: 4 additions & 4 deletions torch_xla/csrc/ops/dynamic_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class SizeNode : public XlaNode, public torch::lazy::DimensionNode {
public:
SizeNode(torch::lazy::Value input, size_t dim);
int64_t getStaticValue() const override;
bool isSymbolic() const override;
bool isSymbolic() const override { return true; }
std::string ToString() const override;
virtual XlaOpVector Lower(LoweringContext* loctx) const override;

Expand All @@ -51,7 +51,7 @@ class SizeAdd : public XlaNode, public torch::lazy::DimensionNode {
public:
SizeAdd(torch::lazy::Value a, torch::lazy::Value b);
int64_t getStaticValue() const override;
bool isSymbolic() const override;
bool isSymbolic() const override { return true; }
std::string ToString() const override;
virtual XlaOpVector Lower(LoweringContext* loctx) const override;
};
Expand All @@ -60,7 +60,7 @@ class SizeMul : public XlaNode, public torch::lazy::DimensionNode {
public:
SizeMul(torch::lazy::Value a, torch::lazy::Value b);
int64_t getStaticValue() const override;
bool isSymbolic() const override;
bool isSymbolic() const override { return true; }
std::string ToString() const override;
virtual XlaOpVector Lower(LoweringContext* loctx) const override;
};
Expand All @@ -69,7 +69,7 @@ class SizeDiv : public XlaNode, public torch::lazy::DimensionNode {
public:
SizeDiv(torch::lazy::Value a, torch::lazy::Value b);
int64_t getStaticValue() const override;
bool isSymbolic() const override;
bool isSymbolic() const override { return true; }
std::string ToString() const override;
virtual XlaOpVector Lower(LoweringContext* loctx) const override;
};
Expand Down