Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions tensorflow_graphics/rendering/framebuffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2020 The TensorFlow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Storage classes for framebuffers and related data."""

from typing import Dict, Optional

import dataclasses
import tensorflow as tf


@dataclasses.dataclass
class RasterizedAttribute(object):
"""A single rasterized attribute and optionally its screen-space derivatives.

Tensors are expected to have shape [batch, height, width, channels] or
[batch, num_layers, height, width, channels].

Immutable once created.
"""
value: tf.Tensor
d_dx: Optional[tf.Tensor] = None
d_dy: Optional[tf.Tensor] = None

def __post_init__(self):
# Checks that all input tensors have the same shape and rank.
tensors = [self.value, self.d_dx, self.d_dy]
shapes = [
tensor.shape.as_list() for tensor in tensors if tensor is not None
]
ranks = [len(shape) for shape in shapes]
if not all(rank == ranks[0] for rank in ranks):
raise ValueError(
"Expected value and derivatives to be of the same rank, but found"
f" ranks {shapes}")

same_as_value = True
static_shapes = [self.value.shape]
if self.d_dx is not None:
same_as_value = tf.logical_and(
same_as_value, tf.equal(tf.shape(self.value), tf.shape(self.d_dx)))
static_shapes.append(self.d_dx.shape)
if self.d_dy is not None:
same_as_value = tf.logical_and(
same_as_value, tf.equal(tf.shape(self.value), tf.shape(self.d_dy)))
static_shapes.append(self.d_dy.shape)
tf.debugging.assert_equal(
same_as_value,
True,
message="Expected all input shapes to be the same but found: " +
", ".join([str(s) for s in static_shapes]))


@dataclasses.dataclass
class Framebuffer(object):
"""A framebuffer holding rasterized values required for deferred shading.

Tensors are expected to have shape [batch, height, width, channels] or
[batch, num_layers, height, width, channels].

For now, the fields are specialized for triangle rendering. Other primitives
may be supported in the future.

Immutable once created. Uses cached_property to avoid creating redundant
tf ops when properties are accessed multiple times.
"""
# The barycentric weights of the pixel centers in the covering triangle.
barycentrics: RasterizedAttribute
# The index of the triangle covering this pixel. Not differentiable.
triangle_id: tf.Tensor
# The indices of the vertices of the triangle covering this pixel.
# Not differentiable.
vertex_ids: tf.Tensor
# A mask of the pixels covered by a triangle. 1 if covered, 0 if background.
# Not differentiable.
foreground_mask: tf.Tensor

# Other rasterized attribute values (e.g., colors, UVs, normals, etc.).
attributes: Dict[str, RasterizedAttribute] = dataclasses.field(
default_factory=dict)

def __post_init__(self):
# Checks that all buffers have rank and same shape up to the
# number of channels.
values = [self.barycentrics.value, self.triangle_id, self.vertex_ids,
self.foreground_mask]
values += [v.value for k, v in self.attributes.items()]

ranks = [len(v.shape) for v in values]
shapes = [tf.shape(v) for v in values]
if not all(rank == ranks[0] for rank in ranks):
raise ValueError(
f"Expected all inputs to have the same rank, but found {shapes}")

same_as_first = [
tf.reduce_all(tf.equal(shapes[0][:-1], s[:-1])) for s in shapes[1:]
]
all_same_as_first = tf.reduce_all(same_as_first)
tf.debugging.assert_equal(
all_same_as_first,
True,
message="Expected all input shapes to be the same "
"(up to channels), but found: " + ", ".join([str(s) for s in shapes]))
36 changes: 26 additions & 10 deletions tensorflow_graphics/rendering/opengl/rasterization_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import tensorflow as tf

from tensorflow_graphics.rendering import framebuffer as fb
from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import shape

Expand Down Expand Up @@ -92,7 +93,7 @@ def _dim_value(dim):
out vec4 output_color;

void main() {
output_color = vec4(round(triangle_index + 1.0), barycentric_coordinates, 1.0);
output_color = vec4(round(triangle_index), barycentric_coordinates, 1.0);
}
"""

