Skip to content

Commit

Permalink
Port of the perspectiveRH function from glm to TensorFlow
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 268436931
  • Loading branch information
julienvalentin authored and Copybara-Service committed Sep 11, 2019
1 parent cc4e010 commit 2796ce9
Show file tree
Hide file tree
Showing 6 changed files with 347 additions and 0 deletions.
1 change: 1 addition & 0 deletions tensorflow_graphics/rendering/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ py_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow_graphics/rendering/camera",
"//tensorflow_graphics/rendering/opengl",
"//tensorflow_graphics/rendering/reflectance",
"//tensorflow_graphics/util:export_api",
],
Expand Down
1 change: 1 addition & 0 deletions tensorflow_graphics/rendering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import print_function

from tensorflow_graphics.rendering import camera
from tensorflow_graphics.rendering import opengl
from tensorflow_graphics.rendering import reflectance
from tensorflow_graphics.util import export_api as _export_api

Expand Down
67 changes: 67 additions & 0 deletions tensorflow_graphics/rendering/opengl/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Math functionalities for tf-graphics.

# google internal package dependency 8)
# google internal package dependency 5

licenses(["notice"]) # Apache 2.0

package(default_visibility = ["//visibility:public"])

py_library(
name = "opengl",
srcs = [
"__init__.py",
],
srcs_version = "PY2AND3",
# google internal rule 1
visibility = ["//visibility:public"],
deps = [
":math",
"//tensorflow_graphics/util:export_api",
],
)

py_library(
name = "math",
srcs = ["math.py"],
srcs_version = "PY2AND3",
# google internal rule 1
deps = [
# google internal package dependency 1,
"//tensorflow_graphics/util:asserts",
"//tensorflow_graphics/util:export_api",
"//tensorflow_graphics/util:shape",
],
)

py_test(
name = "math_test",
srcs = ["tests/math_test.py"],
srcs_version = "PY2AND3",
# google internal rule 1
# google internal rule 2
# google internal rule 3
# google internal rule 4
# google internal rule 5
# google internal rule 6
deps = [
":math",
# google internal package dependency 2
# google internal package dependency 6
# google internal package dependency 1,
"//tensorflow_graphics/util:test_case",
],
)
23 changes: 23 additions & 0 deletions tensorflow_graphics/rendering/opengl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""OpenGL module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow_graphics.rendering.opengl import math
from tensorflow_graphics.util import export_api as _export_api

# API contains submodules of tensorflow_graphics.rendering.
__all__ = _export_api.get_modules()
104 changes: 104 additions & 0 deletions tensorflow_graphics/rendering/opengl/math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module implements math routines used by OpenGL."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import tensorflow as tf

from tensorflow_graphics.util import asserts
from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import shape


def perspective_right_handed(vertical_field_of_view,
aspect_ratio,
near,
far,
name=None):
"""Generates the matrix for a right handed perspective projection.
Note:
In the following, A1 to An are optional batch dimensions.
Args:
vertical_field_of_view: A tensor of shape `[A1, ..., An]`, where the last
dimension represents the vertical field of view of the frustum expressed
in radians. Note that values for `vertical_field_of_view` must be in the
range (0,pi).
aspect_ratio: A tensor of shape `[A1, ..., An, C]`, where the last dimension
stores the width over height ratio of the frustum. Note that values for
`aspect_ratio` must be non-negative.
near: A tensor of shape `[A1, ..., An, C]`, where the last dimension
captures the distance between the viewer and the near clipping plane. Note
that values for `near` must be non-negative.
far: A tensor of shape `[A1, ..., An, C]`, where the last dimension
captures the distance between the viewer and the far clipping plane. Note
that values for `far` must be greater than those of `near`.
name: A name for this op. Defaults to 'perspective_rh'.
Raises:
InvalidArgumentError: if any input contains data not in the specified range
of valid values.
ValueError: if the all the inputs are not of the same shape.
Returns:
A tensor of shape `[A1, ..., An, 4, 4]`, containing matrices of right
handed perspective-view frustum.
"""
with tf.compat.v1.name_scope(
name, "perspective_rh",
[vertical_field_of_view, aspect_ratio, near, far]):
vertical_field_of_view = tf.convert_to_tensor(value=vertical_field_of_view)
aspect_ratio = tf.convert_to_tensor(value=aspect_ratio)
near = tf.convert_to_tensor(value=near)
far = tf.convert_to_tensor(value=far)

