Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions tensorflow_graphics/geometry/convolution/graph_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def feature_steered_convolution(data,
var_c,
var_w,
var_b,
ordered=False,
name=None):
# pyformat: disable
"""Implements the Feature Steered graph convolution.
Expand Down Expand Up @@ -81,6 +82,8 @@ def feature_steered_convolution(data,
var_c: A 1-D tensor with shape `[W]`.
var_w: A 3-D tensor with shape `[W, C, D]`.
var_b: A 1-D tensor with shape `[D]`.
ordered: if True, neighbors is assumed to be ordered in the canonical,
row-major ordering, e.g. via `tf.sparse.reorder`
name: A name for this op. Defaults to
`graph_convolution_feature_steered_convolution`.

Expand Down Expand Up @@ -146,8 +149,12 @@ def feature_steered_convolution(data,
for q_m, w_m in zip(q_m_list, w_m_list):
# Compute `y_i_m = sum_{j in neighborhood(i)} q_m(x_i, x_j) * w_m * x_j`.
q_m = tf.expand_dims(q_m, axis=-1)
p_sum = utils.partition_sums_2d(q_m * x_sep, adjacency_ind_0,
adjacency.values)
terms = q_m * x_sep * tf.expand_dims(adjacency.values, axis=-1)
if ordered:
p_sum = tf.math.segment_sum(terms, adjacency_ind_0)
else:
p_sum = tf.math.unsorted_segment_sum(
terms, adjacency_ind_0, tf.reduce_max(adjacency_ind_0) + 1)
y_i_m.append(tf.matmul(p_sum, w_m))
y_out = tf.add_n(inputs=y_i_m) + tf.reshape(var_b, [1, -1])
if data_ndims > 2:
Expand All @@ -161,6 +168,7 @@ def edge_convolution_template(data,
edge_function,
reduction,
edge_function_kwargs,
ordered=False,
name=None):
# pyformat: disable
r"""A template for edge convolutions.
Expand Down Expand Up @@ -219,7 +227,9 @@ def edge_convolution_template(data,
in the equation above. For 'max' the reduction is a max over features in
which case the weights $$w_{ij}$$ are ignored.
edge_function_kwargs: A dict containing any additional keyword arguments to
be passed to `edge_function`.
be passed to `edge_function`.
ordered: if True, neighbors is assumed to be ordered in the canonical,
row-major ordering, e.g. via `tf.sparse.reorder`
name: A name for this op. Defaults to
`graph_convolution_edge_convolution_template`.

Expand Down Expand Up @@ -262,17 +272,29 @@ def edge_convolution_template(data,
**edge_function_kwargs)

if reduction == "weighted":
features = utils.partition_sums_2d(edge_features, adjacency_ind_0,
adjacency.values)
edge_features = tf.expand_dims(adjacency.values, axis=-1) * edge_features
if ordered:
features = tf.math.segment_sum(edge_features, adjacency_ind_0)
else:
features = tf.math.unsorted_segment_sum(
data=edge_features,
segment_ids=adjacency_ind_0,
num_segments=tf.reduce_max(adjacency_ind_0) + 1)
elif reduction == "max":
features = tf.math.segment_max(data=edge_features,
segment_ids=adjacency_ind_0)
features.set_shape(features.shape.merge_with(
(tf.compat.v1.dimension_value(x_flat.shape[0]),
tf.compat.v1.dimension_value(edge_features.shape[-1]))))
if ordered:
features = tf.math.segment_max(
edge_features, adjacency_ind_0)
else:
features = tf.math.unsorted_segment_max(
data=edge_features,
segment_ids=adjacency_ind_0,
num_segments=tf.reduce_max(adjacency_ind_0) + 1)
else:
raise ValueError("The reduction method must be 'weighted' or 'max'")

features.set_shape(features.shape.merge_with(
(tf.compat.v1.dimension_value(x_flat.shape[0]),
tf.compat.v1.dimension_value(edge_features.shape[-1]))))
if data_ndims > 2:
features = unflatten(features)
return features
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,38 @@
from tensorflow_graphics.util import test_case


def _dense_to_sparse(data):
def _shuffle_sparse(st):
indices = st.indices
values = st.values
order = tf.range(tf.shape(indices, out_type=tf.int64)[-2])
order = tf.random.shuffle(order)
indices = tf.gather(indices, order, axis=-2)
values = tf.gather(values, order, axis=0)
return tf.SparseTensor(indices, values, st.shape)


def _dense_to_sparse(data, ordered=True):
"""Convert a numpy array to a tf.SparseTensor."""
indices = np.where(data)
return tf.SparseTensor(
st = tf.SparseTensor(
np.stack(indices, axis=-1), data[indices], dense_shape=data.shape)
if not ordered:
st = _shuffle_sparse(st)
return st


def _dummy_data(batch_size, num_vertices, num_channels):
def _dummy_data(batch_size, num_vertices, num_channels, ordered=True):
"""Create inputs for feature_steered_convolution."""
if batch_size > 0:
data = np.zeros(
shape=(batch_size, num_vertices, num_channels), dtype=np.float32)
neighbors = _dense_to_sparse(
np.tile(np.eye(num_vertices, dtype=np.float32), (batch_size, 1, 1)))
np.tile(np.eye(num_vertices, dtype=np.float32), (batch_size, 1, 1)),
ordered=ordered)
else:
data = np.zeros(shape=(num_vertices, num_channels), dtype=np.float32)
neighbors = _dense_to_sparse(np.eye(num_vertices, dtype=np.float32))
neighbors = _dense_to_sparse(
np.eye(num_vertices, dtype=np.float32), ordered=ordered)
return data, neighbors


Expand All @@ -64,7 +79,8 @@ def _random_data(batch_size,
only_self_edges,
data_type=np.float32,
neighbors_type=np.float32,
sizes_type=np.int32):
sizes_type=np.int32,
ordered=True):
"""Create random inputs for feature_steered_convolution."""

