From 42ee754cb7ec3f90329ad66a7542d0eaa373c4e1 Mon Sep 17 00:00:00 2001 From: Dominic Jack Date: Tue, 21 Jan 2020 15:28:13 +1000 Subject: [PATCH] sparse_matmul implementation, softmax stability, layer configs --- benchmarks/graph_convolution_benchmark.py | 134 ++++++++++++++++++ .../geometry/convolution/graph_convolution.py | 93 +++++++++--- .../tests/graph_convolution_test.py | 19 ++- .../nn/layer/graph_convolution.py | 51 +++++-- 4 files changed, 263 insertions(+), 34 deletions(-) create mode 100644 benchmarks/graph_convolution_benchmark.py diff --git a/benchmarks/graph_convolution_benchmark.py b/benchmarks/graph_convolution_benchmark.py new file mode 100644 index 000000000..8a1d41339 --- /dev/null +++ b/benchmarks/graph_convolution_benchmark.py @@ -0,0 +1,134 @@ +from absl import logging +from absl import app, flags +import numpy as np +import tensorflow as tf +from tensorflow_graphics.nn.layer.graph_convolution import \ + FeatureSteeredConvolutionKerasLayer +from tensorflow_graphics.geometry.convolution.graph_convolution import \ + SparseImplementation + +logging.info('Finished imports') + +Lambda = tf.keras.layers.Lambda +Input = tf.keras.Input + +flags.DEFINE_boolean('jit', default=False, help='use XLA jit compilation') +flags.DEFINE_boolean('sparse', default=False, help='use sparse implementation') +flags.DEFINE_boolean('sort', default=False, help='use sorted indices') +flags.DEFINE_boolean( + 'backward', default=False, help='benchmark forward and backward pass') +flags.DEFINE_integer( + 'burn_iters', default=10, help='number of burn in iterations') +flags.DEFINE_integer('nv', default=100000, help='number of vertices') +flags.DEFINE_integer('ne', + default=-1, + help='number of edges, -1 will result in using 10*nv') +flags.DEFINE_integer('min_iters', + default=20, + help='minimum number of iterations to benchmark') +flags.DEFINE_integer( + 'num_weight_matrices', default=8, help='number of weight matrices') +flags.DEFINE_integer( + 'num_output_channels', default=32, help='number of output channels') +flags.DEFINE_integer('num_layers', default=10, help='number of layers') + + +def summarize(result, print_fn=print): + """ + Args: + result: output of a tf.test.Benchmark.run_op_benchmark call. + print_fn: print-like function. + """ + print_fn('Wall time (ms): {}'.format(result['wall_time'] * 1000)) + gpu_mem = result['extras'].get('allocator_maximum_num_bytes_GPU_0_bfc', 0) + print_fn('Memory (Mb): {}'.format(gpu_mem / 1024**2)) + + +def get_data(num_vertices, num_edges, sort=True): + if num_edges == -1: + num_edges = 10 * num_vertices + vertices = np.random.uniform(size=(num_vertices, 3)).astype(np.float32) + # replace=False below gives memory issues + indices = np.random.choice(num_vertices**2, num_edges, replace=True) + if sort: + indices.sort() + i, j = np.unravel_index(indices, (num_vertices, num_vertices)) # pylint: disable=unbalanced-tuple-unpacking + + counts = np.zeros((num_vertices,), dtype=np.int64) + for ii in i: + counts[ii] += 1 + weights = (1. / counts)[i].astype(np.float32) + indices = np.stack((i, j), axis=-1) + + return vertices, indices, weights + + +def main(_): + FLAGS = flags.FLAGS + tf.config.optimizer.set_jit(FLAGS.jit) + tf.keras.backend.clear_session() + vertices, indices, weights = get_data(FLAGS.nv, FLAGS.ne, sort=FLAGS.sort) + nv = vertices.shape[0] + + with tf.Graph().as_default(): + vertices = tf.constant(vertices, dtype=tf.float32) + indices = tf.constant(indices, dtype=tf.int64) + weights = tf.constant(weights, dtype=tf.float32) + # batch size of 1 + vertices, indices, weights = ( + tf.expand_dims(t, axis=0) for t in (vertices, indices, weights)) + + data = Input(tensor=vertices) + indices = Input(tensor=indices) + weights = Input(tensor=weights) + inputs = (data, indices, weights) + data, indices, weights = tuple( + Lambda(tf.squeeze, arguments=dict(axis=0), name='squeeze{}'.format(i))(t) + for i, t in enumerate(inputs)) + nv = Lambda(lambda x: tf.shape(x, out_type=tf.int64)[0])(data) + + neighbors = Lambda( + lambda args: tf.SparseTensor(args[0], args[1], (args[2], args[2])))( + [indices, weights, nv]) + + for _ in range(FLAGS.num_layers): + layer = FeatureSteeredConvolutionKerasLayer( + sparse_impl=( + SparseImplementation.SPARSE_MATMUL if FLAGS.sparse else + SparseImplementation.GATHER_SUM), + num_weight_matrices=FLAGS.num_weight_matrices, + num_output_channels=FLAGS.num_output_channels) + data = layer([data, neighbors]) + data = tf.nn.relu(data) + + pred = data + output = Lambda(tf.expand_dims, arguments=dict(axis=0))(pred) + model = tf.keras.Model(inputs=inputs, outputs=output) + + if FLAGS.backward: + loss = tf.reduce_sum(pred) + optimizer = tf.keras.optimizers.SGD() + + model_weights = model.trainable_weights + + grads = optimizer.get_gradients(loss, model_weights) + grads_and_vars = tuple(zip(grads, model_weights)) + train_op = optimizer.apply_gradients(grads_and_vars) + else: + train_op = pred + + bm = tf.test.Benchmark() + with tf.compat.v1.Session() as sess: + logging.info('Initializing variables...') + + sess.run(tf.compat.v1.global_variables_initializer()) + + logging.info('Starting benchmarking...') + result = bm.run_op_benchmark( + sess, train_op, burn_iters=FLAGS.burn_iters, + min_iters=FLAGS.min_iters) + summarize(result) + + +if __name__ == '__main__': + app.run(main) diff --git a/tensorflow_graphics/geometry/convolution/graph_convolution.py b/tensorflow_graphics/geometry/convolution/graph_convolution.py index 8ec7e2b72..640a62f27 100644 --- a/tensorflow_graphics/geometry/convolution/graph_convolution.py +++ b/tensorflow_graphics/geometry/convolution/graph_convolution.py @@ -24,6 +24,48 @@ from tensorflow_graphics.util import shape +class SparseImplementation(object): + GATHER_SUM = 'gather_sum' + SPARSE_MATMUL = 'sparse_matmul' + + @classmethod + def all(cls): + return (SparseImplementation.GATHER_SUM, SparseImplementation.SPARSE_MATMUL) + + @classmethod + def validate(cls, key): + if key not in cls.all(): + raise ValueError('Invalid SparseImplementation Key %s.' % key) + + +def _gather_sum(adjacency_ind_0, adjacency_ind_1, adjacency_values, weights_q, + x_flat): + q_m_list = tf.unstack(weights_q, axis=-1) + p_sums = [] + + x_sep = tf.gather(x_flat, adjacency_ind_1) + num_segments = tf.shape(input=x_flat)[0] + for q_m in q_m_list: + # Compute `y_i_m = sum_{j in neighborhood(i)} q_m(x_i, x_j) * w_m * x_j`. + q_m = tf.expand_dims(q_m, axis=-1) + p_sums.append(tf.math.unsorted_segment_sum( + data=(q_m * x_sep) * tf.expand_dims(adjacency_values, -1), + segment_ids=adjacency_ind_0, + num_segments=num_segments)) + return p_sums + + +def _sparse_matmul( + adjacency_indices, adjacency_values, dense_shape, weights_q, x_flat): + p_sums = [] + q_m_list = tf.unstack(weights_q, axis=-1) + for q_m in q_m_list: + sp = tf.SparseTensor( + adjacency_indices, adjacency_values * q_m, dense_shape) + p_sums.append(tf.sparse.sparse_dense_matmul(sp, x_flat)) + return p_sums + + def feature_steered_convolution(data, neighbors, sizes, @@ -32,7 +74,8 @@ def feature_steered_convolution(data, var_c, var_w, var_b, - name=None): + name=None, + sparse_impl=SparseImplementation.GATHER_SUM): # pyformat: disable """Implements the Feature Steered graph convolution. @@ -42,7 +85,8 @@ def feature_steered_convolution(data, https://arxiv.org/abs/1706.05206 The shorthands used below are - `V`: The number of vertices. + `Vi`: The number of vertices in the input. + `Vo`: The number of vertices in the output. `C`: The number of channels in the input data. `D`: The number of channels in the output after convolution. `W`: The number of weight matrices used in the convolution. @@ -53,9 +97,9 @@ def feature_steered_convolution(data, In the following, A1 to An are optional batch dimensions. Args: - data: A `float` tensor with shape `[A1, ..., An, V, C]`. + data: A `float` tensor with shape `[A1, ..., An, Vi, C]`. neighbors: A `SparseTensor` with the same type as `data` and with shape - `[A1, ..., An, V, V]` representing vertex neighborhoods. The neighborhood + `[A1, ..., An, Vo, Vi]` representing vertex neighborhoods. The neighborhood of a vertex defines the support region for convolution. For a mesh, a common choice for the neighborhood of vertex i would be the vertices in the K-ring of i (including i itself). Each vertex must have at least one @@ -81,17 +125,21 @@ def feature_steered_convolution(data, var_c: A 1-D tensor with shape `[W]`. var_w: A 3-D tensor with shape `[W, C, D]`. var_b: A 1-D tensor with shape `[D]`. + sparse_impl: `SparseImplementation` prop, ('gather_sum', 'sparse_matmul'). + This influences the computational performance. Results should be the same + except for floating point errors. name: A name for this op. Defaults to `graph_convolution_feature_steered_convolution`. Returns: - Tensor with shape `[A1, ..., An, V, D]`. + Tensor with shape `[A1, ..., An, Vo, D]`. Raises: TypeError: if the input types are invalid. ValueError: if the input dimensions are invalid. """ # pyformat: enable + SparseImplementation.validate(sparse_impl) with tf.compat.v1.name_scope( name, "graph_convolution_feature_steered_convolution", [data, neighbors, sizes, var_u, var_v, var_c, var_w, var_b]): @@ -131,28 +179,33 @@ def feature_steered_convolution(data, adjacency = neighbors x_u = tf.matmul(x_flat, var_u) x_v = tf.matmul(x_flat, var_v) - adjacency_ind_0 = adjacency.indices[:, 0] - adjacency_ind_1 = adjacency.indices[:, 1] + adjacency_ind_0, adjacency_ind_1 = tf.unstack(adjacency.indices, axis=1) x_u_rep = tf.gather(x_u, adjacency_ind_0) x_v_sep = tf.gather(x_v, adjacency_ind_1) - weights_q = tf.exp(x_u_rep + x_v_sep + tf.reshape(var_c, (1, -1))) - weights_q_sum = tf.reduce_sum( + logits = x_u_rep + x_v_sep + tf.reshape(var_c, (1, -1)) + # numerically stable softmax + # compared to tf.nn.softmax, this gives better performance when JIT compiled + logits = logits - tf.reduce_max(logits, axis=-1, keepdims=True) + weights_q = tf.exp(logits) + weights_q = weights_q / tf.reduce_sum( input_tensor=weights_q, axis=-1, keepdims=True) - weights_q = weights_q / weights_q_sum - y_i_m = [] - x_sep = tf.gather(x_flat, adjacency_ind_1) - q_m_list = tf.unstack(weights_q, axis=-1) + + if sparse_impl == SparseImplementation.GATHER_SUM: + p_sums = _gather_sum( + adjacency_ind_0, adjacency_ind_1, adjacency.values, weights_q, x_flat) + else: + assert(sparse_impl == SparseImplementation.SPARSE_MATMUL) + p_sums = _sparse_matmul( + adjacency.indices, adjacency.values, adjacency.dense_shape, weights_q, + x_flat) + w_m_list = tf.unstack(var_w, axis=0) + y_i_m = [] - x_flat_shape = tf.shape(input=x_flat) - for q_m, w_m in zip(q_m_list, w_m_list): + for p_sum, w_m in zip(p_sums, w_m_list): # Compute `y_i_m = sum_{j in neighborhood(i)} q_m(x_i, x_j) * w_m * x_j`. - q_m = tf.expand_dims(q_m, axis=-1) - p_sum = tf.math.unsorted_segment_sum( - data=(q_m * x_sep) * tf.expand_dims(adjacency.values, -1), - segment_ids=adjacency_ind_0, - num_segments=x_flat_shape[0]) y_i_m.append(tf.matmul(p_sum, w_m)) + y_out = tf.add_n(inputs=y_i_m) + tf.reshape(var_b, [1, -1]) if data_ndims > 2: y_out = unflatten(y_out) diff --git a/tensorflow_graphics/geometry/convolution/tests/graph_convolution_test.py b/tensorflow_graphics/geometry/convolution/tests/graph_convolution_test.py index 7d3c0ac85..d30685fe8 100644 --- a/tensorflow_graphics/geometry/convolution/tests/graph_convolution_test.py +++ b/tensorflow_graphics/geometry/convolution/tests/graph_convolution_test.py @@ -333,13 +333,23 @@ def test_feature_steered_convolution_only_self_edges(self, batch_size, @parameterized.parameters( (((1.0,), (2.0,), (3.0,)), np.ones(shape=(3, 3)) / 3.0, ((0.5,),), - ((1.3,),), (-0.7,), (((0.8,),),), (3.0,), ((4.6,), (4.6,), (4.6,))), + ((1.3,),), (-0.7,), (((0.8,),),), (3.0,), ((4.6,), (4.6,), (4.6,)), + 'gather_sum'), (((1.0,), (2.0,), (3.0,)), np.ones(shape=(3, 3)) / 3.0, ((0.5, 0.2),), ((0.3, 0.4),), (-0.7, 0.15), (((0.8,),), ((1.1,),)), (3.0,), - ((5.011706928844621,), (4.971030281984818,), (4.927388658982911,))), + ((5.011706928844621,), (4.971030281984818,), (4.927388658982911,)), + 'gather_sum'), + (((1.0,), (2.0,), (3.0,)), np.ones(shape=(3, 3)) / 3.0, ((0.5,),), + ((1.3,),), (-0.7,), (((0.8,),),), (3.0,), ((4.6,), (4.6,), (4.6,)), + 'sparse_matmul'), + (((1.0,), (2.0,), (3.0,)), np.ones(shape=(3, 3)) / 3.0, ((0.5, 0.2),), + ((0.3, 0.4),), (-0.7, 0.15), (((0.8,),), ((1.1,),)), (3.0,), + ((5.011706928844621,), (4.971030281984818,), (4.927388658982911,)), + 'sparse_matmul'), ) def test_feature_steered_convolution_padding_preset(self, data, neighbors, u, - v, c, w, b, expected): + v, c, w, b, expected, + sparse_impl): """Test expected result for preset data and filter values.""" array = (np.array(i) for i in (data, neighbors, expected)) data, neighbors, expected = array @@ -354,7 +364,8 @@ def test_feature_steered_convolution_padding_preset(self, data, neighbors, u, var_v=v, var_c=c, var_w=w, - var_b=b) + var_b=b, + sparse_impl=sparse_impl) self.assertAllClose(y, expected) @parameterized.parameters( diff --git a/tensorflow_graphics/nn/layer/graph_convolution.py b/tensorflow_graphics/nn/layer/graph_convolution.py index b4da09d8c..9ac9cae42 100644 --- a/tensorflow_graphics/nn/layer/graph_convolution.py +++ b/tensorflow_graphics/nn/layer/graph_convolution.py @@ -31,6 +31,7 @@ def feature_steered_convolution_layer( num_weight_matrices=8, num_output_channels=None, initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.1), + sparse_impl=gc.SparseImplementation.GATHER_SUM, name=None, var_name=None): # pyformat: disable @@ -135,6 +136,7 @@ def feature_steered_convolution_layer( var_c=var_c, var_w=var_w, var_b=var_b, + sparse_impl=sparse_impl, name=name) @@ -146,6 +148,7 @@ def __init__(self, num_weight_matrices=8, num_output_channels=None, initializer=None, + sparse_impl=gc.SparseImplementation.GATHER_SUM, name=None, **kwargs): """Initializes FeatureSteeredConvolutionKerasLayer. @@ -171,7 +174,19 @@ def __init__(self, if initializer is None: self._initializer = tf.compat.v1.truncated_normal_initializer(stddev=0.1) else: - self._initializer = initializer + self._initializer = tf.keras.initializers.get(initializer) + gc.SparseImplementation.validate(sparse_impl) + self._sparse_impl = sparse_impl + + def get_config(self): + config = super(FeatureSteeredConvolutionKerasLayer, self).get_config() + config.update(dict( + initializer=tf.keras.initializers.serialize(self._initializer), + num_weight_matrics=self._num_weight_matrics, + num_output_channels=self._num_output_channels, + sparse_impl=self._sparse_impl, + )) + return config def build(self, input_shape): """Initializes the trainable weights.""" @@ -265,7 +280,8 @@ def call(self, inputs, sizes=None): var_v=self.var_v, var_c=self.var_c, var_w=self.var_w, - var_b=self.var_b) + var_b=self.var_b, + sparse_impl=self._sparse_impl) class DynamicGraphConvolutionKerasLayer(tf.keras.layers.Layer): @@ -331,15 +347,30 @@ def __init__(self, name=name, **kwargs) self._num_output_channels = num_output_channels self._reduction = reduction - self._activation = activation + self._activation = tf.keras.activations.get(activation) self._use_bias = use_bias - self._kernel_initializer = kernel_initializer - self._bias_initializer = bias_initializer - self._kernel_regularizer = kernel_regularizer - self._bias_regularizer = bias_regularizer - self._activity_regularizer = activity_regularizer - self._kernel_constraint = kernel_constraint - self._bias_constraint = bias_constraint + self._kernel_initializer = tf.keras.initializers.get(kernel_initializer) + self._bias_initializer = tf.keras.initializers.get(bias_initializer) + self._kernel_regularizer = tf.keras.initializers.get(kernel_regularizer) + self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer) + self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer) + self._kernel_constraint = tf.keras.constraints.get(kernel_constraint) + self._bias_constraint = tf.keras.constraints.get(bias_constraint) + + def get_config(self): + config = super(DynamicGraphConvolutionKerasLayer, self).get_config() + config.update( + num_output_channels=self._num_output_channels, + reduction=self._reduction, + use_bias=self._use_bias, + kernel_initializer=tf.keras.initializers.serialize(self._kernel_initializer), + bias_initializer=tf.keras.initializers.serialize(self._bias_initializer), + kernel_regularizer=tf.keras.regularizers.serialize(self._kernel_regularizer), + bias_regularizer=tf.keras.regularizers.serialize(self._bias_regularizer), + kernel_constraint=tf.keras.constraints.serialize(self._kernel_constraint), + bias_constraint=tf.keras.constraints.serialize(self._bias_constraint), + ) + return config def build(self, input_shape): """Initializes the layer weights."""