shape.compare_batch_dimensions(
tensors=(vertical_field_of_view, aspect_ratio, near, far),
last_axes=-1,
tensor_names=("vertical_field_of_view", "aspect_ratio", "near", "far"),
broadcast_compatible=False)

vertical_field_of_view = asserts.assert_all_in_range(
vertical_field_of_view, 0.0, math.pi, open_bounds=True)
aspect_ratio = asserts.assert_all_above(aspect_ratio, 0.0, open_bound=True)
near = asserts.assert_all_above(near, 0.0, open_bound=True)
far = asserts.assert_all_above(far, near, open_bound=True)

inverse_tan_half_vertical_field_of_view = 1.0 / tf.tan(
vertical_field_of_view * 0.5)
zero = tf.zeros_like(inverse_tan_half_vertical_field_of_view)
one = tf.ones_like(inverse_tan_half_vertical_field_of_view)

x = tf.stack((inverse_tan_half_vertical_field_of_view / aspect_ratio, zero,
zero, zero),
axis=-1)
y = tf.stack((zero, inverse_tan_half_vertical_field_of_view, zero, zero),
axis=-1)
near_minus_far = near - far
z = tf.stack(
(zero, zero,
(far + near) / near_minus_far, 2.0 * far * near / near_minus_far),
axis=-1)
w = tf.stack((zero, zero, -one, zero), axis=-1)

return tf.stack((x, y, z, w), axis=-2)


# API contains all public functions and classes.
__all__ = export_api.get_functions_and_classes()
151 changes: 151 additions & 0 deletions tensorflow_graphics/rendering/opengl/tests/math_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
#Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for OpenGL math routines."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math

from absl.testing import parameterized
import numpy as np
import tensorflow as tf

from tensorflow_graphics.rendering.opengl import math as glm
from tensorflow_graphics.util import test_case


class MathTest(test_case.TestCase):

def test_perspective_right_handed_preset(self):
"""Tests that perspective_right_handed generates expected results.."""
vertical_field_of_view = (60.0 * math.pi / 180.0, 50.0 * math.pi / 180.0)
aspect_ratio = (1.5, 1.1)
near = (1.0, 1.2)
far = (10.0, 5.0)

pred = glm.perspective_right_handed(vertical_field_of_view, aspect_ratio,
near, far)
gt = (((1.15470052, 0.0, 0.0, 0.0), (0.0, 1.73205066, 0.0, 0.0),
(0.0, 0.0, -1.22222221, -2.22222233), (0.0, 0.0, -1.0, 0.0)),
((1.9495517, 0.0, 0.0, 0.0), (0.0, 2.14450693, 0.0, 0.0),
(0.0, 0.0, -1.63157892, -3.15789485), (0.0, 0.0, -1.0, 0.0)))
self.assertAllClose(pred, gt)

@parameterized.parameters(
((1,), (1,), (1,), (1,)),
((None, 2), (None, 2), (None, 2), (None, 2)),
)
def test_perspective_right_handed_exception_not_raised(self, *shapes):
"""Tests that the shape exceptions are not raised."""
self.assert_exception_is_not_raised(glm.perspective_right_handed, shapes)

@parameterized.parameters(
("Not all batch dimensions are identical", (3,), (3, 3), (3, 3), (3, 3)),
("Not all batch dimensions are identical", (2, 3), (3, 3), (3, 3),
(3, 3)),
)
def test_perspective_right_handed_shape_exception_raised(
self, error_msg, *shapes):
"""Tests that the shape exceptions are properly raised."""
self.assert_exception_is_raised(glm.perspective_right_handed, error_msg,
shapes)