def _random_data_2d(padding):
Expand Down Expand Up @@ -94,15 +110,15 @@ def _random_data_2d(padding):
neighbors = np.stack([i[1] for i in list_2d], 0).astype(neighbors_type)
if padding:
sizes = np.stack([i[2] for i in list_2d], 0).astype(sizes_type)
return data, _dense_to_sparse(neighbors), sizes
return data, _dense_to_sparse(neighbors, ordered=ordered), sizes
else:
return data, _dense_to_sparse(neighbors)
return data, _dense_to_sparse(neighbors, ordered=ordered)
else:
if padding:
raise ValueError("Padding only allowed with batched data.")
data, neighbors = _random_data_2d(padding=False)
return data.astype(data_type), _dense_to_sparse(
neighbors.astype(neighbors_type))
neighbors.astype(neighbors_type), ordered=ordered)


def _random_variables(in_channels,
Expand Down Expand Up @@ -358,21 +374,26 @@ def test_feature_steered_convolution_padding_preset(self, data, neighbors, u,
self.assertAllClose(y, expected)

@parameterized.parameters(
(1, 5, 1, 1, 1),
(2, 6, 3, 6, 5),
(5, 15, 6, 12, 8),
(1, 5, 1, 1, 1, True),
(2, 6, 3, 6, 5, True),
(5, 15, 6, 12, 8, True),
(1, 5, 1, 1, 1, False),
(2, 6, 3, 6, 5, False),
(5, 15, 6, 12, 8, False),
)
def test_feature_steered_convolution_padding_random(self, batch_size,
num_vertices, in_channels,
out_channels,
num_weight_matrices):
num_weight_matrices,
ordered):
"""Test mixed topology batches (random vertices and neighbors)."""
data, neighbors, sizes = _random_data(
batch_size,
num_vertices,
in_channels,
padding=True,
only_self_edges=False)
only_self_edges=False,
ordered=ordered)
u, v, c, w, b = _random_variables(in_channels, out_channels,
num_weight_matrices)

Expand Down Expand Up @@ -509,7 +530,7 @@ def _pass_through(self, vertex_features, neighbor_features):
"""A callable for `edge_convolution_template`."""
return neighbor_features

