Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion tensorflow_graphics/geometry/transformation/dual_quaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from __future__ import division
from __future__ import print_function

from typing import Tuple

import tensorflow as tf
from tensorflow_graphics.geometry.transformation import quaternion
from tensorflow_graphics.math import vector
Expand Down Expand Up @@ -271,7 +273,7 @@ def from_rotation_translation(

Returns:
A `[A1, ..., An, 8]`-tensor, where the last dimension represents a
normalized quaternion.
normalized dual quaternion.

Raises:
ValueError: If the shape of `rotation_matrix` is not supported.
Expand Down Expand Up @@ -304,5 +306,35 @@ def from_rotation_translation(
return tf.concat((quaternion_rotation, dual_quaternion_dual_part), axis=-1)


def to_rotation_translation(
dual_quaternion: type_alias.TensorLike,
name: str = "dual_quaternion_to_rot_trans") -> Tuple[tf.Tensor, tf.Tensor]:
"""Converts a dual quaternion into a quaternion for rotation and translation.

Args:
dual_quaternion: A `[A1, ..., An, 8]`-tensor, where the last dimension
represents a qual quaternion.
name: A name for this op that defaults to "dual_quaternion_to_rot_trans".

Returns:
A `[A1, ..., An, 7]`-tensor, where the last dimension represents a
normalized quaternion and a translation vector, in that order.
"""
with tf.name_scope(name):
dual_quaternion = tf.convert_to_tensor(value=dual_quaternion)

shape.check_static(
tensor=dual_quaternion,
tensor_name="dual_quaternion",
has_dim_equals=(-1, 8))

rotation = dual_quaternion[..., 0:4]
translation = 2 * quaternion.multiply(
dual_quaternion[..., 4:8], quaternion.inverse(rotation))
translation = translation[..., 0:3]

return rotation, translation


# API contains all public functions and classes.
__all__ = export_api.get_functions_and_classes()
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from absl.testing import flagsaver
from absl.testing import parameterized

import numpy as np
import tensorflow as tf

Expand Down Expand Up @@ -265,5 +266,43 @@ def test_from_rotation_matrix_random(self):
self.assertAllClose(rotation_gt, rotation)
self.assertAllClose(translation_gt, translation)

@flagsaver.flagsaver(tfg_add_asserts_to_graph=False)
def test_to_rotation_translation_jacobian_preset(self):
pre_dual_quaternion = test_helpers.generate_preset_test_dual_quaternions()

def to_rotation(input_dual_quaternion):
rotation, _ = dual_quaternion.to_rotation_translation(
input_dual_quaternion)
return rotation

self.assert_jacobian_is_finite_fn(to_rotation, [pre_dual_quaternion])

@flagsaver.flagsaver(tfg_add_asserts_to_graph=False)
def test_to_rotation_translation_jacobian_random(self):
rnd_dual_quaternion = test_helpers.generate_random_test_dual_quaternions()

def to_translation(input_dual_quaternion):
_, translation = dual_quaternion.to_rotation_translation(
input_dual_quaternion)
return translation

self.assert_jacobian_is_finite_fn(to_translation, [rnd_dual_quaternion])

def test_to_rotation_matrix_random(self):
(euler_angles_gt, translation_gt
) = test_helpers.generate_random_test_euler_angles_translations()
rotation_gt = rotation_matrix_3d.from_quaternion(
quaternion.from_euler(euler_angles_gt))

dual_quaternion_output = dual_quaternion.from_rotation_translation(
rotation_gt, translation_gt)
rotation, translation = dual_quaternion.to_rotation_translation(
dual_quaternion_output)

self.assertAllClose(rotation_gt,
rotation_matrix_3d.from_quaternion(rotation))
self.assertAllClose(translation_gt, translation)


if __name__ == "__main__":
test_case.main()