@parameterized.parameters(
((1.0,),
(1.0,), np.random.uniform(-1.0, 0.0, size=(1,)).astype(np.float32),
(1.0,)),
((1.0,), (1.0,), (0.0,), (1.0,)),
((1.0,), np.random.uniform(-1.0, 0.0, size=(1,)).astype(np.float32),
(0.1,), (1.0,)),
((1.0,), (0.0,), (0.1,), (1.0,)),
((1.0,),
(1.0,), np.random.uniform(1.0, 2.0, size=(1,)).astype(np.float32),
np.random.uniform(0.1, 0.5, size=(1,)).astype(np.float32)),
((1.0,), (1.0,), (0.1,), (0.1,)),
(np.random.uniform(-math.pi, 0.0, size=(1,)).astype(np.float32), (1.0,),
(0.1,), (1.0,)),
(np.random.uniform(math.pi, 2.0 * math.pi, size=(1,)).astype(np.float32),
(1.0,), (0.1,), (1.0,)),
((0.0,), (1.0,), (0.1,), (1.0,)),
((math.pi,), (1.0,), (0.1,), (1.0,)),
)
def test_perspective_right_handed_valid_range_exception_raised(
self, vertical_field_of_view, aspect_ratio, near, far):
"""Tests that an exception is raised with out of bounds values."""
with self.assertRaises(tf.errors.InvalidArgumentError):
self.evaluate(
glm.perspective_right_handed(vertical_field_of_view, aspect_ratio,
near, far))

def test_perspective_right_handed_cross_jacobian_preset(self):
"""Tests the Jacobian of perspective_right_handed."""
vertical_field_of_view_init = np.array((1.0,))
aspect_ratio_init = np.array((1.0,))
near_init = np.array((1.0,))
far_init = np.array((10.0,))

# Wrap with tf.identity because some assert_* ops look at the constant
# tensor value and mark it as unfeedable.
vertical_field_of_view_tensor = tf.identity(
tf.convert_to_tensor(value=vertical_field_of_view_init))
aspect_ratio_tensor = tf.identity(
tf.convert_to_tensor(value=aspect_ratio_init))
near_tensor = tf.identity(tf.convert_to_tensor(value=near_init))
far_tensor = tf.identity(tf.convert_to_tensor(value=far_init))

y = glm.perspective_right_handed(vertical_field_of_view_tensor,
aspect_ratio_tensor, near_tensor,
far_tensor)

self.assert_jacobian_is_correct(vertical_field_of_view_tensor,
vertical_field_of_view_init, y)
self.assert_jacobian_is_correct(aspect_ratio_tensor, aspect_ratio_init, y)
self.assert_jacobian_is_correct(near_tensor, near_init, y)
self.assert_jacobian_is_correct(far_tensor, far_init, y)

def test_perspective_right_handed_cross_jacobian_random(self):
"""Tests the Jacobian of perspective_right_handed."""
tensor_size = np.random.randint(1, 3)
tensor_shape = np.random.randint(1, 5, size=(tensor_size)).tolist()
eps = np.finfo(np.float64).eps
vertical_field_of_view_init = np.random.uniform(
eps, math.pi - eps, size=tensor_shape)
aspect_ratio_init = np.random.uniform(eps, 100.0, size=tensor_shape)
near_init = np.random.uniform(eps, 10.0, size=tensor_shape)
far_init = np.random.uniform(10 + eps, 100.0, size=tensor_shape)

# Wrap with tf.identity because some assert_* ops look at the constant
# tensor value and mark it as unfeedable.
vertical_field_of_view_tensor = tf.identity(
tf.convert_to_tensor(value=vertical_field_of_view_init))
aspect_ratio_tensor = tf.identity(
tf.convert_to_tensor(value=aspect_ratio_init))
near_tensor = tf.identity(tf.convert_to_tensor(value=near_init))
far_tensor = tf.identity(tf.convert_to_tensor(value=far_init))

y = glm.perspective_right_handed(vertical_field_of_view_tensor,
aspect_ratio_tensor, near_tensor,
far_tensor)

self.assert_jacobian_is_correct(vertical_field_of_view_tensor,
vertical_field_of_view_init, y)
self.assert_jacobian_is_correct(aspect_ratio_tensor, aspect_ratio_init, y)
self.assert_jacobian_is_correct(near_tensor, near_init, y)
self.assert_jacobian_is_correct(far_tensor, far_init, y)


if __name__ == "__main__":
test_case.main()

0 comments on commit 2796ce9

Please sign in to comment.