Skip to content

Commit

Permalink
Adds multiply function to dual_quaternion module.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 380158815
  • Loading branch information
G4G authored and Copybara-Service committed Jun 18, 2021
1 parent d4f2f6a commit f4f04cb
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 6 deletions.
46 changes: 46 additions & 0 deletions tensorflow_graphics/geometry/transformation/dual_quaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,51 @@ def conjugate(dual_quaternion: type_alias.TensorLike,
axis=-1)


def multiply(dual_quaternion1: type_alias.TensorLike,
dual_quaternion2: type_alias.TensorLike,
name: str = "dual_quaternion_multiply"):
"""Multiplies two dual quaternions.
Note:
In the following, A1 to An are optional batch dimensions.
Args:
dual_quaternion1: A TensorLike of shape `[A1, ..., An, 8]`, where the last
dimension represents a dual quaternion.
dual_quaternion2: A TensorLike of shape `[A1, ..., An, 8]`, where the last
dimension represents a dual quaternion.
name: A name for this op that defaults to "dual_quaternion_multiply".
Returns:
A tensor of shape `[A1, ..., An, 8]` representing dual quaternions.
"""
with tf.name_scope(name):
dual_quaternion1 = tf.convert_to_tensor(value=dual_quaternion1)
dual_quaternion2 = tf.convert_to_tensor(value=dual_quaternion2)

shape.check_static(
tensor=dual_quaternion1,
tensor_name="dual_quaternion1",
has_dim_equals=(-1, 8))
shape.check_static(
tensor=dual_quaternion2,
tensor_name="dual_quaternion2",
has_dim_equals=(-1, 8))

dual_quaternion1_real, dual_quaternion1_dual = tf.split(
dual_quaternion1, (4, 4), axis=-1)
dual_quaternion2_real, dual_quaternion2_dual = tf.split(
dual_quaternion2, (4, 4), axis=-1)

dual_quaternion_output_real = quaternion.multiply(dual_quaternion1_real,
dual_quaternion2_real)
dual_quaternion_output_dual = (
quaternion.multiply(dual_quaternion1_real, dual_quaternion2_dual) +
quaternion.multiply(dual_quaternion1_dual, dual_quaternion2_real))

return tf.concat((dual_quaternion_output_real, dual_quaternion_output_dual),
axis=-1)


# API contains all public functions and classes.
__all__ = export_api.get_functions_and_classes()
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from absl.testing import flagsaver
from absl.testing import parameterized
import tensorflow.compat.v2 as tf
import tensorflow as tf

from tensorflow_graphics.geometry.transformation import dual_quaternion
from tensorflow_graphics.geometry.transformation.tests import test_helpers
Expand All @@ -29,30 +29,25 @@ class DualQuaternionTest(test_case.TestCase):
((None, 8),),
)
def test_conjugate_exception_not_raised(self, *shape):
"""Tests that the shape exceptions of conjugate are not raised."""
self.assert_exception_is_not_raised(dual_quaternion.conjugate, shape)

@parameterized.parameters(
("must have exactly 8 dimensions", (3,)),)
def test_conjugate_exception_raised(self, error_msg, *shape):
"""Tests that the shape exceptions are raised."""
self.assert_exception_is_raised(dual_quaternion.conjugate, error_msg, shape)

@flagsaver.flagsaver(tfg_add_asserts_to_graph=False)
def test_conjugate_jacobian_preset(self):
"""Tests the Jacobian of the conjugate function."""
x_init = test_helpers.generate_preset_test_dual_quaternions()
self.assert_jacobian_is_correct_fn(dual_quaternion.conjugate, [x_init])

@flagsaver.flagsaver(tfg_add_asserts_to_graph=False)
def test_conjugate_jacobian_random(self):
"""Tests the Jacobian of the conjugate function."""
x_init = test_helpers.generate_random_test_dual_quaternions()
self.assert_jacobian_is_correct_fn(dual_quaternion.conjugate, [x_init])

@flagsaver.flagsaver(tfg_add_asserts_to_graph=False)
def test_conjugate_preset(self):
"""Tests if the conjugate function is providing correct results."""
x_init = test_helpers.generate_preset_test_dual_quaternions()
x = tf.convert_to_tensor(value=x_init)
y = tf.convert_to_tensor(value=x_init)
Expand All @@ -68,3 +63,37 @@ def test_conjugate_preset(self):

self.assertAllEqual(x_real, y_real)
self.assertAllEqual(x_dual, y_dual)

@parameterized.parameters(
((8,), (8,)),
((None, 8), (None, 8)),
)
def test_multiply_exception_not_raised(self, *shapes):
self.assert_exception_is_not_raised(dual_quaternion.multiply, shapes)

@parameterized.parameters(
("must have exactly 8 dimensions", (5,), (6,)),
("must have exactly 8 dimensions", (7,), (8,)),
)
def test_multiply_exception_raised(self, error_msg, *shape):
self.assert_exception_is_raised(dual_quaternion.multiply, error_msg, shape)

@flagsaver.flagsaver(tfg_add_asserts_to_graph=False)
def test_multiply_jacobian_preset(self):
x_1_init = test_helpers.generate_preset_test_dual_quaternions()
x_2_init = test_helpers.generate_preset_test_dual_quaternions()

self.assert_jacobian_is_correct_fn(dual_quaternion.multiply,
[x_1_init, x_2_init])

@flagsaver.flagsaver(tfg_add_asserts_to_graph=False)
def test_multiply_jacobian_random(self):
x_1_init = test_helpers.generate_random_test_dual_quaternions()
x_2_init = test_helpers.generate_random_test_dual_quaternions()

self.assert_jacobian_is_correct_fn(dual_quaternion.multiply,
[x_1_init, x_2_init])


if __name__ == "__main__":
test_case.main()

0 comments on commit f4f04cb

Please sign in to comment.