Skip to content

Commit

Permalink
Merge 1a726a6 into 2a525a9
Browse files Browse the repository at this point in the history
  • Loading branch information
copybara-service[bot] committed Jul 2, 2021
2 parents 2a525a9 + 1a726a6 commit fa0e3cd
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 0 deletions.
38 changes: 38 additions & 0 deletions tensorflow_graphics/geometry/transformation/dual_quaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,5 +211,43 @@ def norm(dual_quaternion: type_alias.TensorLike,
return tf.concat((quaternion_real_norm, normalized_dot_product), axis=-1)


def is_normalized(dual_quaternion: type_alias.TensorLike,
atol: tf.float32 = 1e-3,
name: str = "dual_quaternion_is_normalized") -> bool:
"""Determines if a dual quaternion is normalized or not.
Note:
In the following, A1 to An are optional batch dimensions.
Args:
dual_quaternion: A `[A1, ..., An, 8]`-tensor, where the last dimension
represents a dual quaternion.
atol: The absolute tolerance parameter.
name: A name for this op that defaults to "dual_quaternion_is_normalized".
Returns:
A `[A1, ..., An, 1]`-tensor of type `bool`, where False indicates that the
dual quaternion is not normalized.
Raises:
ValueError: If the shape of `dual_quaternion` is not supported.
"""
with tf.name_scope(name):
dual_quaternion = tf.convert_to_tensor(value=dual_quaternion)

shape.check_static(
tensor=dual_quaternion,
tensor_name="dual_quaternion",
has_dim_equals=(-1, 8))

norms = norm(dual_quaternion)

return tf.expand_dims(
tf.math.logical_and(
tf.abs(norms[..., 0] - 1.) < atol,
tf.abs(norms[..., 1] - 0.) < atol),
axis=-1)


# API contains all public functions and classes.
__all__ = export_api.get_functions_and_classes()
Original file line number Diff line number Diff line change
Expand Up @@ -190,5 +190,38 @@ def test_norm_correct_preset_non_unit(self):

self.assertAllClose(norms, norms_gt, rtol=1e-3)

@parameterized.parameters(
((8,),),
((None, 8),),
)
def test_is_normalized_exception_not_raised(self, *shape):
self.assert_exception_is_not_raised(dual_quaternion.is_normalized, shape)

@parameterized.parameters(
("must have exactly 8 dimensions", (1, 5)),)
def test_is_normalized_exception_raised(self, error_msg, *shape):
self.assert_exception_is_raised(dual_quaternion.is_normalized,
error_msg,
shape)

def test_is_normalized_random(self):
rnd_dual_quaternion = test_helpers.generate_random_test_dual_quaternions()
tensor_shape = rnd_dual_quaternion.shape[:-1]

unnormalized_rnd_dual_quaternion = rnd_dual_quaternion * 1.01
rnd_dual_quaternion = tf.convert_to_tensor(rnd_dual_quaternion)
unnormalized_rnd_dual_quaternion = tf.convert_to_tensor(
unnormalized_rnd_dual_quaternion)
dual_quaternions = tf.concat(
(rnd_dual_quaternion, unnormalized_rnd_dual_quaternion), axis=0)
mask = tf.concat(
(tf.ones(shape=tensor_shape + (1,), dtype=bool),
tf.zeros(shape=tensor_shape + (1,), dtype=bool)),
axis=0)
is_normalized = dual_quaternion.is_normalized(dual_quaternions)

self.assertAllEqual(mask, is_normalized)


if __name__ == "__main__":
test_case.main()

0 comments on commit fa0e3cd

Please sign in to comment.