@@ -831,6 +831,66 @@ TEST_F(AtenXlaTensorTest, TestSumInDimsKeep) {
831831 });
832832}
833833
834+ TEST_F (AtenXlaTensorTest, TestNorm) {
835+ at::Tensor a = at::rand ({4 , 3 , 4 }, at::TensorOptions (at::kFloat ));
836+ at::Tensor b = at::norm (a);
837+ ForEachDevice ([&](const Device& device) {
838+ at::Tensor xla_a = bridge::CreateXlaTensor (a, device);
839+ at::Tensor xla_b = at::norm (xla_a);
840+ AllClose (b, xla_b);
841+ });
842+ }
843+
844+ TEST_F (AtenXlaTensorTest, TestNormInDim) {
845+ at::Tensor a = at::rand ({4 , 3 , 4 }, at::TensorOptions (at::kFloat ));
846+ at::Tensor b = at::norm (a, 2 , {1 }, /* keepdim=*/ false );
847+ ForEachDevice ([&](const Device& device) {
848+ at::Tensor xla_a = bridge::CreateXlaTensor (a, device);
849+ at::Tensor xla_b = at::norm (xla_a, 2 , {1 }, /* keepdim=*/ false );
850+ AllClose (b, xla_b);
851+ });
852+ }
853+
854+ TEST_F (AtenXlaTensorTest, TestNormInDims) {
855+ at::Tensor a = at::rand ({4 , 3 , 4 }, at::TensorOptions (at::kFloat ));
856+ at::Tensor b = at::norm (a, 2 , {1 , 2 }, /* keepdim=*/ false );
857+ ForEachDevice ([&](const Device& device) {
858+ at::Tensor xla_a = bridge::CreateXlaTensor (a, device);
859+ at::Tensor xla_b = at::norm (xla_a, 2 , {1 , 2 }, /* keepdim=*/ false );
860+ AllClose (b, xla_b);
861+ });
862+ }
863+
864+ TEST_F (AtenXlaTensorTest, TestNormInDimsKeep) {
865+ at::Tensor a = at::rand ({4 , 3 , 4 }, at::TensorOptions (at::kFloat ));
866+ at::Tensor b = at::norm (a, 2 , {1 , 2 }, /* keepdim=*/ true );
867+ ForEachDevice ([&](const Device& device) {
868+ at::Tensor xla_a = bridge::CreateXlaTensor (a, device);
869+ at::Tensor xla_b = at::norm (xla_a, 2 , {1 , 2 }, /* keepdim=*/ true );
870+ AllClose (b, xla_b);
871+ });
872+ }
873+
874+ TEST_F (AtenXlaTensorTest, TestNormGeneral) {
875+ at::Tensor a = at::rand ({4 , 3 , 4 }, at::TensorOptions (at::kFloat ));
876+ at::Tensor b = at::norm (a, 3.5 );
877+ ForEachDevice ([&](const Device& device) {
878+ at::Tensor xla_a = bridge::CreateXlaTensor (a, device);
879+ at::Tensor xla_b = at::norm (xla_a, 3.5 );
880+ AllClose (b, xla_b);
881+ });
882+ }
883+
884+ TEST_F (AtenXlaTensorTest, TestNormNuclear) {
885+ at::Tensor a = at::rand ({4 , 3 , 4 }, at::TensorOptions (at::kFloat ));
886+ at::Tensor b = at::norm (a, 1 );
887+ ForEachDevice ([&](const Device& device) {
888+ at::Tensor xla_a = bridge::CreateXlaTensor (a, device);
889+ at::Tensor xla_b = at::norm (xla_a, 1 );
890+ AllClose (b, xla_b);
891+ });
892+ }
893+
834894TEST_F (AtenXlaTensorTest, TestProd) {
835895 at::Tensor a = at::rand ({4 , 3 , 4 }, at::TensorOptions (at::kFloat ));
836896 at::Tensor b = at::prod (a);
0 commit comments