Skip to content

Commit 9957243

Browse files
G4Gcopybara-github
authored andcommitted
Adds multiply function to dual_quaternion module.
PiperOrigin-RevId: 379060199
1 parent d4f2f6a commit 9957243

File tree

2 files changed

+81
-6
lines changed

2 files changed

+81
-6
lines changed

tensorflow_graphics/geometry/transformation/dual_quaternion.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,51 @@ def conjugate(dual_quaternion: type_alias.TensorLike,
7878
axis=-1)
7979

8080

81+
def multiply(dual_quaternion1: type_alias.TensorLike,
82+
dual_quaternion2: type_alias.TensorLike,
83+
name: str = "dual_quaternion_multiply"):
84+
"""Multiplies two dual quaternions.
85+
86+
Note:
87+
In the following, A1 to An are optional batch dimensions.
88+
89+
Args:
90+
dual_quaternion1: A TensorLike of shape `[A1, ..., An, 8]`, where the last
91+
dimension represents a dual quaternion.
92+
dual_quaternion2: A TensorLike of shape `[A1, ..., An, 8]`, where the last
93+
dimension represents a dual quaternion.
94+
name: A name for this op that defaults to "dual_quaternion_multiply".
95+
96+
Returns:
97+
A tensor of shape `[A1, ..., An, 8]` representing dual quaternions.
98+
"""
99+
with tf.name_scope(name):
100+
dual_quaternion1 = tf.convert_to_tensor(value=dual_quaternion1)
101+
dual_quaternion2 = tf.convert_to_tensor(value=dual_quaternion2)
102+
103+
shape.check_static(
104+
tensor=dual_quaternion1,
105+
tensor_name="dual_quaternion1",
106+
has_dim_equals=(-1, 8))
107+
shape.check_static(
108+
tensor=dual_quaternion2,
109+
tensor_name="dual_quaternion2",
110+
has_dim_equals=(-1, 8))
111+
112+
dual_quaternion1_real, dual_quaternion1_dual = tf.split(
113+
dual_quaternion1, (4, 4), axis=-1)
114+
dual_quaternion2_real, dual_quaternion2_dual = tf.split(
115+
dual_quaternion2, (4, 4), axis=-1)
116+
117+
dual_quaternion_output_real = quaternion.multiply(dual_quaternion1_real,
118+
dual_quaternion2_real)
119+
dual_quaternion_output_dual = (
120+
quaternion.multiply(dual_quaternion1_real, dual_quaternion2_dual) +
121+
quaternion.multiply(dual_quaternion1_dual, dual_quaternion2_real))
122+
123+
return tf.concat((dual_quaternion_output_real, dual_quaternion_output_dual),
124+
axis=-1)
125+
126+
81127
# API contains all public functions and classes.
82128
__all__ = export_api.get_functions_and_classes()

tensorflow_graphics/geometry/transformation/tests/dual_quaternion_test.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from absl.testing import flagsaver
1717
from absl.testing import parameterized
18-
import tensorflow.compat.v2 as tf
18+
import tensorflow as tf
1919

2020
from tensorflow_graphics.geometry.transformation import dual_quaternion
2121
from tensorflow_graphics.geometry.transformation.tests import test_helpers
@@ -29,30 +29,25 @@ class DualQuaternionTest(test_case.TestCase):
2929
((None, 8),),
3030
)
3131
def test_conjugate_exception_not_raised(self, *shape):
32-
"""Tests that the shape exceptions of conjugate are not raised."""
3332
self.assert_exception_is_not_raised(dual_quaternion.conjugate, shape)
3433

3534
@parameterized.parameters(
3635
("must have exactly 8 dimensions", (3,)),)
3736
def test_conjugate_exception_raised(self, error_msg, *shape):
38-
"""Tests that the shape exceptions are raised."""
3937
self.assert_exception_is_raised(dual_quaternion.conjugate, error_msg, shape)
4038

4139
@flagsaver.flagsaver(tfg_add_asserts_to_graph=False)
4240
def test_conjugate_jacobian_preset(self):
43-
"""Tests the Jacobian of the conjugate function."""
4441
x_init = test_helpers.generate_preset_test_dual_quaternions()
4542
self.assert_jacobian_is_correct_fn(dual_quaternion.conjugate, [x_init])
4643

4744
@flagsaver.flagsaver(tfg_add_asserts_to_graph=False)
4845
def test_conjugate_jacobian_random(self):
49-
"""Tests the Jacobian of the conjugate function."""
5046
x_init = test_helpers.generate_random_test_dual_quaternions()
5147
self.assert_jacobian_is_correct_fn(dual_quaternion.conjugate, [x_init])
5248

5349
@flagsaver.flagsaver(tfg_add_asserts_to_graph=False)
5450
def test_conjugate_preset(self):
55-
"""Tests if the conjugate function is providing correct results."""
5651
x_init = test_helpers.generate_preset_test_dual_quaternions()
5752
x = tf.convert_to_tensor(value=x_init)
5853
y = tf.convert_to_tensor(value=x_init)
@@ -68,3 +63,37 @@ def test_conjugate_preset(self):
6863

6964
self.assertAllEqual(x_real, y_real)
7065
self.assertAllEqual(x_dual, y_dual)
66+
67+
@parameterized.parameters(
68+
((8,), (8,)),
69+
((None, 8), (None, 8)),
70+
)
71+
def test_multiply_exception_not_raised(self, *shapes):
72+
self.assert_exception_is_not_raised(dual_quaternion.multiply, shapes)
73+
74+
@parameterized.parameters(
75+
("must have exactly 8 dimensions", (5,), (6,)),
76+
("must have exactly 8 dimensions", (7,), (8,)),
77+
)
78+
def test_multiply_exception_raised(self, error_msg, *shape):
79+
self.assert_exception_is_raised(dual_quaternion.multiply, error_msg, shape)
80+
81+
@flagsaver.flagsaver(tfg_add_asserts_to_graph=False)
82+
def test_multiply_jacobian_preset(self):
83+
x_1_init = test_helpers.generate_preset_test_dual_quaternions()
84+
x_2_init = test_helpers.generate_preset_test_dual_quaternions()
85+
86+
self.assert_jacobian_is_correct_fn(dual_quaternion.multiply,
87+
[x_1_init, x_2_init])
88+
89+
@flagsaver.flagsaver(tfg_add_asserts_to_graph=False)
90+
def test_multiply_jacobian_random(self):
91+
x_1_init = test_helpers.generate_random_test_dual_quaternions()
92+
x_2_init = test_helpers.generate_random_test_dual_quaternions()
93+
94+
self.assert_jacobian_is_correct_fn(dual_quaternion.multiply,
95+
[x_1_init, x_2_init])
96+
97+
98+
if __name__ == "__main__":
99+
test_case.main()

0 commit comments

Comments
 (0)