@@ -250,6 +250,7 @@ TEST_F(IrTest, TestSizeMulNode) {
250250 std::dynamic_pointer_cast<torch::lazy::DimensionNode>(size_node_mul);
251251
252252 EXPECT_EQ (dim_node_mul->getStaticValue (), 12 );
253+ EXPECT_EQ (dim_node_mul->getDynamicValue (), 12 );
253254
254255 ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
255256 // Lower the SizeAddNode and execute the GetDimensionSize.
@@ -258,6 +259,31 @@ TEST_F(IrTest, TestSizeMulNode) {
258259 });
259260}
260261
262+ TEST_F (IrTest, TestSizeMulNodeDynamic) {
263+ int64_t num_non_zero_element = 1 ;
264+ int64_t num_row = 10 ;
265+ int64_t num_col = 10 ;
266+ torch::lazy::NodePtr nonzero_node =
267+ CreateNonZeroNode2d (num_non_zero_element, num_row, num_col);
268+ torch::lazy::Value node_with_dynamism = torch::lazy::Value (nonzero_node, 0 );
269+
270+ // static value = 100, dynamic value = 1
271+ torch::lazy::NodePtr size_node_nonzero_0 =
272+ torch::lazy::MakeNode<SizeNode>(node_with_dynamism, 0 );
273+ // static value = 2, dynamic value = 2
274+ torch::lazy::NodePtr size_node_nonzero_1 =
275+ torch::lazy::MakeNode<SizeNode>(node_with_dynamism, 1 );
276+
277+ torch::lazy::NodePtr node_mul = torch::lazy::MakeNode<SizeMul>(
278+ torch::lazy::Value (size_node_nonzero_0, 0 ),
279+ torch::lazy::Value (size_node_nonzero_1, 0 ));
280+
281+ std::shared_ptr<torch::lazy::DimensionNode> dim_node_mul =
282+ std::dynamic_pointer_cast<torch::lazy::DimensionNode>(node_mul);
283+ EXPECT_EQ (dim_node_mul->getStaticValue (), 200 );
284+ EXPECT_EQ (dim_node_mul->getDynamicValue (), 2 );
285+ }
286+
261287TEST_F (IrTest, TestSizeDivNode) {
262288 torch::lazy::NodePtr scalar_node =
263289 ScalarOp (1.0 , xla::ShapeUtil::MakeShape (xla::F32, {12 , 5 }));
@@ -271,6 +297,7 @@ TEST_F(IrTest, TestSizeDivNode) {
271297 std::dynamic_pointer_cast<torch::lazy::DimensionNode>(size_node_div);
272298
273299 EXPECT_EQ (dim_node_div->getStaticValue (), 2 );
300+ EXPECT_EQ (dim_node_div->getDynamicValue (), 2 );
274301
275302 ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
276303 // Lower the SizeAddNode and execute the GetDimensionSize.
@@ -279,5 +306,30 @@ TEST_F(IrTest, TestSizeDivNode) {
279306 });
280307}
281308
309+ TEST_F (IrTest, TestSizeDivNodeDynamic) {
310+ int64_t num_non_zero_element = 1 ;
311+ int64_t num_row = 10 ;
312+ int64_t num_col = 10 ;
313+ torch::lazy::NodePtr nonzero_node =
314+ CreateNonZeroNode2d (num_non_zero_element, num_row, num_col);
315+ torch::lazy::Value node_with_dynamism = torch::lazy::Value (nonzero_node, 0 );
316+
317+ // static value = 100, dynamic value = 1
318+ torch::lazy::NodePtr size_node_nonzero_0 =
319+ torch::lazy::MakeNode<SizeNode>(node_with_dynamism, 0 );
320+ // static value = 2, dynamic value = 2
321+ torch::lazy::NodePtr size_node_nonzero_1 =
322+ torch::lazy::MakeNode<SizeNode>(node_with_dynamism, 1 );
323+
324+ torch::lazy::NodePtr node_div = torch::lazy::MakeNode<SizeDiv>(
325+ torch::lazy::Value (size_node_nonzero_0, 0 ),
326+ torch::lazy::Value (size_node_nonzero_1, 0 ));
327+
328+ std::shared_ptr<torch::lazy::DimensionNode> dim_node_div =
329+ std::dynamic_pointer_cast<torch::lazy::DimensionNode>(node_div);
330+ EXPECT_EQ (dim_node_div->getStaticValue (), 50 );
331+ EXPECT_EQ (dim_node_div->getDynamicValue (), 0 );
332+ }
333+
282334} // namespace cpp_test
283335} // namespace torch_xla
0 commit comments