def _circular_2d_data(self, num_vertices, include_normals=False):
def _circular_2d_data(self, num_vertices, include_normals=False, ordered=True):
"""Create data for a circle graph."""
# Vertices are points distributed uniformly on a circle, with each point
# connected to its closest neighbor on either side.
Expand All @@ -520,7 +541,7 @@ def _circular_2d_data(self, num_vertices, include_normals=False):
eye = np.eye(num_vertices)
neighbors = np.maximum(np.roll(eye, 1, axis=1), np.roll(eye, -1,
axis=1)) * 0.5
return data, _dense_to_sparse(neighbors)
return data, _dense_to_sparse(neighbors, ordered=ordered)

def _edge_curvature_2d(self, vertex_features, neighbor_features):
"""A callable for `edge_convolution_template` that computes curvature."""
Expand Down Expand Up @@ -664,16 +685,35 @@ def test_edge_convolution_template_output_shape(self, batch_size,
self.assertAllEqual(y_shape[:-1], data.shape[:-1])

@parameterized.parameters(
(1, 10, 3, True, "weighted"),
(3, 6, 1, True, "weighted"),
(0, 10, 5, False, "weighted"),
(1, 10, 3, False, "max"),
(3, 6, 1, False, "max"),
(0, 10, 5, False, "max"),
(1, 10, 3, True, "weighted", True),
(3, 6, 1, True, "weighted", True),
(0, 10, 5, False, "weighted", True),
(1, 10, 3, False, "max", True),
(3, 6, 1, False, "max", True),
(0, 10, 5, False, "max", True),
(1, 10, 3, True, "weighted", True),
(3, 6, 1, True, "weighted", True),
(0, 10, 5, False, "weighted", True),
(1, 10, 3, False, "max", True),
(3, 6, 1, False, "max", True),
(0, 10, 5, False, "max", True),
(1, 10, 3, True, "weighted", False),
(3, 6, 1, True, "weighted", False),
(0, 10, 5, False, "weighted", False),
(1, 10, 3, False, "max", False),
(3, 6, 1, False, "max", False),
(0, 10, 5, False, "max", False),
(1, 10, 3, True, "weighted", False),
(3, 6, 1, True, "weighted", False),
(0, 10, 5, False, "weighted", False),
(1, 10, 3, False, "max", False),
(3, 6, 1, False, "max", False),
(0, 10, 5, False, "max", False),
)
def test_edge_convolution_template_jacobian_random(self, batch_size,
num_vertices, in_channels,
padding, reduction):
padding, reduction,
ordered):
"""Test the jacobian for random input data."""
random_data = _random_data(
batch_size,
Expand All @@ -682,7 +722,8 @@ def test_edge_convolution_template_jacobian_random(self, batch_size,
padding,
only_self_edges=False,
data_type=np.float64,
neighbors_type=np.float64)
neighbors_type=np.float64,
ordered=ordered)
data_init = random_data[0]
neighbors = random_data[1]
sizes = None if not padding else random_data[2]
Expand All @@ -694,7 +735,8 @@ def test_edge_convolution_template_jacobian_random(self, batch_size,
sizes=sizes,
edge_function=self._pass_through,
reduction=reduction,
edge_function_kwargs=dict())
edge_function_kwargs=dict(),
ordered=ordered)

self.assert_jacobian_is_correct(data, data_init, y)

