Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions tensorflow_graphics/geometry/convolution/graph_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,15 @@ def feature_steered_convolution(data,
x_sep = tf.gather(x_flat, adjacency_ind_1)
q_m_list = tf.unstack(weights_q, axis=-1)
w_m_list = tf.unstack(var_w, axis=0)

x_flat_shape = tf.shape(input=x_flat)
for q_m, w_m in zip(q_m_list, w_m_list):
# Compute `y_i_m = sum_{j in neighborhood(i)} q_m(x_i, x_j) * w_m * x_j`.
q_m = tf.expand_dims(q_m, axis=-1)
p_sum = utils.partition_sums_2d(q_m * x_sep, adjacency_ind_0,
adjacency.values)
p_sum = tf.math.unsorted_segment_sum(
data=(q_m * x_sep) * tf.expand_dims(adjacency.values, -1),
segment_ids=adjacency_ind_0,
num_segments=x_flat_shape[0])
y_i_m.append(tf.matmul(p_sum, w_m))
y_out = tf.add_n(inputs=y_i_m) + tf.reshape(var_b, [1, -1])
if data_ndims > 2:
Expand Down Expand Up @@ -262,17 +266,22 @@ def edge_convolution_template(data,
**edge_function_kwargs)

if reduction == "weighted":
features = utils.partition_sums_2d(edge_features, adjacency_ind_0,
adjacency.values)
edge_features_weighted = edge_features * tf.expand_dims(
adjacency.values, -1)
features = tf.math.unsorted_segment_sum(
data=edge_features_weighted,
segment_ids=adjacency_ind_0,
num_segments=tf.shape(input=x_flat)[0])
elif reduction == "max":
features = tf.math.segment_max(data=edge_features,
segment_ids=adjacency_ind_0)
features.set_shape(features.shape.merge_with(
(tf.compat.v1.dimension_value(x_flat.shape[0]),
tf.compat.v1.dimension_value(edge_features.shape[-1]))))
else:
raise ValueError("The reduction method must be 'weighted' or 'max'")

features.set_shape(features.shape.merge_with(
(tf.compat.v1.dimension_value(x_flat.shape[0]),
tf.compat.v1.dimension_value(edge_features.shape[-1]))))

if data_ndims > 2:
features = unflatten(features)
return features
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,54 @@ def test_edge_convolution_template_output_shape(self, batch_size,
with self.subTest(name="shape"):
self.assertAllEqual(y_shape[:-1], data.shape[:-1])

def test_edge_convolution_template_zero_neighbors(self):
"""Check that vertices with no neighbors map to zeros in the output."""
# We can reuse `self._edge_curvature_2d` as the curvature functional.
num_vertices = 500
data, neighbors = self._circular_2d_data(num_vertices, include_normals=True)

# Interleave the data with rows filled with random data, these rows will
# have no neighbors in the adjacency matrix so should map to all zeros in
# the output.
rows_odd = tf.expand_dims(
tf.range(start=1, limit=(2 * num_vertices), delta=2), -1)
rows_even = tf.expand_dims(
tf.range(start=0, limit=(2 * num_vertices + 1), delta=2), -1)
data_interleaved = tf.scatter_nd(
indices=rows_odd, updates=data,
shape=(2 * num_vertices + 1, tf.shape(input=data)[-1]))
random_data = tf.random.uniform(shape=(data.shape[0] + 1, data.shape[-1]),
dtype=data.dtype)
random_interleaved = tf.scatter_nd(
indices=rows_even, updates=random_data,
shape=(2 * num_vertices + 1, tf.shape(input=data)[-1]))
data_interleaved = data_interleaved + random_interleaved
neighbors_interleaved_indices = neighbors.indices * 2 + 1
neighbors_interleaved = tf.SparseTensor(
indices=neighbors_interleaved_indices,
values=neighbors.values,
dense_shape=(2 * num_vertices + 1, 2 * num_vertices + 1))

# Convolve the interleaved data.
data_curvature = gc.edge_convolution_template(
data=data_interleaved,
neighbors=neighbors_interleaved,
sizes=None,
edge_function=self._edge_curvature_2d,
reduction="weighted",
edge_function_kwargs=dict())

self.assertEqual(data_curvature.shape,
(2 * num_vertices + 1, 1))

# The rows corresponding to the original input data measure the curvature.
# The curvature at any point on a circle of radius 1 should be 1.
# The interleaved rows of random data should map to zeros in the output.
self.assertAllClose(data_curvature[1::2, :],
np.ones(shape=(num_vertices, 1)))
self.assertAllClose(data_curvature[::2, :],
np.zeros(shape=(num_vertices + 1, 1)))

@parameterized.parameters(
(1, 10, 3, True, "weighted"),
(3, 6, 1, True, "weighted"),
Expand Down
78 changes: 0 additions & 78 deletions tensorflow_graphics/geometry/convolution/tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,83 +648,5 @@ def test_convert_to_block_diag_2d_jacobian_random(self):
self.assert_jacobian_is_correct(sparse_val, sparse_val_init, y.values)


class UtilsPartitionSums2dTests(test_case.TestCase):

def _numpy_ground_truth(self, data, group_ids, row_weights):
"""Ground truth row sum in numpy."""
out_weighted = np.zeros([np.max(group_ids) + 1, data.shape[1]])
out_uniform = np.zeros_like(out_weighted)
for i, val in enumerate(group_ids):
out_uniform[val] += data[i]
out_weighted[val] += row_weights[i] * data[i]
return out_weighted, out_uniform

@parameterized.parameters(
("must have a rank of 2", ((3, 4, 5), (3,)), (tf.float32, tf.int32)),
("must have a rank of 2", ((3,), (3,)), (tf.float32, tf.int32)),
("must have a rank of 1", ((3, 3), (3, 1)), (tf.float32, tf.int32)),
("must have a rank of 1", ((3, 3), (3,), (3, 1)),
(tf.float32, tf.int32, tf.float32)),
)
def test_partition_sums_2d_input_shapes(self, error_msg, shapes, dtypes):
self.assert_exception_is_raised(utils.partition_sums_2d, error_msg, shapes,
dtypes)

@parameterized.parameters(
("'data' and 'row_weights' must have the same type.",
(tf.float32, tf.int32, tf.float64)),
("'group_ids' must be an integer tensor.",
(tf.float32, tf.float32, tf.float32)),
)
def test_partition_sums_2d_exception_raised_types(self, error_msg, dtypes):
"""Check the exceptions with invalid input types."""
x1 = tf.zeros(shape=(3, 3), dtype=dtypes[0])
x2 = tf.zeros(shape=(3), dtype=dtypes[1])
x3 = tf.zeros(shape=(3), dtype=dtypes[2])

with self.assertRaisesRegexp(TypeError, error_msg):
utils.partition_sums_2d(x1, x2, x3)

def test_partition_sums_2d_single_output_dim(self):
"""Test when there is a single output dimension."""
data = np.ones([10, 10])
group_ids = np.zeros([10], dtype=np.int32)
gt = np.sum(data, axis=0, keepdims=True)

self.assertAllEqual(gt, utils.partition_sums_2d(data, group_ids))

def test_partition_sums_2d_random(self):
"""Test with random inputs."""
data_shape = np.random.randint(low=50, high=100, size=2)
data = np.random.uniform(size=data_shape)
group_ids = np.random.randint(
low=0, high=np.random.randint(low=50, high=150), size=data_shape[0])
row_weights = np.random.uniform(size=group_ids.shape)

gt_weighted, gt_uniform = self._numpy_ground_truth(data, group_ids,
row_weights)
weighted_sum = utils.partition_sums_2d(
data, group_ids, row_weights=row_weights)
uniform_sum = utils.partition_sums_2d(data, group_ids, row_weights=None)

self.assertAllClose(weighted_sum, gt_weighted)
self.assertAllClose(uniform_sum, gt_uniform)

def test_partition_sums_2d_jacobian_random(self):
"""Test the jacobian with random inputs."""
data_shape = np.random.randint(low=5, high=10, size=2)
data_init = np.random.uniform(size=data_shape)
data = tf.convert_to_tensor(value=data_init)
group_ids = np.random.randint(
low=0, high=np.random.randint(low=50, high=150), size=data_shape[0])
row_weights_init = np.random.uniform(size=group_ids.shape)
row_weights = tf.convert_to_tensor(value=row_weights_init)

y = utils.partition_sums_2d(data, group_ids, row_weights)

self.assert_jacobian_is_correct(data, data_init, y)
self.assert_jacobian_is_correct(row_weights, row_weights_init, y)


if __name__ == "__main__":
test_case.main()
48 changes: 0 additions & 48 deletions tensorflow_graphics/geometry/convolution/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,53 +455,5 @@ def convert_to_block_diag_2d(data,
return block_diag


def partition_sums_2d(data, group_ids, row_weights=None, name=None):
"""Sum over subsets of rows in a 2-D tensor.

Args:
data: 2-D tensor with shape `[D1, D2]`.
group_ids: 1-D `int` tensor with shape `[D1]`.
row_weights: 1-D tensor with shape `[D1]`. Can be `None`.
name: A name for this op. Defaults to 'utils_partition_sums_2d'.

Returns:
A 2-D tensor with shape `[max(group_ids) + 1, D2]` where
`output[i, :] = sum(data[j, :] * weight[j] * 1(group_ids[j] == i)), 1(.)`
is the indicator function.

Raises:
ValueError: if the inputs have invalid dimensions or types.
"""
with tf.compat.v1.name_scope(name, "utils_partition_sums_2d",
[data, group_ids, row_weights]):
data = tf.convert_to_tensor(value=data)
group_ids = tf.convert_to_tensor(value=group_ids)
if not group_ids.dtype.is_integer:
raise TypeError("'group_ids' must be an integer tensor.")
elif group_ids.dtype != tf.int64:
group_ids = tf.cast(group_ids, dtype=tf.int64)
if row_weights is None:
row_weights = tf.ones_like(group_ids, dtype=data.dtype)
else:
row_weights = tf.convert_to_tensor(value=row_weights)

if row_weights.dtype != data.dtype:
raise TypeError("'data' and 'row_weights' must have the same type.")
shape.check_static(tensor=data, tensor_name="data", has_rank=2)
shape.check_static(tensor=group_ids, tensor_name="group_ids", has_rank=1)
shape.check_static(
tensor=row_weights, tensor_name="row_weights", has_rank=1)
shape.compare_dimensions(
tensors=(data, group_ids, row_weights),
tensor_names=("data", "group_ids", "row_weights"),
axes=0)

num_rows = tf.size(input=group_ids, out_type=tf.int64)
sparse_indices = tf.stack((group_ids, tf.range(num_rows)), axis=1)
out_shape = (tf.reduce_max(input_tensor=group_ids) + 1, num_rows)
sparse = tf.SparseTensor(sparse_indices, row_weights, dense_shape=out_shape)
return tf.sparse.sparse_dense_matmul(sparse, data)


# API contains all public functions and classes.
__all__ = []