Expand Down Expand Up @@ -123,13 +124,17 @@ def rasterize(vertices,
name: A name for this op. Defaults to 'rasterization_backend_rasterize'.

Returns:
A tuple of 3 elements. The first one of shape `[A1, ..., An, H, W, 1]`
representing the triangle index associated with each pixel. If no triangle
is associated to a pixel, the index is set to -1.
The second element in the tuple is of shape `[A1, ..., An, H, W, 3]` and
correspond to barycentric coordinates per pixel. The last element in the
tuple is of shape `[A1, ..., An, H, W, 1]` and stores a value of `0` of the
pixel is assciated with the background, and `1` with the foreground.
A Framebuffer containing the rasterized values: barycentrics, triangle_id,
foreground_mask, vertex_ids. Returned Tensors have shape
[batch, num_layers, height, width, channels]
Note: triangle_id contains the triangle id value for each pixel in the
output image. For pixels within the mesh, this is the integer value in the
range [0, num_vertices] from triangles. For vertices outside the mesh this
is 0; 0 can either indicate belonging to triangle 0, or being outside the
mesh. This ensures all returned triangle ids will validly index into the
vertex array, enabling the use of tf.gather with indices from this tensor.
The barycentric coordinates can be used to determine pixel validity instead.
See framebuffer.py for a description of the Framebuffer fields.
"""
with tf.compat.v1.name_scope(name, "rasterization_backend_rasterize",
(vertices, triangles, view_projection_matrices)):
Expand Down Expand Up @@ -182,7 +187,7 @@ def rasterize(vertices,
geometry_shader=geometry_shader,
fragment_shader=fragment_shader)

triangle_index = tf.cast(rasterized[..., 0], tf.int32) - 1
triangle_index = tf.cast(rasterized[..., 0], tf.int32)
# Slicing of the tensor will result in all batch dimensions being
# `None` for tensorflow graph mode, therefore we have to fix it in order to
# have explicit shape.
Expand All @@ -197,7 +202,18 @@ def rasterize(vertices,
mask = tf.cast(rasterized[..., 3], tf.int32)
mask = tf.reshape(mask, common_batch_shape + [height, width, 1])

return triangle_index, barycentric_coordinates, mask
triangles_batch = tf.broadcast_to(triangles,
common_batch_shape + triangles.shape)
vertex_ids = tf.gather(
triangles_batch, triangle_index[..., 0],
batch_dims=len(common_batch_shape))

return fb.Framebuffer(
foreground_mask=mask,
triangle_id=triangle_index,
vertex_ids=vertex_ids,
barycentrics=fb.RasterizedAttribute(
value=barycentric_coordinates, d_dx=None, d_dy=None))


# API contains all public functions and classes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,10 @@ def test_rasterize_exception_not_raised(self, shapes, dtypes):
def test_rasterize_batch_vertices_only(self):
triangles = np.array(((0, 1, 2),), np.int32)
vertices, view_projection_matrix = _generate_vertices_and_view_matrices()
_, _, mask = rasterization_backend.rasterize(vertices, triangles,
view_projection_matrix[0],
(_IMAGE_WIDTH, _IMAGE_HEIGHT))
predicted_fb = rasterization_backend.rasterize(
vertices, triangles, view_projection_matrix[0],
(_IMAGE_WIDTH, _IMAGE_HEIGHT))
mask = predicted_fb.foreground_mask
self.assertAllEqual(mask[0, ...], tf.ones_like(mask[0, ...]))

gt_layer_1 = np.zeros((_IMAGE_HEIGHT, _IMAGE_WIDTH, 1), np.float32)
Expand All @@ -103,11 +104,13 @@ def test_rasterize_batch_view_only(self):
triangles = np.array(((0, 1, 2),), np.int32)
vertices, view_projection_matrix = _generate_vertices_and_view_matrices()

_, _, mask = rasterization_backend.rasterize(vertices[0], triangles,
view_projection_matrix,
(_IMAGE_WIDTH, _IMAGE_HEIGHT))
self.assertAllEqual(mask[0, ...], tf.ones_like(mask[0, ...]))
self.assertAllEqual(mask[1, ...], tf.zeros_like(mask[1, ...]))
predicted_fb = rasterization_backend.rasterize(
vertices[0], triangles, view_projection_matrix,
(_IMAGE_WIDTH, _IMAGE_HEIGHT))
self.assertAllEqual(predicted_fb.foreground_mask[0, ...],
tf.ones_like(predicted_fb.foreground_mask[0, ...]))
self.assertAllEqual(predicted_fb.foreground_mask[1, ...],
tf.zeros_like(predicted_fb.foreground_mask[1, ...]))

def test_rasterize_preset(self):
camera_origin = (0.0, 0.0, 0.0)
Expand All @@ -132,22 +135,22 @@ def test_rasterize_preset(self):
(0.0, -_TRIANGLE_SIZE, depth))
triangles = np.array(((1, 2, 0), (0, 2, 3)), np.int32)

predicted_triangle_index, predicted_barycentrics, predicted_mask = rasterization_backend.rasterize(
predicted_fb = rasterization_backend.rasterize(
vertices, triangles, view_projection_matrix,
(_IMAGE_WIDTH, _IMAGE_HEIGHT))

with self.subTest(name="triangle_index"):
groundtruth_triangle_index = np.zeros((_IMAGE_HEIGHT, _IMAGE_WIDTH, 1),
dtype=np.int32)
groundtruth_triangle_index[..., :_IMAGE_WIDTH // 2, 0] = -1
groundtruth_triangle_index[..., :_IMAGE_WIDTH // 2, 0] = 0
groundtruth_triangle_index[:_IMAGE_HEIGHT // 2, _IMAGE_WIDTH // 2:, 0] = 1
self.assertAllEqual(groundtruth_triangle_index, predicted_triangle_index)
self.assertAllEqual(groundtruth_triangle_index, predicted_fb.triangle_id)

with self.subTest(name="mask"):
groundtruth_mask = np.ones((_IMAGE_HEIGHT, _IMAGE_WIDTH, 1),
dtype=np.int32)
groundtruth_mask[..., :_IMAGE_WIDTH // 2, 0] = 0
self.assertAllEqual(groundtruth_mask, predicted_mask)
self.assertAllEqual(groundtruth_mask, predicted_fb.foreground_mask)

attributes = np.array(
((1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0))).astype(np.float32)
Expand All @@ -162,7 +165,9 @@ def test_rasterize_preset(self):
barycentrics_gt_0 = perspective_correct_interpolation(
geometry_0, pixels_0)
self.assertAllClose(
barycentrics_gt_0, predicted_barycentrics[2:, 3:, :], atol=1e-3)
barycentrics_gt_0,
predicted_fb.barycentrics.value[2:, 3:, :],
atol=1e-3)

with self.subTest(name="barycentric_coordinates_triangle_1"):
geometry_1 = tf.gather(vertices, triangles[1, :])
Expand All @@ -171,4 +176,6 @@ def test_rasterize_preset(self):
barycentrics_gt_1 = perspective_correct_interpolation(
geometry_1, pixels_1)
self.assertAllClose(
barycentrics_gt_1, predicted_barycentrics[0:2, 3:, :], atol=1e-3)
barycentrics_gt_1,
predicted_fb.barycentrics.value[0:2, 3:, :],
atol=1e-3)
18 changes: 11 additions & 7 deletions tensorflow_graphics/rendering/rasterization_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,17 @@ def rasterize(vertices,
Supported options are defined in the RasterizationBackends enum.

Returns:
A tuple of 3 elements. The first one of shape `[A1, ..., An, H, W, 1]`
representing the triangle index associated with each pixel. If no triangle
is associated to a pixel, the index is set to -1.
The second element in the tuple is of shape `[A1, ..., An, H, W, 3]` and
correspond to barycentric coordinates per pixel. The last element in the
tuple is of shape `[A1, ..., An, H, W]` and stores a value of `0` of the
pixel is assciated with the background, and `1` with the foreground.
A Framebuffer containing the rasterized values: barycentrics, triangle_id,
foreground_mask, vertex_ids. Returned Tensors have shape
[batch, num_layers, height, width, channels]
Note: triangle_id contains the triangle id value for each pixel in the
output image. For pixels within the mesh, this is the integer value in the
range [0, num_vertices] from triangles. For vertices outside the mesh this
is 0; 0 can either indicate belonging to triangle 0, or being outside the
mesh. This ensures all returned triangle ids will validly index into the
vertex array, enabling the use of tf.gather with indices from this tensor.
The barycentric coordinates can be used to determine pixel validity instead.
See framebuffer.py for a description of the Framebuffer fields.
"""
return _BACKENDS[backend].rasterize(vertices, triangles,
view_projection_matrices, image_size)
Expand Down
53 changes: 53 additions & 0 deletions tensorflow_graphics/rendering/tests/framebuffer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2020 The TensorFlow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for tensorflow_graphics.rendering.framebuffer."""

import tensorflow as tf
from tensorflow_graphics.rendering import framebuffer as fb
from tensorflow_graphics.util import test_case


class FramebufferTest(test_case.TestCase):

def test_initialize_rasterized_attribute_with_wrong_rank(self):
with self.assertRaisesRegex(
ValueError, "Expected value and derivatives to be of the same rank"):
fb.RasterizedAttribute(
tf.ones([4, 4, 1]), tf.ones([4, 3]), tf.ones([3, 4, 4, 5, 5]))

def test_initialize_rasterized_attribute_with_wrong_shapes(self):
with self.assertRaisesRegex(tf.errors.InvalidArgumentError,
"Expected all input shapes to be the same"):
fb.RasterizedAttribute(tf.ones([1, 1, 4, 4, 1]), tf.ones([1, 1, 4, 3, 1]))

def test_initialize_framebuffer_with_wrong_rank(self):
with self.assertRaisesRegex(ValueError,
"Expected all inputs to have the same rank"):
fb.Framebuffer(
fb.RasterizedAttribute(tf.ones([1, 1, 4, 4, 1])), tf.ones([4, 3]),
tf.ones([3, 4, 4, 5, 5]), tf.ones([3, 4, 4, 5, 5]))

def test_initialize_framebuffer_with_wrong_shapes(self):
with self.assertRaisesRegex(tf.errors.InvalidArgumentError,
"Expected all input shapes to be the same"):
fb.Framebuffer(
fb.RasterizedAttribute(tf.ones([1, 1, 4, 4, 3])),
tf.ones([1, 1, 4, 4, 1]), tf.ones([1, 1, 4, 4, 3]),
tf.ones([1, 1, 4, 4, 1]),
{"an_attr": fb.RasterizedAttribute(tf.ones([1, 1, 4, 3, 4]))})


if __name__ == "__main__":
tf.test.main()
28 changes: 12 additions & 16 deletions tensorflow_graphics/rendering/triangle_rasterizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,36 +130,32 @@ def rasterize(vertices,

view_projection_matrix = tf.linalg.matmul(perspective_matrix,
model_to_eye_matrix)
triangle_index, _, mask = rasterization_backend.rasterize(
vertices, triangles, view_projection_matrix, image_size_backend,
backend)
outputs = {"mask": mask, "triangle_indices": triangle_index}
rasterized = rasterization_backend.rasterize(vertices, triangles,
view_projection_matrix,
image_size_backend, backend)
outputs = {
"mask": rasterized.foreground_mask,
"triangle_indices": rasterized.triangle_id
}

vertices = tf.gather(vertices, triangles, axis=-2)

# Gather does not work on negative indices, which is the case for the pixel
# associated to the background.
triangle_index = triangle_index * mask
# Extract batch shape in order to make sure it is preserved after `gather`
# operation.
batch_shape = triangle_index.shape[:-3]
batch_shape = rasterized.triangle_id.shape[:-3]
batch_shape = [_dim_value(dim) for dim in batch_shape]
# Remove last dimension of `triangle_index` in order to make it compatible
# with gather operations.
triangle_index_lean = tf.squeeze(triangle_index, axis=-1)

vertices_per_pixel = tf.gather(
vertices, triangle_index_lean, axis=-3, batch_dims=len(batch_shape))
vertices, rasterized.vertex_ids, batch_dims=len(batch_shape))
barycentrics = _perspective_correct_barycentrics(vertices_per_pixel,
model_to_eye_matrix,
perspective_matrix,
image_size_float)
mask_float = tf.cast(mask, vertices.dtype)
mask_float = tf.cast(rasterized.foreground_mask, vertices.dtype)
outputs["barycentrics"] = mask_float * barycentrics

for key, attribute in attributes.items():
attribute = tf.convert_to_tensor(value=attribute)
outputs[key] = mask_float * _perspective_correct_attributes(
attribute, barycentrics, triangles, triangle_index_lean,
attribute, barycentrics, triangles, rasterized.triangle_id[..., 0],
len(batch_shape))

return outputs
Expand Down