From 53d60da7965e0fd18ce36d4cada4afb1fb908274 Mon Sep 17 00:00:00 2001 From: Dominic Jack Date: Fri, 27 Sep 2019 15:40:20 +1000 Subject: [PATCH 1/3] added difference feature_steered_convolution implementations and benchmarks --- .../feature_steered_conv_benchmark.py | 126 ++++++ .../feature_steered_model_benchmark.py | 252 +++++++++++ .../geometry/convolution/graph_convolution.py | 410 +++++++++++++++--- .../tests/graph_convolution_test.py | 22 +- .../nn/layer/graph_convolution.py | 33 +- 5 files changed, 765 insertions(+), 78 deletions(-) create mode 100644 tensorflow_graphics/benchmarks/feature_steered_conv_benchmark.py create mode 100644 tensorflow_graphics/benchmarks/feature_steered_model_benchmark.py diff --git a/tensorflow_graphics/benchmarks/feature_steered_conv_benchmark.py b/tensorflow_graphics/benchmarks/feature_steered_conv_benchmark.py new file mode 100644 index 000000000..7e5046eaf --- /dev/null +++ b/tensorflow_graphics/benchmarks/feature_steered_conv_benchmark.py @@ -0,0 +1,126 @@ +""" +Benchmarking script for various feature_steered_convolution implementations. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import numpy as np +import tensorflow as tf +import tensorflow_graphics.geometry.convolution.graph_convolution as gc +from tensorflow_graphics.geometry.convolution.tests.graph_convolution_test \ + import _random_data, _random_variables + +from absl import app +from absl import flags + +flags.DEFINE_integer('batch_size', 8, help='size of batch') +flags.DEFINE_integer('num_vertices', 500, help='number of vertices') +flags.DEFINE_integer('in_channels', 32, help='number of input channels') +flags.DEFINE_integer('out_channels', 32, help='number of output channels') +flags.DEFINE_integer('num_filters', + 8, + help='number of filters (W, or M in paper)') +flags.DEFINE_float('sparsity', 0.25, help='sparsity of neighbors') +flags.DEFINE_bool('mem_only', + default=False, + help='memory efficient implementations only') + +FLAGS = flags.FLAGS + + +def main(args): + random_state = np.random.RandomState(123) + + data, neighbors = _random_data(FLAGS.batch_size, + FLAGS.num_vertices, + FLAGS.in_channels, + padding=False, + only_self_edges=False, + sparsity=FLAGS.sparsity, + random_state=random_state) + sizes = None + data = tf.convert_to_tensor(value=data, dtype=tf.float32) + + u, v, c, w, b = _random_variables(FLAGS.in_channels, + FLAGS.out_channels, + FLAGS.num_filters, + random_state=random_state) + + # v1_p2d is the original implementation + fast = dict(memory_efficient=False) + names, kwargs = zip(*( + ('v1_p2d', dict(version='v1', segment_sum_impl='partition2d', **fast)), + ('v1_sorted', dict(version='v1', segment_sum_impl='sorted', **fast)), + ('v1_unsorted', dict(version='v1', segment_sum_impl='unsorted', **fast)), + ('v1_p2d_mem', dict(version='v1', segment_sum_impl='partition2d')), + ('v1_sorted_mem', dict(version='v1', segment_sum_impl='sorted')), + ('v1_unsorted_mem', dict(version='v1', segment_sum_impl='unsorted')), + ('v2_default', dict(version='v2')), # will be same as one of the below + ('v2_first', dict(version='v2', transform_data_first=True)), + ('v2_last', dict(version='v2', transform_data_first=False)), + ('v3', dict(version='v3', **fast)), + ('v3_mem', dict(version='v3')), + )) + if FLAGS.mem_only: + names, kwargs = zip(*( + (name, kw) for name, kw in zip(names, kwargs) if 'mem' in name)) + + vals = [ + gc.feature_steered_convolution(data, neighbors, sizes, u, v, c, w, b, + **kw) for kw in kwargs + ] + grads = [ + tf.gradients(val, (data, neighbors.values, u, v, c, w, b)) for val in vals + ] + + errs = [tf.reduce_max(tf.abs(val - vals[0])) for val in vals[1:]] + + with tf.Session() as sess: + errs = sess.run(errs) + + times = [] + memories = [] + for name, v, g in zip(names, vals, grads): + print('------------') + print(name) + bm = tf.test.Benchmark() + result = bm.run_op_benchmark(sess, (v, g)) + + times.append(result['wall_time']) + memories.append(result['extras']['allocator_maximum_num_bytes_GPU_0_bfc']) + + print('*************') + print('** SUMMARY **') + print('*************') + print('{:15s}: {}'.format('batch_size', FLAGS.batch_size)) + print('{:15s}: {}'.format('num_vertices', FLAGS.num_vertices)) + print('{:15s}: {}'.format('in_channels', FLAGS.in_channels)) + print('{:15s}: {}'.format('out_channels', FLAGS.out_channels)) + print('{:15s}: {}'.format('num_filters', FLAGS.num_filters)) + print('{:15s}: {}'.format('sparsity', FLAGS.sparsity)) + + times = np.array(times) + # ti = np.argmin(times) + ti = 0 + tmin = times[ti] + print('Baseline time: {}, {}s'.format(names[ti], tmin)) + print('rel times:') + for name, time in zip(names, times): + print('{:15s} {:.3f}'.format(name, time / tmin)) + memories = np.array(memories) + # mi = np.argmin(memories) + mi = 0 + mmin = memories[mi] + print('Baseline memory: {}, {}mb'.format(names[mi], mmin / 1024**2)) + for name, memory in zip(names, memories): + print('{:15s} {:.3f}'.format(name, memory / mmin)) + + print('Errors w.r.t {}'.format(names[0])) + for name, err in zip(names[1:], errs): + print('{:10s}: {}'.format(name, err)) + + +if __name__ == '__main__': + app.run(main) diff --git a/tensorflow_graphics/benchmarks/feature_steered_model_benchmark.py b/tensorflow_graphics/benchmarks/feature_steered_model_benchmark.py new file mode 100644 index 000000000..5129c804c --- /dev/null +++ b/tensorflow_graphics/benchmarks/feature_steered_model_benchmark.py @@ -0,0 +1,252 @@ +""" +Benchmarking script for various feature_steered_convolution implementations. + +Runs training operation on models from `notebooks/mesh_segmentation_demo.ipynb` +with differing convolution kwargs. Reported memory usage/timings are for the +entire model, not just the convolution implementations. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import os +import tensorflow as tf + +from tensorflow_graphics.geometry.convolution import utils +from tensorflow_graphics.nn.layer import graph_convolution as graph_conv +from tensorflow_graphics.notebooks import mesh_segmentation_dataio as dataio + +from absl import app +from absl import flags + +FLAGS = flags.FLAGS + +flags.DEFINE_integer( + 'num_filters', help='number of filters (M in paper, W in code)', default=8) +flags.DEFINE_bool( + 'mem_only', help='memory efficient implementations only', default=False) + +path_to_data_zip = tf.keras.utils.get_file( + 'data.zip', + origin='https://storage.googleapis.com/tensorflow-graphics/notebooks/mesh_segmentation/data.zip', + extract=True) + +test_data_files = [ + os.path.join( + os.path.dirname(path_to_data_zip), + 'data/Dancer_test_sequence.tfrecords') +] + +MODEL_PARAMS = { + 'num_filters': 8, + 'num_classes': 16, + 'encoder_filter_dims': [32, 64, 128], + 'learning_rate': 1e-3, + 'beta': 0.9, + 'adam_epsilon': 1e-8, + 'preprocess_neighbors': True +} + + +def mesh_encoder( + batch_mesh_data, num_filters, output_dim, conv_layer_dims, conv_kwargs, + preprocess_neighbors=True): + """A mesh encoder using feature steered graph convolutions. + + The shorthands used below are + `B`: Batch size. + `V`: The maximum number of vertices over all meshes in the batch. + `D`: The number of dimensions of input vertex features, D=3 if vertex + positions are used as features. + + Args: + batch_mesh_data: A mesh_data dict with following keys + 'vertices': A [B, V, D] `float32` tensor of vertex features, possibly + 0-padded. + 'neighbors': A [B, V, V] `float32` sparse tensor of edge weights. + 'num_vertices': A [B] `int32` tensor of number of vertices per mesh. + num_filters: The number of weight matrices to be used in feature steered + graph conv. + output_dim: A dimension of output per vertex features. + conv_layer_dims: A list of dimensions used in graph convolution layers. + + Returns: + vertex_features: A [B, V, output_dim] `float32` tensor of per vertex + features. + """ + batch_vertices = batch_mesh_data['vertices'] + neighbors = batch_mesh_data['neighbors'] + num_vertices = batch_mesh_data['num_vertices'] + + # Linear: N x D --> N x 16. + vertex_features = tf.keras.layers.Conv1D(16, 1, name='lin16')(batch_vertices) + + if preprocess_neighbors: + num_vertices_square = tf.stack((num_vertices, num_vertices), axis=-1) + neighbors = utils.convert_to_block_diag_2d(neighbors, num_vertices_square) + sizes = None + vertex_features, unflatten = utils.flatten_batch_to_2d( + vertex_features, num_vertices) + else: + sizes = num_vertices + unflatten = None + + # graph convolution layers + for dim in conv_layer_dims: + with tf.variable_scope('conv_%d' % dim): + vertex_features = graph_conv.feature_steered_convolution_layer( + vertex_features, + neighbors, + sizes=sizes, + num_weight_matrices=num_filters, + num_output_channels=dim, + **conv_kwargs) + vertex_features = tf.nn.relu(vertex_features) + + if unflatten is not None: + vertex_features = unflatten(vertex_features) + # Linear: N x 128 --> N x 256. + vertex_features = tf.keras.layers.Conv1D( + 256, 1, name='lin256')( + vertex_features) + vertex_features = tf.nn.relu(vertex_features) + + # Linear: N x 256 --> N x output_dim. + vertex_features = tf.keras.layers.Conv1D( + output_dim, 1, name='lin_output')( + vertex_features) + + return vertex_features + + +def model_fn(features, labels, mode, params): + """Returns a mesh segmentation model_fn for use with tf.Estimator.""" + logits = mesh_encoder(features, params['num_filters'], params['num_classes'], + params['encoder_filter_dims'], + params.get('conv_kwargs'), + params.get('preprocess_neighbors', True)) + predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) + outputs = { + 'vertices': features['vertices'], + 'triangles': features['triangles'], + 'num_vertices': features['num_vertices'], + 'num_triangles': features['num_triangles'], + 'predictions': predictions, + } + # For predictions, return the outputs. + if mode == tf.estimator.ModeKeys.PREDICT: + outputs['labels'] = features['labels'] + return tf.estimator.EstimatorSpec(mode=mode, predictions=outputs) + # Loss + # Weight the losses by masking out padded vertices/labels. + vertex_ragged_sizes = features['num_vertices'] + mask = tf.sequence_mask(vertex_ragged_sizes, tf.shape(labels)[-1]) + loss_weights = tf.cast(mask, dtype=tf.float32) + loss = tf.losses.sparse_softmax_cross_entropy( + logits=logits, labels=labels, weights=loss_weights) + # For training, build the optimizer. + if mode == tf.estimator.ModeKeys.TRAIN: + optimizer = tf.train.AdamOptimizer( + learning_rate=params['learning_rate'], + beta1=params['beta'], + epsilon=params['adam_epsilon']) + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + with tf.control_dependencies(update_ops): + train_op = optimizer.minimize( + loss=loss, global_step=tf.train.get_global_step()) + return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) + + # For eval, return eval metrics. + eval_ops = { + 'mean_loss': + tf.metrics.mean(loss), + 'accuracy': + tf.metrics.accuracy( + labels=labels, predictions=predictions, weights=loss_weights) + } + return tf.estimator.EstimatorSpec( + mode=mode, loss=loss, eval_metric_ops=eval_ops) + +test_io_params = { + 'is_training': False, + 'sloppy': False, + 'shuffle': True, + 'repeat': False +} +test_tfrecords = test_data_files + + +def run_benchmark(conv_kwargs, **kwargs): + with tf.Graph().as_default(): + features, labels = dataio.create_input_from_dataset( + dataio.create_dataset_from_tfrecords, test_tfrecords, test_io_params) + params = MODEL_PARAMS.copy() + params['conv_kwargs'] = conv_kwargs + params.update(kwargs) + spec = model_fn(features, labels, tf.estimator.ModeKeys.TRAIN, params) + init = tf.compat.v1.global_variables_initializer() + + print('--------------') + for k in sorted(conv_kwargs): + print('{:10s}: {}'.format(k, conv_kwargs[k])) + with tf.Session() as sess: + sess.run(init) + bm = tf.test.Benchmark() + result = bm.run_op_benchmark(sess, spec.train_op) + return result + + +def main(args): + num_filters = flags.FLAGS.num_filters + # v1_p2d is the original implementation + fast = dict(memory_efficient=False) + names, kwargs = zip(*( + ('v1_p2d', dict(version='v1', segment_sum_impl='partition2d', **fast)), + ('v1_sorted', dict(version='v1', segment_sum_impl='sorted', **fast)), + ('v1_unsorted', dict(version='v1', segment_sum_impl='unsorted', **fast)), + ('v1_p2d_mem', dict(version='v1', segment_sum_impl='partition2d')), + ('v1_sorted_mem', dict(version='v1', segment_sum_impl='sorted')), + ('v1_unsorted_mem', dict(version='v1', segment_sum_impl='unsorted')), + ('v2_default', dict(version='v2')), # will be same as one of the below + # ('v2_first', dict(version='v2', transform_data_first=True)), + # ('v2_last', dict(version='v2', transform_data_first=False)), + ('v3', dict(version='v3', **fast)), + ('v3_mem', dict(version='v3')), + )) + if FLAGS.mem_only: + names, kwargs = zip(*( + (name, kw) for name, kw in zip(names, kwargs) if 'mem' in name)) + times = [] + memories = [] + for kw in kwargs: + result = run_benchmark(kw, num_filters=num_filters) + times.append(result['wall_time']) + memories.append( + result['extras']['allocator_maximum_num_bytes_GPU_0_bfc']) + + print('*************') + print('** SUMMARY **') + print('*************') + print('num_filters = {}'.format(num_filters)) + + times = np.array(times) + # ti = np.argmin(times) + ti = 0 + tmin = times[ti] + print('Baseline time: {}, {}s'.format(names[ti], tmin)) + print('rel times:') + for name, time in zip(names, times): + print('{:15s} {:.3f}'.format(name, time / tmin)) + memories = np.array(memories) + # mi = np.argmin(memories) + mi = 0 + mmin = memories[mi] + print('Baseline memory: {}, {}mb'.format(names[mi], mmin / 1024**2)) + for name, memory in zip(names, memories): + print('{:15s} {:.3f}'.format(name, memory / mmin)) + + +if __name__ == '__main__': + app.run(main) diff --git a/tensorflow_graphics/geometry/convolution/graph_convolution.py b/tensorflow_graphics/geometry/convolution/graph_convolution.py index ab66e3db4..cc31beee9 100644 --- a/tensorflow_graphics/geometry/convolution/graph_convolution.py +++ b/tensorflow_graphics/geometry/convolution/graph_convolution.py @@ -24,6 +24,321 @@ from tensorflow_graphics.util import shape +def _prepare_feature_steered_args( + data, neighbors, sizes, var_u, var_v, var_c, var_w, var_b): + data = tf.convert_to_tensor(value=data) + neighbors = tf.compat.v1.convert_to_tensor_or_sparse_tensor(value=neighbors) + if sizes is not None: + sizes = tf.convert_to_tensor(value=sizes) + var_u = tf.convert_to_tensor(value=var_u) + var_v = tf.convert_to_tensor(value=var_v) + var_c = tf.convert_to_tensor(value=var_c) + var_w = tf.convert_to_tensor(value=var_w) + var_b = tf.convert_to_tensor(value=var_b) + + data_ndims = data.shape.ndims + utils.check_valid_graph_convolution_input(data, neighbors, sizes) + shape.compare_dimensions( + tensors=(data, var_u, var_v, var_w), + tensor_names=("data", "var_u", "var_v", "var_w"), + axes=(-1, 0, 0, 1)) + shape.compare_dimensions( + tensors=(var_u, var_v, var_c, var_w), + tensor_names=("var_u", "var_v", "var_c", "var_w"), + axes=(1, 1, 0, 0)) + shape.compare_dimensions( + tensors=(var_w, var_b), tensor_names=("var_w", "var_b"), axes=-1) + + # Flatten the batch dimensions and remove any vertex padding. + if data_ndims > 2: + if sizes is not None: + sizes_square = tf.stack((sizes, sizes), axis=-1) + else: + sizes_square = None + x_flat, unflatten = utils.flatten_batch_to_2d(data, sizes) + adjacency = utils.convert_to_block_diag_2d(neighbors, sizes_square) + else: + x_flat = data + adjacency = neighbors + unflatten = None + + return x_flat, adjacency, var_u, var_v, var_c, var_w, var_b, unflatten + + +def feature_steered_convolution_v1(data, + neighbors, + sizes, + var_u, + var_v, + var_c, + var_w, + var_b, + memory_efficient=True, + segment_sum_impl='partition2d', + name=None): + """Implements the Feature Steered graph convolution. + + FeaStNet: Feature-Steered Graph Convolutions for 3D Shape Analysis + Nitika Verma, Edmond Boyer, Jakob Verbeek + CVPR 2018 + https://arxiv.org/abs/1706.05206 + + Original implementation with some tweaks. Original version recovered with + `memory_efficient=False, segment_sum_impl='partition2d'`. + + Additional args: + memory_efficient: bool, if True uses `foldl` implementation which is + slightly slower (~10% in experiments) but significantly more memory + efficient (~2-4x less memory). + segment_sum_impl: one of 'partition2d', 'sorted', 'unsorted', corresponding to + using `tf.math.segment_sum`, `tf.math.unsorted_segment_sum` or + `utils.partition_sums_2d` respectively. If 'sorted', `neighbors` must be + ordered - see `tf.sparse.reorder`. + """ + with tf.compat.v1.name_scope( + name, "graph_convolution_feature_steered_convolution", + [data, neighbors, sizes, var_u, var_v, var_c, var_w, var_b]): + x_flat, adjacency, var_u, var_v, var_c, var_w, var_b, unflatten = \ + _prepare_feature_steered_args(data, neighbors, sizes, var_u, var_v, + var_c, var_w, var_b) + x_u = tf.matmul(x_flat, var_u) + x_v = tf.matmul(x_flat, var_v) + adjacency_ind_0 = adjacency.indices[:, 0] + adjacency_ind_1 = adjacency.indices[:, 1] + x_u_rep = tf.gather(x_u, adjacency_ind_0) + x_v_sep = tf.gather(x_v, adjacency_ind_1) + weights_q = tf.exp(x_u_rep + x_v_sep + tf.reshape(var_c, (1, -1))) + weights_q_sum = tf.reduce_sum( + input_tensor=weights_q, axis=-1, keepdims=True) + weights_q = weights_q / weights_q_sum + x_sep = tf.gather(x_flat, adjacency_ind_1) + V = tf.shape(x_flat)[0] + + def get_mth_term(q_m, w_m): + if segment_sum_impl == 'partition2d': + q_m = tf.expand_dims(q_m, axis=-1) + p_sum = utils.partition_sums_2d(q_m * x_sep, adjacency_ind_0, + adjacency.values) + else: + args = (x_sep * tf.expand_dims(q_m * adjacency.values, axis=-1), + adjacency_ind_0) + if segment_sum_impl == 'sorted': + p_sum = tf.math.segment_sum(*args) + elif segment_sum_impl == 'unsorted': + p_sum = tf.math.unsorted_segment_sum(*args, num_segments=V) + else: + raise ValueError( + 'Invalid segment_sum_impl "{}" - must be one of "partition2d", ' + '"sorted", "unsorted"'.format(segment_sum_impl)) + return tf.matmul(p_sum, w_m) + + if memory_efficient: + y_out = tf.foldl( + lambda acc, args: acc + get_mth_term(*args), + (tf.transpose(weights_q, (1, 0)), var_w), + tf.tile(tf.expand_dims(var_b, axis=0), (tf.shape(x_flat)[0], 1))) + else: + q_ms = tf.unstack(weights_q, axis=-1) + w_ms = tf.unstack(var_w, axis=0) + y_out = tf.add_n( + [get_mth_term(*args) for args in zip(q_ms, w_ms)]) + var_b + + if unflatten is not None: + y_out = unflatten(y_out) + return y_out + + +def feature_steered_convolution_v2(data, + neighbors, + sizes, + var_u, + var_v, + var_c, + var_w, + var_b, + transform_data_first=None, + name=None): + """Implements the Feature Steered graph convolution. + + FeaStNet: Feature-Steered Graph Convolutions for 3D Shape Analysis + Nitika Verma, Edmond Boyer, Jakob Verbeek + CVPR 2018 + https://arxiv.org/abs/1706.05206 + + This implementation is based on splitting the exponential term in the softmax + into products of expoentials which allows neighborhood summation to be + implemented as a sparse-dense matrix product. This means per-edge features + (other than the softmax values) need not be explicitly created, so memory + usage is lower and computation is faster. + + Extra channels ("W", or "M" in paper) are broadcast to create a large feature + matrix of shape [V, D, M] before reduction. This is slightly faster at the + cost of a larger memory footprint. See `feature_steered_convolution_v3` for + a slightly slower more memory efficient implementation. + + For base arg/return descriptions, see See `feature_steered_convolution`. + + Additional args: + transform_data_first: if True, performs transformation of features from + [V, C] -> [V, D, W] via `var_w` before other multiplications. + Defaults to `C > D`. + """ + with tf.compat.v1.name_scope( + name, "graph_convolution_feature_steered_convolution_v2", + [data, neighbors, sizes, var_u, var_v, var_c, var_w, var_b]): + x_flat, adjacency, var_u, var_v, var_c, var_w, var_b, unflatten = \ + _prepare_feature_steered_args(data, neighbors, sizes, var_u, var_v, + var_c, var_w, var_b) + x_u = tf.matmul(x_flat, var_u) # [V, W] + x_v = tf.matmul(x_flat, var_v) # [V, W] + x_uc = x_u + var_c # [V, W] + # apply per-term stabilization + x_uc = x_uc - tf.reduce_max(x_uc, axis=-1, keepdims=True) + x_v = x_v - tf.reduce_max(x_v, axis=-1, keepdims=True) + + e_uc = tf.exp(x_uc) + e_v = tf.exp(x_v) + + i, j = tf.unstack(adjacency.indices, axis=-1) + # E == num_edges + q_vals = tf.gather(e_uc, i) * tf.gather(e_v, j) # [E, W] + weights = adjacency.values / tf.reduce_sum(q_vals, axis=-1) # [E] + + weighted_adjacency = tf.SparseTensor( + adjacency.indices, weights, dense_shape=adjacency.dense_shape) + + # `tf.einsum` implementations arguable easier to understand and possibly + # more efficient, but we avoid them until the following issue is resolved + # https://github.com/tensorflow/tensorflow/issues/31022 + # Seems to be limited to examples where indices are repeated but not summed? + + W, C, D = var_w.shape + assert(C is not None and D is not None) + if transform_data_first is None: + transform_data_first = C > D + + if transform_data_first: + # x_flat = tf.einsum('vc,wcd->vdw', x_flat, var_w) + x_flat = tf.reduce_sum(tf.multiply( + tf.reshape(x_flat, (-1, 1, C, 1)), # V 1 C W + tf.transpose(var_w, (2, 1, 0)), # D C W + ), axis=2) # V D W + F = D + else: + F = C + x_flat = tf.expand_dims(x_flat, axis=-1) # V C 1 + + # ef = tf.einsum('vw,vfw->vfw', e_v, data) + ef = tf.expand_dims(e_v, axis=-2) * x_flat # [V, F, W] + ef = tf.reshape(ef, (-1, F * W)) # [V, F * W] + summed_ef = tf.sparse.sparse_dense_matmul(weighted_adjacency, ef) + summed_ef = tf.reshape(summed_ef, (-1, F, W)) # [V, F, W] + + if transform_data_first: + # ym = tf.einsum('vfw,vw->vf', summed_ef, e_uc) + ym_flat = summed_ef * tf.expand_dims(e_uc, axis=1) + y_flat = tf.reduce_sum(ym_flat, axis=-1) + else: + # y_flat = tf.einsum('vfw,vw,wcd->vd', summed_ef, e_uc, var_w) + ym_flat = summed_ef * tf.expand_dims(e_uc, axis=1) + y_flat = tf.matmul( + tf.reshape(ym_flat, (-1, C * W)), + tf.reshape(tf.transpose(var_w, (1, 0, 2)), (C * W, D)) + ) + y_flat = y_flat + var_b + if unflatten is not None: + return unflatten(y_flat) + else: + return y_flat + + +def feature_steered_convolution_v3(data, + neighbors, + sizes, + var_u, + var_v, + var_c, + var_w, + var_b, + memory_efficient=True, + name=None): + """Implements the Feature Steered graph convolution. + + FeaStNet: Feature-Steered Graph Convolutions for 3D Shape Analysis + Nitika Verma, Edmond Boyer, Jakob Verbeek + CVPR 2018 + https://arxiv.org/abs/1706.05206 + + This implementation is similar to `feature_steered_convolution_v2` except + it loops over entries of `var_w` in feature transformation. This avoids the + need to have a single [V, D, F] tensor in memory at any point in time. + + For base arg/return descriptions, see See `feature_steered_convolution`. + + Additional args: + memory_efficient: bool, if True uses `foldl` implementation which is + slightly slower (~10% in experiments) but significantly more memory + efficient (~2-4x less memory). + """ + with tf.compat.v1.name_scope( + name, "graph_convolution_feature_steered_convolution", + [data, neighbors, sizes, var_u, var_v, var_c, var_w, var_b]): + x_flat, adjacency, var_u, var_v, var_c, var_w, var_b, unflatten = \ + _prepare_feature_steered_args(data, neighbors, sizes, var_u, var_v, + var_c, var_w, var_b) + x_u = tf.matmul(x_flat, var_u) # [V, W] + x_v = tf.matmul(x_flat, var_v) # [V, W] + x_uc = x_u + var_c # [V, W] + + # apply per-term stabilization + x_uc = x_uc - tf.reduce_max(x_uc, axis=-1, keepdims=True) + x_v = x_v - tf.reduce_max(x_v, axis=-1, keepdims=True) + + e_uc = tf.exp(x_uc) + e_v = tf.exp(x_v) + + i, j = tf.unstack(adjacency.indices, axis=-1) + # E == num_edges + q_vals = tf.gather(e_uc, i) * tf.gather(e_v, j) # [E, W] + weights = adjacency.values / tf.reduce_sum(q_vals, axis=-1) # [E] + + weighted_adjacency = tf.SparseTensor( + adjacency.indices, weights, dense_shape=adjacency.dense_shape) + + W, C, D = var_w.shape + assert(C is not None and D is not None) + + def get_mth_term(wm, e_ucm, e_vm): + summed_ef = tf.sparse.sparse_dense_matmul( + weighted_adjacency, tf.expand_dims(e_vm, axis=-1) * x_flat) + return tf.matmul(tf.expand_dims(e_ucm, axis=-1) * summed_ef, wm) + + if memory_efficient: + y_flat = tf.foldl( + lambda acc, args: acc + get_mth_term(*args), + (var_w, tf.transpose(e_uc, (1, 0)), tf.transpose(e_v, (1, 0))), + tf.tile(tf.expand_dims(var_b, axis=0), (tf.shape(e_uc)[0], 1))) + else: + args = [ + tf.unstack(var_w, axis=0), + tf.unstack(e_uc, axis=1), + tf.unstack(e_v, axis=1), + ] + y_flat = tf.add_n([get_mth_term(*args) for args in zip(*args)]) + var_b + + if unflatten is not None: + return unflatten(y_flat) + else: + return y_flat + + +_versions = { + 'v1': feature_steered_convolution_v1, + 'v2': feature_steered_convolution_v2, + 'v3': feature_steered_convolution_v3, +} + + def feature_steered_convolution(data, neighbors, sizes, @@ -32,8 +347,9 @@ def feature_steered_convolution(data, var_c, var_w, var_b, - name=None): - # pyformat: disable + version='v1', + name=None, + **kwargs): """Implements the Feature Steered graph convolution. FeaStNet: Feature-Steered Graph Convolutions for 3D Shape Analysis @@ -81,8 +397,21 @@ def feature_steered_convolution(data, var_c: A 1-D tensor with shape `[W]`. var_w: A 3-D tensor with shape `[W, C, D]`. var_b: A 1-D tensor with shape `[D]`. + version: string indicating implementation version, one of "v1", "v2", "v3". + See `feature_steered_convolution_v1` / `feature_steered_convolution_v2` + etc. name: A name for this op. Defaults to `graph_convolution_feature_steered_convolution`. + **kwargs: version-specific kwargs. + use_original_segment_sum (default False): for "v1", use the original + implementation of segment sum, rather than + `tf.math.unsorted_segment_sum`. + memory_efficient (default True): for "v1", "v3", uses a more memory + efficient implementation at the cost of slightly slower runtime. + transform_data_first: for "v2", if True transforms data from + [V, C] -> [V, D, W] before some transformations. This does not affect + the results (aside from floating point errors), but may result in + a performance difference. Defaults to `C > D`. Returns: Tensor with shape `[A1, ..., An, V, D]`. @@ -91,68 +420,21 @@ def feature_steered_convolution(data, TypeError: if the input types are invalid. ValueError: if the input dimensions are invalid. """ - # pyformat: enable - with tf.compat.v1.name_scope( - name, "graph_convolution_feature_steered_convolution", - [data, neighbors, sizes, var_u, var_v, var_c, var_w, var_b]): - data = tf.convert_to_tensor(value=data) - neighbors = tf.compat.v1.convert_to_tensor_or_sparse_tensor(value=neighbors) - if sizes is not None: - sizes = tf.convert_to_tensor(value=sizes) - var_u = tf.convert_to_tensor(value=var_u) - var_v = tf.convert_to_tensor(value=var_v) - var_c = tf.convert_to_tensor(value=var_c) - var_w = tf.convert_to_tensor(value=var_w) - var_b = tf.convert_to_tensor(value=var_b) - - data_ndims = data.shape.ndims - utils.check_valid_graph_convolution_input(data, neighbors, sizes) - shape.compare_dimensions( - tensors=(data, var_u, var_v, var_w), - tensor_names=("data", "var_u", "var_v", "var_w"), - axes=(-1, 0, 0, 1)) - shape.compare_dimensions( - tensors=(var_u, var_v, var_c, var_w), - tensor_names=("var_u", "var_v", "var_c", "var_w"), - axes=(1, 1, 0, 0)) - shape.compare_dimensions( - tensors=(var_w, var_b), tensor_names=("var_w", "var_b"), axes=-1) - - # Flatten the batch dimensions and remove any vertex padding. - if data_ndims > 2: - if sizes is not None: - sizes_square = tf.stack((sizes, sizes), axis=-1) - else: - sizes_square = None - x_flat, unflatten = utils.flatten_batch_to_2d(data, sizes) - adjacency = utils.convert_to_block_diag_2d(neighbors, sizes_square) - else: - x_flat = data - adjacency = neighbors - x_u = tf.matmul(x_flat, var_u) - x_v = tf.matmul(x_flat, var_v) - adjacency_ind_0 = adjacency.indices[:, 0] - adjacency_ind_1 = adjacency.indices[:, 1] - x_u_rep = tf.gather(x_u, adjacency_ind_0) - x_v_sep = tf.gather(x_v, adjacency_ind_1) - weights_q = tf.exp(x_u_rep + x_v_sep + tf.reshape(var_c, (1, -1))) - weights_q_sum = tf.reduce_sum( - input_tensor=weights_q, axis=-1, keepdims=True) - weights_q = weights_q / weights_q_sum - y_i_m = [] - x_sep = tf.gather(x_flat, adjacency_ind_1) - q_m_list = tf.unstack(weights_q, axis=-1) - w_m_list = tf.unstack(var_w, axis=0) - for q_m, w_m in zip(q_m_list, w_m_list): - # Compute `y_i_m = sum_{j in neighborhood(i)} q_m(x_i, x_j) * w_m * x_j`. - q_m = tf.expand_dims(q_m, axis=-1) - p_sum = utils.partition_sums_2d(q_m * x_sep, adjacency_ind_0, - adjacency.values) - y_i_m.append(tf.matmul(p_sum, w_m)) - y_out = tf.add_n(inputs=y_i_m) + tf.reshape(var_b, [1, -1]) - if data_ndims > 2: - y_out = unflatten(y_out) - return y_out + if version not in _versions: + raise ValueError( + 'Invalid version {}. Must be one of {}'.format( + version, sorted(_versions))) + return _versions[version]( + data, + neighbors, + sizes, + var_u, + var_v, + var_c, + var_w, + var_b, + name=name, + **kwargs) def edge_convolution_template(data, diff --git a/tensorflow_graphics/geometry/convolution/tests/graph_convolution_test.py b/tensorflow_graphics/geometry/convolution/tests/graph_convolution_test.py index 27f28cdbb..70909eef8 100644 --- a/tensorflow_graphics/geometry/convolution/tests/graph_convolution_test.py +++ b/tensorflow_graphics/geometry/convolution/tests/graph_convolution_test.py @@ -64,19 +64,24 @@ def _random_data(batch_size, only_self_edges, data_type=np.float32, neighbors_type=np.float32, - sizes_type=np.int32): + sizes_type=np.int32, + sparsity=0.25, + random_state=None): """Create random inputs for feature_steered_convolution.""" + if random_state is None: + random_state = np.random def _random_data_2d(padding): - size = num_vertices if not padding else np.random.randint( + size = num_vertices if not padding else random_state.randint( low=1, high=num_vertices + 1) - data = np.random.uniform(size=(size, num_channels)).astype(data_type) + data = random_state.uniform(size=(size, num_channels)).astype(data_type) if only_self_edges: neighbors = np.eye(size, dtype=neighbors_type) else: - random = np.random.uniform(size=(size, size)).astype(neighbors_type) + random = random_state.uniform(size=(size, size)).astype(neighbors_type) neighbors = np.maximum( - np.where(random > 0.75, np.ones_like(random), np.zeros_like(random)), + np.where(random < sparsity, + np.ones_like(random), np.zeros_like(random)), np.eye(size, dtype=neighbors_type)) neighbors = neighbors / np.sum(neighbors, axis=1, keepdims=True) if padding: @@ -108,11 +113,14 @@ def _random_data_2d(padding): def _random_variables(in_channels, out_channels, num_weight_matrices, - dtype=np.float32): + dtype=np.float32, + random_state=None): """Create random variables for feature_steered_convolution.""" + if random_state is None: + random_state = np.random def _random_constant(shape, dtype): - return tf.constant(np.random.uniform(size=shape).astype(dtype)) + return tf.constant(random_state.uniform(size=shape).astype(dtype)) var_u = _random_constant([in_channels, num_weight_matrices], dtype) var_v = _random_constant([in_channels, num_weight_matrices], dtype) diff --git a/tensorflow_graphics/nn/layer/graph_convolution.py b/tensorflow_graphics/nn/layer/graph_convolution.py index cda07bdbb..90e49a36b 100644 --- a/tensorflow_graphics/nn/layer/graph_convolution.py +++ b/tensorflow_graphics/nn/layer/graph_convolution.py @@ -32,7 +32,8 @@ def feature_steered_convolution_layer( num_output_channels=None, initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.1), name=None, - var_name=None): + var_name=None, + **kwargs): # pyformat: disable """Wraps the function `feature_steered_convolution` as a TensorFlow layer. @@ -79,6 +80,7 @@ def feature_steered_convolution_layer( feature_steered_convolution(). var_name: A (var_scope) name for the variables. Defaults to `graph_convolution_feature_steered_convolution_weights`. + **kwargs: passed to `feature_steered_convolution` Returns: Tensor with shape `[A1, ..., An, V, num_output_channels]`. @@ -135,7 +137,8 @@ def feature_steered_convolution_layer( var_c=var_c, var_w=var_w, var_b=var_b, - name=name) + name=name, + **kwargs) class FeatureSteeredConvolutionKerasLayer(tf.keras.layers.Layer): @@ -147,7 +150,8 @@ def __init__(self, num_output_channels=None, initializer=None, name=None, - **kwargs): + impl_kwargs=None, + **layer_kwargs): """Initializes FeatureSteeredConvolutionKerasLayer. Args: @@ -161,17 +165,31 @@ def __init__(self, initializer: An initializer for the trainable variables. If `None`, defaults to `tf.compat.v1.truncated_normal_initializer(stddev=0.1)`. name: A name for this layer. - **kwargs: Additional keyword arguments passed to the base layer. + impl_kwargs: dict of addition keyword arguments passed to + `feature_steered_convolution`. + **layer_kwargs: Additional keyword arguments passed to the base layer. """ super(FeatureSteeredConvolutionKerasLayer, self).__init__( - name=name, **kwargs) + name=name, **layer_kwargs) self._num_weight_matrices = num_weight_matrices self._num_output_channels = num_output_channels self._translation_invariant = translation_invariant if initializer is None: self._initializer = tf.compat.v1.truncated_normal_initializer(stddev=0.1) else: - self._initializer = initializer + self._initializer = tf.keras.initializers.get(initializer) + self._impl_kwargs = impl_kwargs + + def get_config(self): + config = super(FeatureSteeredConvolutionKerasLayer, self).get_config() + config.update(dict( + num_weight_matrices=self._num_weight_matrices, + num_output_channels=self._num_output_channels, + translation_invariant=self._translation_invariant, + initializer=tf.keras.utils.serialize_keras_object(self._initializer), + impl_kwargs=self._impl_kwargs.copy(), + )) + return config def build(self, input_shape): """Initializes the trainable weights.""" @@ -265,7 +283,8 @@ def call(self, inputs, sizes=None): var_v=self.var_v, var_c=self.var_c, var_w=self.var_w, - var_b=self.var_b) + var_b=self.var_b, + **(self._impl_kwargs or {})) class DynamicGraphConvolutionKerasLayer(tf.keras.layers.Layer): From 3323d1c1b2c2ff2cef4e26ac3c8adc71d54f78f4 Mon Sep 17 00:00:00 2001 From: Dominic Jack Date: Fri, 27 Sep 2019 17:16:33 +1000 Subject: [PATCH 2/3] added transform_data_first option to all implementations --- .../feature_steered_conv_benchmark.py | 10 +++-- .../feature_steered_model_benchmark.py | 6 +-- .../geometry/convolution/graph_convolution.py | 45 ++++++++++++------- 3 files changed, 39 insertions(+), 22 deletions(-) diff --git a/tensorflow_graphics/benchmarks/feature_steered_conv_benchmark.py b/tensorflow_graphics/benchmarks/feature_steered_conv_benchmark.py index 7e5046eaf..aeb0eedb8 100644 --- a/tensorflow_graphics/benchmarks/feature_steered_conv_benchmark.py +++ b/tensorflow_graphics/benchmarks/feature_steered_conv_benchmark.py @@ -50,18 +50,22 @@ def main(args): # v1_p2d is the original implementation fast = dict(memory_efficient=False) + bad = dict(transform_data_first=FLAGS.in_channels <= FLAGS.out_channels) names, kwargs = zip(*( ('v1_p2d', dict(version='v1', segment_sum_impl='partition2d', **fast)), + ('v1_p2d_bad', dict( + version='v1', segment_sum_impl='partition2d', **fast, **bad)), ('v1_sorted', dict(version='v1', segment_sum_impl='sorted', **fast)), ('v1_unsorted', dict(version='v1', segment_sum_impl='unsorted', **fast)), ('v1_p2d_mem', dict(version='v1', segment_sum_impl='partition2d')), ('v1_sorted_mem', dict(version='v1', segment_sum_impl='sorted')), ('v1_unsorted_mem', dict(version='v1', segment_sum_impl='unsorted')), - ('v2_default', dict(version='v2')), # will be same as one of the below - ('v2_first', dict(version='v2', transform_data_first=True)), - ('v2_last', dict(version='v2', transform_data_first=False)), + ('v2', dict(version='v2')), + ('v2_bad', dict(version='v2', **bad)), ('v3', dict(version='v3', **fast)), ('v3_mem', dict(version='v3')), + ('v3_bad', dict(version='v3', **bad, **fast)), + ('v3_mem_bad', dict(version='v3', **bad)), )) if FLAGS.mem_only: names, kwargs = zip(*( diff --git a/tensorflow_graphics/benchmarks/feature_steered_model_benchmark.py b/tensorflow_graphics/benchmarks/feature_steered_model_benchmark.py index 5129c804c..dc991b49c 100644 --- a/tensorflow_graphics/benchmarks/feature_steered_model_benchmark.py +++ b/tensorflow_graphics/benchmarks/feature_steered_model_benchmark.py @@ -209,9 +209,7 @@ def main(args): ('v1_p2d_mem', dict(version='v1', segment_sum_impl='partition2d')), ('v1_sorted_mem', dict(version='v1', segment_sum_impl='sorted')), ('v1_unsorted_mem', dict(version='v1', segment_sum_impl='unsorted')), - ('v2_default', dict(version='v2')), # will be same as one of the below - # ('v2_first', dict(version='v2', transform_data_first=True)), - # ('v2_last', dict(version='v2', transform_data_first=False)), + ('v2', dict(version='v2')), # will be same as one of the below ('v3', dict(version='v3', **fast)), ('v3_mem', dict(version='v3')), )) @@ -229,7 +227,7 @@ def main(args): print('*************') print('** SUMMARY **') print('*************') - print('num_filters = {}'.format(num_filters)) + print('{:15s}: {}'.format('num_filters', num_filters)) times = np.array(times) # ti = np.argmin(times) diff --git a/tensorflow_graphics/geometry/convolution/graph_convolution.py b/tensorflow_graphics/geometry/convolution/graph_convolution.py index cc31beee9..6e5aa990a 100644 --- a/tensorflow_graphics/geometry/convolution/graph_convolution.py +++ b/tensorflow_graphics/geometry/convolution/graph_convolution.py @@ -75,6 +75,7 @@ def feature_steered_convolution_v1(data, var_b, memory_efficient=True, segment_sum_impl='partition2d', + transform_data_first=None, name=None): """Implements the Feature Steered graph convolution. @@ -114,13 +115,19 @@ def feature_steered_convolution_v1(data, x_sep = tf.gather(x_flat, adjacency_ind_1) V = tf.shape(x_flat)[0] + W, C, D = var_w.shape + if transform_data_first is None: + transform_data_first = C > D + def get_mth_term(q_m, w_m): + x = tf.matmul(x_sep, w_m) if transform_data_first else x_sep + if segment_sum_impl == 'partition2d': q_m = tf.expand_dims(q_m, axis=-1) - p_sum = utils.partition_sums_2d(q_m * x_sep, adjacency_ind_0, + p_sum = utils.partition_sums_2d(q_m * x, adjacency_ind_0, adjacency.values) else: - args = (x_sep * tf.expand_dims(q_m * adjacency.values, axis=-1), + args = (x * tf.expand_dims(q_m * adjacency.values, axis=-1), adjacency_ind_0) if segment_sum_impl == 'sorted': p_sum = tf.math.segment_sum(*args) @@ -130,7 +137,9 @@ def get_mth_term(q_m, w_m): raise ValueError( 'Invalid segment_sum_impl "{}" - must be one of "partition2d", ' '"sorted", "unsorted"'.format(segment_sum_impl)) - return tf.matmul(p_sum, w_m) + if not transform_data_first: + p_sum = tf.matmul(p_sum, w_m) + return p_sum if memory_efficient: y_out = tf.foldl( @@ -177,11 +186,6 @@ def feature_steered_convolution_v2(data, a slightly slower more memory efficient implementation. For base arg/return descriptions, see See `feature_steered_convolution`. - - Additional args: - transform_data_first: if True, performs transformation of features from - [V, C] -> [V, D, W] via `var_w` before other multiplications. - Defaults to `C > D`. """ with tf.compat.v1.name_scope( name, "graph_convolution_feature_steered_convolution_v2", @@ -261,6 +265,7 @@ def feature_steered_convolution_v3(data, var_w, var_b, memory_efficient=True, + transform_data_first=None, name=None): """Implements the Feature Steered graph convolution. @@ -308,10 +313,17 @@ def feature_steered_convolution_v3(data, W, C, D = var_w.shape assert(C is not None and D is not None) + if transform_data_first is None: + transform_data_first = C > D + def get_mth_term(wm, e_ucm, e_vm): - summed_ef = tf.sparse.sparse_dense_matmul( - weighted_adjacency, tf.expand_dims(e_vm, axis=-1) * x_flat) - return tf.matmul(tf.expand_dims(e_ucm, axis=-1) * summed_ef, wm) + x = tf.matmul(x_flat, wm) if transform_data_first else x_flat + ex = tf.expand_dims(e_vm, axis=-1) * x + summed_ex = tf.sparse.sparse_dense_matmul(weighted_adjacency, ex) + summed_qx = tf.expand_dims(e_ucm, axis=-1) * summed_ex + if not transform_data_first: + summed_qx = tf.matmul(summed_qx, wm) + return summed_qx if memory_efficient: y_flat = tf.foldl( @@ -347,6 +359,7 @@ def feature_steered_convolution(data, var_c, var_w, var_b, + transform_data_first=None, version='v1', name=None, **kwargs): @@ -397,6 +410,11 @@ def feature_steered_convolution(data, var_c: A 1-D tensor with shape `[W]`. var_w: A 3-D tensor with shape `[W, C, D]`. var_b: A 1-D tensor with shape `[D]`. + transform_data_first: bool influencing the order of matrix operations. + Summations can essentially be written as N @ x @ w_m. This parameter + determines whether this is implemented as `N @ (x @ w_m)` or + `(N @ x) @ w_m`. Default is to do manipulations in the lower dimensional + space, so default value is `C > D`. version: string indicating implementation version, one of "v1", "v2", "v3". See `feature_steered_convolution_v1` / `feature_steered_convolution_v2` etc. @@ -408,10 +426,6 @@ def feature_steered_convolution(data, `tf.math.unsorted_segment_sum`. memory_efficient (default True): for "v1", "v3", uses a more memory efficient implementation at the cost of slightly slower runtime. - transform_data_first: for "v2", if True transforms data from - [V, C] -> [V, D, W] before some transformations. This does not affect - the results (aside from floating point errors), but may result in - a performance difference. Defaults to `C > D`. Returns: Tensor with shape `[A1, ..., An, V, D]`. @@ -433,6 +447,7 @@ def feature_steered_convolution(data, var_c, var_w, var_b, + transform_data_first=transform_data_first, name=name, **kwargs) From f9f412ebcd43c8422d21eb2a09631a08c7d83cef Mon Sep 17 00:00:00 2001 From: Dominic Jack Date: Sat, 28 Sep 2019 00:08:44 +1000 Subject: [PATCH 3/3] fixed shape invariance issue --- .../geometry/convolution/graph_convolution.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tensorflow_graphics/geometry/convolution/graph_convolution.py b/tensorflow_graphics/geometry/convolution/graph_convolution.py index 6e5aa990a..11719fc70 100644 --- a/tensorflow_graphics/geometry/convolution/graph_convolution.py +++ b/tensorflow_graphics/geometry/convolution/graph_convolution.py @@ -65,6 +65,17 @@ def _prepare_feature_steered_args( return x_flat, adjacency, var_u, var_v, var_c, var_w, var_b, unflatten +def _sum_reducer(term_fn): + def f(acc, args): + term = term_fn(*args) + # the following is needed because foldl doesn't expose shape_invariants + # and partition_sums_2d doesn't set size. + if acc.shape[0] is not None: + term.set_shape((acc.shape[0], acc.shape[1])) + return acc + term + return f + + def feature_steered_convolution_v1(data, neighbors, sizes, @@ -142,8 +153,9 @@ def get_mth_term(q_m, w_m): return p_sum if memory_efficient: + y_out = tf.foldl( - lambda acc, args: acc + get_mth_term(*args), + _sum_reducer(get_mth_term), (tf.transpose(weights_q, (1, 0)), var_w), tf.tile(tf.expand_dims(var_b, axis=0), (tf.shape(x_flat)[0], 1))) else: @@ -327,7 +339,7 @@ def get_mth_term(wm, e_ucm, e_vm): if memory_efficient: y_flat = tf.foldl( - lambda acc, args: acc + get_mth_term(*args), + _sum_reducer(get_mth_term), (var_w, tf.transpose(e_uc, (1, 0)), tf.transpose(e_v, (1, 0))), tf.tile(tf.expand_dims(var_b, axis=0), (tf.shape(e_uc)[0], 1))) else: