Skip to content

Commit 23153e2

Browse files
authored
Add getDynamicValue for SizeMul and SizeDiv nodes. (#4042)
1 parent 39445aa commit 23153e2

File tree

3 files changed

+70
-0
lines changed

3 files changed

+70
-0
lines changed

test/cpp/test_ir.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ TEST_F(IrTest, TestSizeMulNode) {
250250
std::dynamic_pointer_cast<torch::lazy::DimensionNode>(size_node_mul);
251251

252252
EXPECT_EQ(dim_node_mul->getStaticValue(), 12);
253+
EXPECT_EQ(dim_node_mul->getDynamicValue(), 12);
253254

254255
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
255256
// Lower the SizeAddNode and execute the GetDimensionSize.
@@ -258,6 +259,31 @@ TEST_F(IrTest, TestSizeMulNode) {
258259
});
259260
}
260261

262+
TEST_F(IrTest, TestSizeMulNodeDynamic) {
263+
int64_t num_non_zero_element = 1;
264+
int64_t num_row = 10;
265+
int64_t num_col = 10;
266+
torch::lazy::NodePtr nonzero_node =
267+
CreateNonZeroNode2d(num_non_zero_element, num_row, num_col);
268+
torch::lazy::Value node_with_dynamism = torch::lazy::Value(nonzero_node, 0);
269+
270+
// static value = 100, dynamic value = 1
271+
torch::lazy::NodePtr size_node_nonzero_0 =
272+
torch::lazy::MakeNode<SizeNode>(node_with_dynamism, 0);
273+
// static value = 2, dynamic value = 2
274+
torch::lazy::NodePtr size_node_nonzero_1 =
275+
torch::lazy::MakeNode<SizeNode>(node_with_dynamism, 1);
276+
277+
torch::lazy::NodePtr node_mul = torch::lazy::MakeNode<SizeMul>(
278+
torch::lazy::Value(size_node_nonzero_0, 0),
279+
torch::lazy::Value(size_node_nonzero_1, 0));
280+
281+
std::shared_ptr<torch::lazy::DimensionNode> dim_node_mul =
282+
std::dynamic_pointer_cast<torch::lazy::DimensionNode>(node_mul);
283+
EXPECT_EQ(dim_node_mul->getStaticValue(), 200);
284+
EXPECT_EQ(dim_node_mul->getDynamicValue(), 2);
285+
}
286+
261287
TEST_F(IrTest, TestSizeDivNode) {
262288
torch::lazy::NodePtr scalar_node =
263289
ScalarOp(1.0, xla::ShapeUtil::MakeShape(xla::F32, {12, 5}));
@@ -271,6 +297,7 @@ TEST_F(IrTest, TestSizeDivNode) {
271297
std::dynamic_pointer_cast<torch::lazy::DimensionNode>(size_node_div);
272298

273299
EXPECT_EQ(dim_node_div->getStaticValue(), 2);
300+
EXPECT_EQ(dim_node_div->getDynamicValue(), 2);
274301

275302
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
276303
// Lower the SizeAddNode and execute the GetDimensionSize.
@@ -279,5 +306,30 @@ TEST_F(IrTest, TestSizeDivNode) {
279306
});
280307
}
281308

309+
TEST_F(IrTest, TestSizeDivNodeDynamic) {
310+
int64_t num_non_zero_element = 1;
311+
int64_t num_row = 10;
312+
int64_t num_col = 10;
313+
torch::lazy::NodePtr nonzero_node =
314+
CreateNonZeroNode2d(num_non_zero_element, num_row, num_col);
315+
torch::lazy::Value node_with_dynamism = torch::lazy::Value(nonzero_node, 0);
316+
317+
// static value = 100, dynamic value = 1
318+
torch::lazy::NodePtr size_node_nonzero_0 =
319+
torch::lazy::MakeNode<SizeNode>(node_with_dynamism, 0);
320+
// static value = 2, dynamic value = 2
321+
torch::lazy::NodePtr size_node_nonzero_1 =
322+
torch::lazy::MakeNode<SizeNode>(node_with_dynamism, 1);
323+
324+
torch::lazy::NodePtr node_div = torch::lazy::MakeNode<SizeDiv>(
325+
torch::lazy::Value(size_node_nonzero_0, 0),
326+
torch::lazy::Value(size_node_nonzero_1, 0));
327+
328+
std::shared_ptr<torch::lazy::DimensionNode> dim_node_div =
329+
std::dynamic_pointer_cast<torch::lazy::DimensionNode>(node_div);
330+
EXPECT_EQ(dim_node_div->getStaticValue(), 50);
331+
EXPECT_EQ(dim_node_div->getDynamicValue(), 0);
332+
}
333+
282334
} // namespace cpp_test
283335
} // namespace torch_xla

torch_xla/csrc/ops/dynamic_ir.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,14 @@ SizeMul::SizeMul(torch::lazy::Value a, torch::lazy::Value b)
9898
upper_bound_ = dim_node_0->getStaticValue() * dim_node_1->getStaticValue();
9999
};
100100

101+
int64_t SizeMul::getDynamicValue() const {
102+
const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0));
103+
const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1));
104+
XLA_CHECK(dim_node_0);
105+
XLA_CHECK(dim_node_1);
106+
return dim_node_0->getDynamicValue() * dim_node_1->getDynamicValue();
107+
}
108+
101109
std::string SizeMul::ToString() const { return "SizeMul"; }
102110

103111
XlaOpVector SizeMul::Lower(LoweringContext* loctx) const {
@@ -121,6 +129,14 @@ SizeDiv::SizeDiv(torch::lazy::Value a, torch::lazy::Value b)
121129
upper_bound_ = dim_node_0->getStaticValue() / dim_node_1->getStaticValue();
122130
};
123131

132+
int64_t SizeDiv::getDynamicValue() const {
133+
const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0));
134+
const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1));
135+
XLA_CHECK(dim_node_0);
136+
XLA_CHECK(dim_node_1);
137+
return dim_node_0->getDynamicValue() / dim_node_1->getDynamicValue();
138+
}
139+
124140
std::string SizeDiv::ToString() const { return "SizeDiv"; }
125141

126142
XlaOpVector SizeDiv::Lower(LoweringContext* loctx) const {

torch_xla/csrc/ops/dynamic_ir.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class SizeAdd : public XlaNode, public torch::lazy::DimensionNode {
6868
class SizeMul : public XlaNode, public torch::lazy::DimensionNode {
6969
public:
7070
SizeMul(torch::lazy::Value a, torch::lazy::Value b);
71+
int64_t getDynamicValue() const override;
7172
int64_t getStaticValue() const override { return upper_bound_; }
7273
bool isSymbolic() const override { return true; }
7374
std::string ToString() const override;
@@ -80,6 +81,7 @@ class SizeMul : public XlaNode, public torch::lazy::DimensionNode {
8081
class SizeDiv : public XlaNode, public torch::lazy::DimensionNode {
8182
public:
8283
SizeDiv(torch::lazy::Value a, torch::lazy::Value b);
84+
int64_t getDynamicValue() const override;
8385
int64_t getStaticValue() const override { return upper_bound_; }
8486
bool isSymbolic() const override { return true; }
8587
std::string ToString() const override;

0 commit comments

Comments
 (0)