diff --git a/tensorflow_graphics/geometry/transformation/dual_quaternion.py b/tensorflow_graphics/geometry/transformation/dual_quaternion.py index fc01d8b96..3c10d75d2 100644 --- a/tensorflow_graphics/geometry/transformation/dual_quaternion.py +++ b/tensorflow_graphics/geometry/transformation/dual_quaternion.py @@ -78,5 +78,51 @@ def conjugate(dual_quaternion: type_alias.TensorLike, axis=-1) +def multiply(dual_quaternion1: type_alias.TensorLike, + dual_quaternion2: type_alias.TensorLike, + name: str = "dual_quaternion_multiply"): + """Multiplies two dual quaternions. + + Note: + In the following, A1 to An are optional batch dimensions. + + Args: + dual_quaternion1: A TensorLike of shape `[A1, ..., An, 8]`, where the last + dimension represents a dual quaternion. + dual_quaternion2: A TensorLike of shape `[A1, ..., An, 8]`, where the last + dimension represents a dual quaternion. + name: A name for this op that defaults to "dual_quaternion_multiply". + + Returns: + A tensor of shape `[A1, ..., An, 8]` representing dual quaternions. + """ + with tf.name_scope(name): + dual_quaternion1 = tf.convert_to_tensor(value=dual_quaternion1) + dual_quaternion2 = tf.convert_to_tensor(value=dual_quaternion2) + + shape.check_static( + tensor=dual_quaternion1, + tensor_name="dual_quaternion1", + has_dim_equals=(-1, 8)) + shape.check_static( + tensor=dual_quaternion2, + tensor_name="dual_quaternion2", + has_dim_equals=(-1, 8)) + + dual_quaternion1_real, dual_quaternion1_dual = tf.split( + dual_quaternion1, (4, 4), axis=-1) + dual_quaternion2_real, dual_quaternion2_dual = tf.split( + dual_quaternion2, (4, 4), axis=-1) + + dual_quaternion_output_real = quaternion.multiply(dual_quaternion1_real, + dual_quaternion2_real) + dual_quaternion_output_dual = ( + quaternion.multiply(dual_quaternion1_real, dual_quaternion2_dual) + + quaternion.multiply(dual_quaternion1_dual, dual_quaternion2_real)) + + return tf.concat((dual_quaternion_output_real, dual_quaternion_output_dual), + axis=-1) + + # API contains all public functions and classes. __all__ = export_api.get_functions_and_classes() diff --git a/tensorflow_graphics/geometry/transformation/tests/dual_quaternion_test.py b/tensorflow_graphics/geometry/transformation/tests/dual_quaternion_test.py index c794228e8..9933f9cd5 100644 --- a/tensorflow_graphics/geometry/transformation/tests/dual_quaternion_test.py +++ b/tensorflow_graphics/geometry/transformation/tests/dual_quaternion_test.py @@ -15,7 +15,7 @@ from absl.testing import flagsaver from absl.testing import parameterized -import tensorflow.compat.v2 as tf +import tensorflow as tf from tensorflow_graphics.geometry.transformation import dual_quaternion from tensorflow_graphics.geometry.transformation.tests import test_helpers @@ -29,30 +29,25 @@ class DualQuaternionTest(test_case.TestCase): ((None, 8),), ) def test_conjugate_exception_not_raised(self, *shape): - """Tests that the shape exceptions of conjugate are not raised.""" self.assert_exception_is_not_raised(dual_quaternion.conjugate, shape) @parameterized.parameters( ("must have exactly 8 dimensions", (3,)),) def test_conjugate_exception_raised(self, error_msg, *shape): - """Tests that the shape exceptions are raised.""" self.assert_exception_is_raised(dual_quaternion.conjugate, error_msg, shape) @flagsaver.flagsaver(tfg_add_asserts_to_graph=False) def test_conjugate_jacobian_preset(self): - """Tests the Jacobian of the conjugate function.""" x_init = test_helpers.generate_preset_test_dual_quaternions() self.assert_jacobian_is_correct_fn(dual_quaternion.conjugate, [x_init]) @flagsaver.flagsaver(tfg_add_asserts_to_graph=False) def test_conjugate_jacobian_random(self): - """Tests the Jacobian of the conjugate function.""" x_init = test_helpers.generate_random_test_dual_quaternions() self.assert_jacobian_is_correct_fn(dual_quaternion.conjugate, [x_init]) @flagsaver.flagsaver(tfg_add_asserts_to_graph=False) def test_conjugate_preset(self): - """Tests if the conjugate function is providing correct results.""" x_init = test_helpers.generate_preset_test_dual_quaternions() x = tf.convert_to_tensor(value=x_init) y = tf.convert_to_tensor(value=x_init) @@ -68,3 +63,37 @@ def test_conjugate_preset(self): self.assertAllEqual(x_real, y_real) self.assertAllEqual(x_dual, y_dual) + + @parameterized.parameters( + ((8,), (8,)), + ((None, 8), (None, 8)), + ) + def test_multiply_exception_not_raised(self, *shapes): + self.assert_exception_is_not_raised(dual_quaternion.multiply, shapes) + + @parameterized.parameters( + ("must have exactly 8 dimensions", (5,), (6,)), + ("must have exactly 8 dimensions", (7,), (8,)), + ) + def test_multiply_exception_raised(self, error_msg, *shape): + self.assert_exception_is_raised(dual_quaternion.multiply, error_msg, shape) + + @flagsaver.flagsaver(tfg_add_asserts_to_graph=False) + def test_multiply_jacobian_preset(self): + x_1_init = test_helpers.generate_preset_test_dual_quaternions() + x_2_init = test_helpers.generate_preset_test_dual_quaternions() + + self.assert_jacobian_is_correct_fn(dual_quaternion.multiply, + [x_1_init, x_2_init]) + + @flagsaver.flagsaver(tfg_add_asserts_to_graph=False) + def test_multiply_jacobian_random(self): + x_1_init = test_helpers.generate_random_test_dual_quaternions() + x_2_init = test_helpers.generate_random_test_dual_quaternions() + + self.assert_jacobian_is_correct_fn(dual_quaternion.multiply, + [x_1_init, x_2_init]) + + +if __name__ == "__main__": + test_case.main()