Skip to content

Commit 4311310

Browse files
podlipenskycopybara-github
authored andcommitted
Rasterization backend golden tests.
PiperOrigin-RevId: 363596956
1 parent e110a12 commit 4311310

File tree

5 files changed

+238
-5
lines changed

5 files changed

+238
-5
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2020 The TensorFlow Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for tensorflow_graphics.rendering.tests.interpolate."""
15+
16+
import tensorflow as tf
17+
18+
from tensorflow_graphics.rendering import interpolate
19+
from tensorflow_graphics.rendering import rasterization_backend
20+
from tensorflow_graphics.rendering.tests import rasterization_test_utils
21+
from tensorflow_graphics.util import test_case
22+
23+
24+
class RasterizeTest(test_case.TestCase):
25+
26+
def setUp(self):
27+
super(RasterizeTest, self).setUp()
28+
29+
self.test_data_directory = (
30+
'google3/research/vision/viscam/diffren/mesh/test_data/')
31+
32+
self.cube_vertex_positions = tf.constant(
33+
[[[-1, -1, 1], [-1, -1, -1], [-1, 1, -1], [-1, 1, 1], [1, -1, 1],
34+
[1, -1, -1], [1, 1, -1], [1, 1, 1]]],
35+
dtype=tf.float32)
36+
self.cube_triangles = tf.constant(
37+
[[0, 1, 2], [2, 3, 0], [3, 2, 6], [6, 7, 3], [7, 6, 5], [5, 4, 7],
38+
[4, 5, 1], [1, 0, 4], [5, 6, 2], [2, 1, 5], [7, 4, 0], [0, 3, 7]],
39+
dtype=tf.int32)
40+
41+
self.image_width = 640
42+
self.image_height = 480
43+
perspective = rasterization_test_utils.make_perspective_matrix(
44+
self.image_width, self.image_height)
45+
projection = tf.matmul(
46+
perspective,
47+
rasterization_test_utils.make_look_at_matrix(
48+
camera_origin=(2.0, 3.0, 6.0)))
49+
# Add batch dimension.
50+
self.projection = tf.expand_dims(projection, axis=0)
51+
52+
def test_renders_colored_cube(self):
53+
"""Renders a simple colored cube in two viewpoints."""
54+
num_layers = 1
55+
rasterized = rasterization_backend.rasterize(
56+
self.cube_vertex_positions,
57+
self.cube_triangles,
58+
self.projection, (self.image_width, self.image_height),
59+
num_layers=num_layers,
60+
enable_cull_face=False,
61+
backend=rasterization_backend.RasterizationBackends.CPU)
62+
63+
vertex_rgb = (self.cube_vertex_positions * 0.5 + 0.5)
64+
vertex_rgba = tf.concat([vertex_rgb, tf.ones([1, 8, 1])], axis=-1)
65+
rendered = interpolate.interpolate_vertex_attribute(vertex_rgba,
66+
rasterized).value
67+
68+
baseline_image = rasterization_test_utils.load_baseline_image(
69+
'Unlit_Cube_0_0.png', rendered.shape)
70+
71+
images_near, error_message = rasterization_test_utils.compare_images(
72+
self, baseline_image, rendered)
73+
self.assertTrue(images_near, msg=error_message)
74+
75+
76+
if __name__ == '__main__':
77+
tf.test.main()

tensorflow_graphics/rendering/tests/rasterization_backend_test_base.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,16 @@
1717
import tensorflow as tf
1818

1919
from tensorflow_graphics.rendering import rasterization_backend
20+
from tensorflow_graphics.rendering.tests import rasterization_test_utils
2021
from tensorflow_graphics.util import test_case
2122

2223

2324
class RasterizationBackendTestBase(test_case.TestCase):
2425
"""Base class for CPU/GPU rasterization backend tests."""
2526

27+
IMAGE_WIDTH = 640
28+
IMAGE_HEIGHT = 480
29+
2630
def setUp(self):
2731
super().setUp()
2832
self._backend = rasterization_backend.RasterizationBackends.OPENGL
@@ -75,8 +79,9 @@ def test_rasterizer_return_correct_batch_shapes(self, shapes, dtypes,
7579
"""Tests that supported backends return correct shape."""
7680
placeholders = self._create_placeholders(shapes, dtypes)
7781
frame_buffer = rasterization_backend.rasterize(
78-
placeholders[0], placeholders[1], placeholders[2], (600, 800),
79-
enable_cull_face, self._num_layers, self._backend)
82+
placeholders[0], placeholders[1], placeholders[2],
83+
(self.IMAGE_WIDTH, self.IMAGE_HEIGHT), enable_cull_face,
84+
self._num_layers, self._backend)
8085
batch_size = shapes[0][0]
8186
self.assertEqual([batch_size],
8287
frame_buffer.triangle_id.get_shape().as_list()[:-3])
@@ -94,7 +99,8 @@ def test_rasterizer_rasterize_exception_raised(self, shapes, dtypes, backend):
9499
placeholders = self._create_placeholders(shapes, dtypes)
95100
with self.assertRaisesRegexp(KeyError, 'Backend is not supported'):
96101
rasterization_backend.rasterize(placeholders[0], placeholders[1],
97-
placeholders[2], (600, 800),
102+
placeholders[2],
103+
(self.IMAGE_WIDTH, self.IMAGE_HEIGHT),
98104
self._enable_cull_face, self._num_layers,
99105
backend)
100106

@@ -105,8 +111,9 @@ def test_rasterizer_all_vertices_visible(self):
105111
triangles = tf.convert_to_tensor([[0, 1, 2]], dtype=tf.int32)
106112
view_projection_matrix = tf.expand_dims(tf.eye(4), axis=0)
107113
frame_buffer = rasterization_backend.rasterize(
108-
vertices, triangles, view_projection_matrix, (100, 100),
109-
self._enable_cull_face, self._num_layers, self._backend)
114+
vertices, triangles, view_projection_matrix,
115+
(self.IMAGE_WIDTH, self.IMAGE_HEIGHT), self._enable_cull_face,
116+
self._num_layers, self._backend)
110117
self.assertAllEqual(frame_buffer.triangle_id.shape[:-1],
111118
frame_buffer.vertex_ids.shape[:-1])
112119
# Assert that triangle is visible.
@@ -115,3 +122,28 @@ def test_rasterizer_all_vertices_visible(self):
115122
# Assert that all three vertices are visible.
116123
self.assertAllLess(frame_buffer.triangle_id, 3)
117124
self.assertAllGreaterEqual(frame_buffer.triangle_id, 0)
125+
126+
def test_render_simple_triangle(self):
127+
"""Directly renders a rasterized triangle's barycentric coordinates."""
128+
w_vector = tf.constant([1.0, 1.0, 1.0], dtype=tf.float32)
129+
clip_init = tf.constant(
130+
[[[-0.5, -0.5, 0.8], [0.0, 0.5, 0.3], [0.5, -0.5, 0.3]]],
131+
dtype=tf.float32)
132+
clip_coordinates = clip_init * tf.reshape(w_vector, [1, 3, 1])
133+
triangles = tf.constant([[0, 1, 2]], dtype=tf.int32)
134+
135+
face_culling_enabled = False
136+
framebuffer = rasterization_backend.rasterize(
137+
clip_coordinates, triangles, tf.eye(4, batch_shape=[1]),
138+
(self.IMAGE_WIDTH, self.IMAGE_HEIGHT), face_culling_enabled,
139+
self._num_layers, self._backend)
140+
ones_image = tf.ones([1, self.IMAGE_HEIGHT, self.IMAGE_WIDTH, 1])
141+
rendered_coordinates = tf.concat(
142+
[framebuffer.barycentrics.value, ones_image], axis=-1)
143+
144+
baseline_image = rasterization_test_utils.load_baseline_image(
145+
'Simple_Triangle.png', rendered_coordinates.shape)
146+
147+
images_near, error_message = rasterization_test_utils.compare_images(
148+
self, baseline_image, rendered_coordinates)
149+
self.assertTrue(images_near, msg=error_message)
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright 2020 The TensorFlow Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Util functions for rasterization tests."""
15+
16+
import os
17+
18+
import numpy as np
19+
import tensorflow as tf
20+
21+
from tensorflow_graphics.geometry.transformation import look_at
22+
from tensorflow_graphics.rendering.camera import perspective
23+
from tensorflow_graphics.util import shape
24+
25+
26+
def make_perspective_matrix(image_width=None, image_height=None):
27+
"""Generates perspective matrix for a given image size.
28+
29+
Args:
30+
image_width: int representing image width.
31+
image_height: int representing image height.
32+
33+
Returns:
34+
Perspective matrix, tensor of shape [4, 4].
35+
36+
Note: Golden tests require image size to be fixed and equal to the size of
37+
golden image examples.
38+
"""
39+
40+
field_of_view = (40 * np.math.pi / 180,)
41+
near_plane = (0.01,)
42+
far_plane = (10.0,)
43+
return perspective.right_handed(field_of_view,
44+
(float(image_width) / float(image_height),),
45+
near_plane, far_plane)
46+
47+
48+
def make_look_at_matrix(
49+
camera_origin=(0.0, 0.0, 0.0), look_at_point=(0.0, 0.0, 0.0)):
50+
camera_up = (0.0, 1.0, 0.0)
51+
return look_at.right_handed(camera_origin, look_at_point, camera_up)
52+
53+
54+
def compare_images(test_case,
55+
baseline_image,
56+
image,
57+
max_outlier_fraction=0.005,
58+
pixel_error_threshold=0.04):
59+
"""Compares two image arrays.
60+
61+
The comparison is soft: the images are considered identical if fewer than
62+
max_outlier_fraction of the pixels differ by more than pixel_error_threshold
63+
of the full color value.
64+
65+
Differences in JPEG encoding can produce pixels with pretty large variation,
66+
so by default we use 0.04 (4%) for pixel_error_threshold and 0.005 (0.5%) for
67+
max_outlier_fraction.
68+
69+
Args:
70+
test_case: test_case.TestCase instance this util function is used in.
71+
baseline_image: tensor of shape [batch, height, width, channels] containing
72+
the baseline image.
73+
image: tensor of shape [batch, height, width, channels] containing the
74+
result image.
75+
max_outlier_fraction: fraction of pixels that may vary by more than the
76+
error threshold. 0.005 means 0.5% of pixels. Number of outliers are
77+
computed and compared per image.
78+
pixel_error_threshold: pixel values are considered to differ if their
79+
difference exceeds this amount. Range is 0.0 - 1.0.
80+
81+
Returns:
82+
Tuple of a boolean and string error message. Boolean indicates
83+
whether images are close to each other or not. Error message contains
84+
details of two images mismatch.
85+
"""
86+
tf.assert_equal(baseline_image.shape, image.shape)
87+
if baseline_image.dtype != image.dtype:
88+
return False, ("Image types %s and %s do not match" %
89+
(baseline_image.dtype, image.dtype))
90+
shape.check_static(
91+
tensor=baseline_image, tensor_name="baseline_image", has_rank=4)
92+
# Flatten height, width and channels dimensions since we're interested in
93+
# error per image.
94+
image_height, image_width = image.shape[1:3]
95+
baseline_image = tf.reshape(baseline_image, [baseline_image.shape[0]] + [-1])
96+
image = tf.reshape(image, [image.shape[0]] + [-1])
97+
abs_diff = tf.abs(baseline_image - image)
98+
outliers = tf.math.greater(abs_diff, pixel_error_threshold)
99+
num_outliers = tf.math.reduce_sum(tf.cast(outliers, tf.int32))
100+
perc_outliers = num_outliers / (image_height * image_width)
101+
error_msg = "{:.2%} pixels are not equal to baseline image pixels.".format(
102+
test_case.evaluate(perc_outliers) * 100.0)
103+
return test_case.evaluate(perc_outliers < max_outlier_fraction), error_msg
104+
105+
106+
def load_baseline_image(filename, image_shape=None):
107+
"""Loads baseline image and makes sure it is of the right shape.
108+
109+
Args:
110+
filename: file name of the image to load.
111+
image_shape: expected shape of the image.
112+
113+
Returns:
114+
tf.Tensor with baseline image
115+
"""
116+
image_path = tf.compat.v1.resource_loader.get_path_to_datafile(
117+
os.path.join("test_data", filename))
118+
file = tf.io.read_file(image_path)
119+
baseline_image = tf.cast(tf.image.decode_image(file), tf.float32) / 255.0
120+
baseline_image = tf.expand_dims(baseline_image, axis=0)
121+
if image_shape is not None:
122+
# Graph-mode requires image shape to be known in advance.
123+
baseline_image = tf.reshape(baseline_image, image_shape)
124+
return baseline_image
16.9 KB
Loading
48.2 KB
Loading

0 commit comments

Comments
 (0)