Expand Down Expand Up @@ -753,7 +795,8 @@ def test_edge_convolution_template_jacobian_preset(self, num_vertices,

self.assert_jacobian_is_correct(data, data_init, y)

def test_edge_convolution_template_laplacian_smoothing(self):
@parameterized.parameters((True,), (False))
def test_edge_convolution_template_laplacian_smoothing(self, ordered):
r"""Test the expected result with laplacian smoothing.

Laplacian smoothing for meshes is defined as
Expand All @@ -780,7 +823,7 @@ def test_edge_convolution_template_laplacian_smoothing(self):

with self.subTest(name="circular_2d"):
num_vertices = 500
data, neighbors = self._circular_2d_data(num_vertices)
data, neighbors = self._circular_2d_data(num_vertices, ordered=ordered)

data_smoothed = gc.edge_convolution_template(
data=data,
Expand All @@ -794,7 +837,8 @@ def test_edge_convolution_template_laplacian_smoothing(self):

self.assertAllClose(data, data_smoothed_normalized)

def test_edge_convolution_template_curvature(self):
@parameterized.parameters((True,), (False))
def test_edge_convolution_template_curvature(self, ordered):
r"""Test the expected result with curvature.

(Approximate) curvature for meshes is defined as
Expand All @@ -811,7 +855,8 @@ def test_edge_convolution_template_curvature(self):
"""
# We can reuse `self._edge_curvature_2d` as the curvature functional.
num_vertices = 500
data, neighbors = self._circular_2d_data(num_vertices, include_normals=True)
data, neighbors = self._circular_2d_data(
num_vertices, include_normals=True, ordered=ordered)

data_curvature = gc.edge_convolution_template(
data=data,
Expand Down
32 changes: 18 additions & 14 deletions tensorflow_graphics/geometry/convolution/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,13 +456,15 @@ def convert_to_block_diag_2d(data,
return block_diag


def partition_sums_2d(data, group_ids, row_weights=None, name=None):
def partition_sums_2d(
data, group_ids, row_weights=None, is_sorted=False, name=None):
"""Sum over subsets of rows in a 2-D tensor.

Args:
data: 2-D tensor with shape `[D1, D2]`.
group_ids: 1-D `int` tensor with shape `[D1]`.
row_weights: 1-D tensor with shape `[D1]`. Can be `None`.
is_sorted: if True, group_ids are known to be sorted.
name: A name for this op. Defaults to 'utils_partition_sums_2d'.

Returns:
Expand All @@ -481,27 +483,29 @@ def partition_sums_2d(data, group_ids, row_weights=None, name=None):
raise TypeError("'group_ids' must be an integer tensor.")
elif group_ids.dtype != tf.int64:
group_ids = tf.cast(group_ids, dtype=tf.int64)
if row_weights is None:
row_weights = tf.ones_like(group_ids, dtype=data.dtype)
else:
row_weights = tf.convert_to_tensor(value=row_weights)

if row_weights.dtype != data.dtype:
raise TypeError("'data' and 'row_weights' must have the same type.")
shape.check_static(tensor=data, tensor_name="data", has_rank=2)
shape.check_static(tensor=group_ids, tensor_name="group_ids", has_rank=1)
shape.check_static(

if row_weights is None:
shape.compare_dimensions(
tensors=(data, group_ids),
tensor_names=("data", "group_ids"),
axes=0)
else:
row_weights = tf.convert_to_tensor(value=row_weights)
shape.check_static(
tensor=row_weights, tensor_name="row_weights", has_rank=1)
shape.compare_dimensions(
if row_weights.dtype != data.dtype:
raise TypeError("'data' and 'row_weights' must have the same type.")
shape.compare_dimensions(
tensors=(data, group_ids, row_weights),
tensor_names=("data", "group_ids", "row_weights"),
axes=0)
data = data * tf.expand_dims(row_weights, axis=-1)

num_rows = tf.size(input=group_ids, out_type=tf.int64)
sparse_indices = tf.stack((group_ids, tf.range(num_rows)), axis=1)
out_shape = (tf.reduce_max(input_tensor=group_ids) + 1, num_rows)
sparse = tf.SparseTensor(sparse_indices, row_weights, dense_shape=out_shape)
return tf.sparse.sparse_dense_matmul(sparse, data)
return tf.math.unsorted_segment_sum(
data, group_ids, tf.reduce_max(group_ids) + 1)


# API contains all public functions and classes.
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_graphics/util/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,9 @@ def compare_dimensions(tensors, axes, tensor_names=None):
tensor_names = _give_default_names(tensors, 'tensor')
if not tf.executing_eagerly():
dimensions = [
int(tensor.shape[axis]) if tensor.shape[axis].value is not None else 1
for tensor, axis in zip(tensors, axes)
]
1 if dim is None else dim for dim in
(tf.compat.v1.dimension_value(tensor.shape[axis])
for tensor, axis in zip(tensors, axes))]
else:
# In eager mode tensor.shape[axis].value doesn't exist if v2 behavior is
# enabled. Therefore we can use tf.shape() to return the shape as a tensor.
Expand Down