diff --git a/tensorflow_graphics/geometry/convolution/graph_convolution.py b/tensorflow_graphics/geometry/convolution/graph_convolution.py index 344fb7292..8ec7e2b72 100644 --- a/tensorflow_graphics/geometry/convolution/graph_convolution.py +++ b/tensorflow_graphics/geometry/convolution/graph_convolution.py @@ -143,11 +143,15 @@ def feature_steered_convolution(data, x_sep = tf.gather(x_flat, adjacency_ind_1) q_m_list = tf.unstack(weights_q, axis=-1) w_m_list = tf.unstack(var_w, axis=0) + + x_flat_shape = tf.shape(input=x_flat) for q_m, w_m in zip(q_m_list, w_m_list): # Compute `y_i_m = sum_{j in neighborhood(i)} q_m(x_i, x_j) * w_m * x_j`. q_m = tf.expand_dims(q_m, axis=-1) - p_sum = utils.partition_sums_2d(q_m * x_sep, adjacency_ind_0, - adjacency.values) + p_sum = tf.math.unsorted_segment_sum( + data=(q_m * x_sep) * tf.expand_dims(adjacency.values, -1), + segment_ids=adjacency_ind_0, + num_segments=x_flat_shape[0]) y_i_m.append(tf.matmul(p_sum, w_m)) y_out = tf.add_n(inputs=y_i_m) + tf.reshape(var_b, [1, -1]) if data_ndims > 2: @@ -262,17 +266,22 @@ def edge_convolution_template(data, **edge_function_kwargs) if reduction == "weighted": - features = utils.partition_sums_2d(edge_features, adjacency_ind_0, - adjacency.values) + edge_features_weighted = edge_features * tf.expand_dims( + adjacency.values, -1) + features = tf.math.unsorted_segment_sum( + data=edge_features_weighted, + segment_ids=adjacency_ind_0, + num_segments=tf.shape(input=x_flat)[0]) elif reduction == "max": features = tf.math.segment_max(data=edge_features, segment_ids=adjacency_ind_0) - features.set_shape(features.shape.merge_with( - (tf.compat.v1.dimension_value(x_flat.shape[0]), - tf.compat.v1.dimension_value(edge_features.shape[-1])))) else: raise ValueError("The reduction method must be 'weighted' or 'max'") + features.set_shape(features.shape.merge_with( + (tf.compat.v1.dimension_value(x_flat.shape[0]), + tf.compat.v1.dimension_value(edge_features.shape[-1])))) + if data_ndims > 2: features = unflatten(features) return features diff --git a/tensorflow_graphics/geometry/convolution/tests/graph_convolution_test.py b/tensorflow_graphics/geometry/convolution/tests/graph_convolution_test.py index 1c9f92118..7d3c0ac85 100644 --- a/tensorflow_graphics/geometry/convolution/tests/graph_convolution_test.py +++ b/tensorflow_graphics/geometry/convolution/tests/graph_convolution_test.py @@ -663,6 +663,54 @@ def test_edge_convolution_template_output_shape(self, batch_size, with self.subTest(name="shape"): self.assertAllEqual(y_shape[:-1], data.shape[:-1]) + def test_edge_convolution_template_zero_neighbors(self): + """Check that vertices with no neighbors map to zeros in the output.""" + # We can reuse `self._edge_curvature_2d` as the curvature functional. + num_vertices = 500 + data, neighbors = self._circular_2d_data(num_vertices, include_normals=True) + + # Interleave the data with rows filled with random data, these rows will + # have no neighbors in the adjacency matrix so should map to all zeros in + # the output. + rows_odd = tf.expand_dims( + tf.range(start=1, limit=(2 * num_vertices), delta=2), -1) + rows_even = tf.expand_dims( + tf.range(start=0, limit=(2 * num_vertices + 1), delta=2), -1) + data_interleaved = tf.scatter_nd( + indices=rows_odd, updates=data, + shape=(2 * num_vertices + 1, tf.shape(input=data)[-1])) + random_data = tf.random.uniform(shape=(data.shape[0] + 1, data.shape[-1]), + dtype=data.dtype) + random_interleaved = tf.scatter_nd( + indices=rows_even, updates=random_data, + shape=(2 * num_vertices + 1, tf.shape(input=data)[-1])) + data_interleaved = data_interleaved + random_interleaved + neighbors_interleaved_indices = neighbors.indices * 2 + 1 + neighbors_interleaved = tf.SparseTensor( + indices=neighbors_interleaved_indices, + values=neighbors.values, + dense_shape=(2 * num_vertices + 1, 2 * num_vertices + 1)) + + # Convolve the interleaved data. + data_curvature = gc.edge_convolution_template( + data=data_interleaved, + neighbors=neighbors_interleaved, + sizes=None, + edge_function=self._edge_curvature_2d, + reduction="weighted", + edge_function_kwargs=dict()) + + self.assertEqual(data_curvature.shape, + (2 * num_vertices + 1, 1)) + + # The rows corresponding to the original input data measure the curvature. + # The curvature at any point on a circle of radius 1 should be 1. + # The interleaved rows of random data should map to zeros in the output. + self.assertAllClose(data_curvature[1::2, :], + np.ones(shape=(num_vertices, 1))) + self.assertAllClose(data_curvature[::2, :], + np.zeros(shape=(num_vertices + 1, 1))) + @parameterized.parameters( (1, 10, 3, True, "weighted"), (3, 6, 1, True, "weighted"), diff --git a/tensorflow_graphics/geometry/convolution/tests/utils_test.py b/tensorflow_graphics/geometry/convolution/tests/utils_test.py index de92ee62e..4ffc11885 100644 --- a/tensorflow_graphics/geometry/convolution/tests/utils_test.py +++ b/tensorflow_graphics/geometry/convolution/tests/utils_test.py @@ -648,83 +648,5 @@ def test_convert_to_block_diag_2d_jacobian_random(self): self.assert_jacobian_is_correct(sparse_val, sparse_val_init, y.values) -class UtilsPartitionSums2dTests(test_case.TestCase): - - def _numpy_ground_truth(self, data, group_ids, row_weights): - """Ground truth row sum in numpy.""" - out_weighted = np.zeros([np.max(group_ids) + 1, data.shape[1]]) - out_uniform = np.zeros_like(out_weighted) - for i, val in enumerate(group_ids): - out_uniform[val] += data[i] - out_weighted[val] += row_weights[i] * data[i] - return out_weighted, out_uniform - - @parameterized.parameters( - ("must have a rank of 2", ((3, 4, 5), (3,)), (tf.float32, tf.int32)), - ("must have a rank of 2", ((3,), (3,)), (tf.float32, tf.int32)), - ("must have a rank of 1", ((3, 3), (3, 1)), (tf.float32, tf.int32)), - ("must have a rank of 1", ((3, 3), (3,), (3, 1)), - (tf.float32, tf.int32, tf.float32)), - ) - def test_partition_sums_2d_input_shapes(self, error_msg, shapes, dtypes): - self.assert_exception_is_raised(utils.partition_sums_2d, error_msg, shapes, - dtypes) - - @parameterized.parameters( - ("'data' and 'row_weights' must have the same type.", - (tf.float32, tf.int32, tf.float64)), - ("'group_ids' must be an integer tensor.", - (tf.float32, tf.float32, tf.float32)), - ) - def test_partition_sums_2d_exception_raised_types(self, error_msg, dtypes): - """Check the exceptions with invalid input types.""" - x1 = tf.zeros(shape=(3, 3), dtype=dtypes[0]) - x2 = tf.zeros(shape=(3), dtype=dtypes[1]) - x3 = tf.zeros(shape=(3), dtype=dtypes[2]) - - with self.assertRaisesRegexp(TypeError, error_msg): - utils.partition_sums_2d(x1, x2, x3) - - def test_partition_sums_2d_single_output_dim(self): - """Test when there is a single output dimension.""" - data = np.ones([10, 10]) - group_ids = np.zeros([10], dtype=np.int32) - gt = np.sum(data, axis=0, keepdims=True) - - self.assertAllEqual(gt, utils.partition_sums_2d(data, group_ids)) - - def test_partition_sums_2d_random(self): - """Test with random inputs.""" - data_shape = np.random.randint(low=50, high=100, size=2) - data = np.random.uniform(size=data_shape) - group_ids = np.random.randint( - low=0, high=np.random.randint(low=50, high=150), size=data_shape[0]) - row_weights = np.random.uniform(size=group_ids.shape) - - gt_weighted, gt_uniform = self._numpy_ground_truth(data, group_ids, - row_weights) - weighted_sum = utils.partition_sums_2d( - data, group_ids, row_weights=row_weights) - uniform_sum = utils.partition_sums_2d(data, group_ids, row_weights=None) - - self.assertAllClose(weighted_sum, gt_weighted) - self.assertAllClose(uniform_sum, gt_uniform) - - def test_partition_sums_2d_jacobian_random(self): - """Test the jacobian with random inputs.""" - data_shape = np.random.randint(low=5, high=10, size=2) - data_init = np.random.uniform(size=data_shape) - data = tf.convert_to_tensor(value=data_init) - group_ids = np.random.randint( - low=0, high=np.random.randint(low=50, high=150), size=data_shape[0]) - row_weights_init = np.random.uniform(size=group_ids.shape) - row_weights = tf.convert_to_tensor(value=row_weights_init) - - y = utils.partition_sums_2d(data, group_ids, row_weights) - - self.assert_jacobian_is_correct(data, data_init, y) - self.assert_jacobian_is_correct(row_weights, row_weights_init, y) - - if __name__ == "__main__": test_case.main() diff --git a/tensorflow_graphics/geometry/convolution/utils.py b/tensorflow_graphics/geometry/convolution/utils.py index 9ca62284b..0a8cd42aa 100644 --- a/tensorflow_graphics/geometry/convolution/utils.py +++ b/tensorflow_graphics/geometry/convolution/utils.py @@ -455,53 +455,5 @@ def convert_to_block_diag_2d(data, return block_diag -def partition_sums_2d(data, group_ids, row_weights=None, name=None): - """Sum over subsets of rows in a 2-D tensor. - - Args: - data: 2-D tensor with shape `[D1, D2]`. - group_ids: 1-D `int` tensor with shape `[D1]`. - row_weights: 1-D tensor with shape `[D1]`. Can be `None`. - name: A name for this op. Defaults to 'utils_partition_sums_2d'. - - Returns: - A 2-D tensor with shape `[max(group_ids) + 1, D2]` where - `output[i, :] = sum(data[j, :] * weight[j] * 1(group_ids[j] == i)), 1(.)` - is the indicator function. - - Raises: - ValueError: if the inputs have invalid dimensions or types. - """ - with tf.compat.v1.name_scope(name, "utils_partition_sums_2d", - [data, group_ids, row_weights]): - data = tf.convert_to_tensor(value=data) - group_ids = tf.convert_to_tensor(value=group_ids) - if not group_ids.dtype.is_integer: - raise TypeError("'group_ids' must be an integer tensor.") - elif group_ids.dtype != tf.int64: - group_ids = tf.cast(group_ids, dtype=tf.int64) - if row_weights is None: - row_weights = tf.ones_like(group_ids, dtype=data.dtype) - else: - row_weights = tf.convert_to_tensor(value=row_weights) - - if row_weights.dtype != data.dtype: - raise TypeError("'data' and 'row_weights' must have the same type.") - shape.check_static(tensor=data, tensor_name="data", has_rank=2) - shape.check_static(tensor=group_ids, tensor_name="group_ids", has_rank=1) - shape.check_static( - tensor=row_weights, tensor_name="row_weights", has_rank=1) - shape.compare_dimensions( - tensors=(data, group_ids, row_weights), - tensor_names=("data", "group_ids", "row_weights"), - axes=0) - - num_rows = tf.size(input=group_ids, out_type=tf.int64) - sparse_indices = tf.stack((group_ids, tf.range(num_rows)), axis=1) - out_shape = (tf.reduce_max(input_tensor=group_ids) + 1, num_rows) - sparse = tf.SparseTensor(sparse_indices, row_weights, dense_shape=out_shape) - return tf.sparse.sparse_dense_matmul(sparse, data) - - # API contains all public functions and classes. __all__ = []