From 2538486bacc74f43a0697fc061d0ade2807a4ccb Mon Sep 17 00:00:00 2001 From: Cengiz Oztireli Date: Fri, 9 Jul 2021 07:00:26 -0700 Subject: [PATCH] Adds the to_rotation_translation function to the dual_quaternion module. PiperOrigin-RevId: 383838512 --- .../transformation/dual_quaternion.py | 34 +++++++++++++++- .../tests/dual_quaternion_test.py | 39 +++++++++++++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) diff --git a/tensorflow_graphics/geometry/transformation/dual_quaternion.py b/tensorflow_graphics/geometry/transformation/dual_quaternion.py index fa81249fc..5deb0e740 100644 --- a/tensorflow_graphics/geometry/transformation/dual_quaternion.py +++ b/tensorflow_graphics/geometry/transformation/dual_quaternion.py @@ -33,6 +33,8 @@ from __future__ import division from __future__ import print_function +from typing import Tuple + import tensorflow as tf from tensorflow_graphics.geometry.transformation import quaternion from tensorflow_graphics.math import vector @@ -271,7 +273,7 @@ def from_rotation_translation( Returns: A `[A1, ..., An, 8]`-tensor, where the last dimension represents a - normalized quaternion. + normalized dual quaternion. Raises: ValueError: If the shape of `rotation_matrix` is not supported. @@ -304,5 +306,35 @@ def from_rotation_translation( return tf.concat((quaternion_rotation, dual_quaternion_dual_part), axis=-1) +def to_rotation_translation( + dual_quaternion: type_alias.TensorLike, + name: str = "dual_quaternion_to_rot_trans") -> Tuple[tf.Tensor, tf.Tensor]: + """Converts a dual quaternion into a quaternion for rotation and translation. + + Args: + dual_quaternion: A `[A1, ..., An, 8]`-tensor, where the last dimension + represents a qual quaternion. + name: A name for this op that defaults to "dual_quaternion_to_rot_trans". + + Returns: + A `[A1, ..., An, 7]`-tensor, where the last dimension represents a + normalized quaternion and a translation vector, in that order. + """ + with tf.name_scope(name): + dual_quaternion = tf.convert_to_tensor(value=dual_quaternion) + + shape.check_static( + tensor=dual_quaternion, + tensor_name="dual_quaternion", + has_dim_equals=(-1, 8)) + + rotation = dual_quaternion[..., 0:4] + translation = 2 * quaternion.multiply( + dual_quaternion[..., 4:8], quaternion.inverse(rotation)) + translation = translation[..., 0:3] + + return rotation, translation + + # API contains all public functions and classes. __all__ = export_api.get_functions_and_classes() diff --git a/tensorflow_graphics/geometry/transformation/tests/dual_quaternion_test.py b/tensorflow_graphics/geometry/transformation/tests/dual_quaternion_test.py index 2bd9bdf9b..5fcf44344 100644 --- a/tensorflow_graphics/geometry/transformation/tests/dual_quaternion_test.py +++ b/tensorflow_graphics/geometry/transformation/tests/dual_quaternion_test.py @@ -15,6 +15,7 @@ from absl.testing import flagsaver from absl.testing import parameterized + import numpy as np import tensorflow as tf @@ -265,5 +266,43 @@ def test_from_rotation_matrix_random(self): self.assertAllClose(rotation_gt, rotation) self.assertAllClose(translation_gt, translation) + @flagsaver.flagsaver(tfg_add_asserts_to_graph=False) + def test_to_rotation_translation_jacobian_preset(self): + pre_dual_quaternion = test_helpers.generate_preset_test_dual_quaternions() + + def to_rotation(input_dual_quaternion): + rotation, _ = dual_quaternion.to_rotation_translation( + input_dual_quaternion) + return rotation + + self.assert_jacobian_is_finite_fn(to_rotation, [pre_dual_quaternion]) + + @flagsaver.flagsaver(tfg_add_asserts_to_graph=False) + def test_to_rotation_translation_jacobian_random(self): + rnd_dual_quaternion = test_helpers.generate_random_test_dual_quaternions() + + def to_translation(input_dual_quaternion): + _, translation = dual_quaternion.to_rotation_translation( + input_dual_quaternion) + return translation + + self.assert_jacobian_is_finite_fn(to_translation, [rnd_dual_quaternion]) + + def test_to_rotation_matrix_random(self): + (euler_angles_gt, translation_gt + ) = test_helpers.generate_random_test_euler_angles_translations() + rotation_gt = rotation_matrix_3d.from_quaternion( + quaternion.from_euler(euler_angles_gt)) + + dual_quaternion_output = dual_quaternion.from_rotation_translation( + rotation_gt, translation_gt) + rotation, translation = dual_quaternion.to_rotation_translation( + dual_quaternion_output) + + self.assertAllClose(rotation_gt, + rotation_matrix_3d.from_quaternion(rotation)) + self.assertAllClose(translation_gt, translation) + + if __name__ == "__main__": test_case.main()