Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions benchmarks/graph_convolution_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from absl import logging
from absl import app, flags
import numpy as np
import tensorflow as tf
from tensorflow_graphics.nn.layer.graph_convolution import \
FeatureSteeredConvolutionKerasLayer
from tensorflow_graphics.geometry.convolution.graph_convolution import \
SparseImplementation

logging.info('Finished imports')

Lambda = tf.keras.layers.Lambda
Input = tf.keras.Input

flags.DEFINE_boolean('jit', default=False, help='use XLA jit compilation')
flags.DEFINE_boolean('sparse', default=False, help='use sparse implementation')
flags.DEFINE_boolean('sort', default=False, help='use sorted indices')
flags.DEFINE_boolean(
'backward', default=False, help='benchmark forward and backward pass')
flags.DEFINE_integer(
'burn_iters', default=10, help='number of burn in iterations')
flags.DEFINE_integer('nv', default=100000, help='number of vertices')
flags.DEFINE_integer('ne',
default=-1,
help='number of edges, -1 will result in using 10*nv')
flags.DEFINE_integer('min_iters',
default=20,
help='minimum number of iterations to benchmark')
flags.DEFINE_integer(
'num_weight_matrices', default=8, help='number of weight matrices')
flags.DEFINE_integer(
'num_output_channels', default=32, help='number of output channels')
flags.DEFINE_integer('num_layers', default=10, help='number of layers')


def summarize(result, print_fn=print):
"""
Args:
result: output of a tf.test.Benchmark.run_op_benchmark call.
print_fn: print-like function.
"""
print_fn('Wall time (ms): {}'.format(result['wall_time'] * 1000))
gpu_mem = result['extras'].get('allocator_maximum_num_bytes_GPU_0_bfc', 0)
print_fn('Memory (Mb): {}'.format(gpu_mem / 1024**2))


def get_data(num_vertices, num_edges, sort=True):
if num_edges == -1:
num_edges = 10 * num_vertices
vertices = np.random.uniform(size=(num_vertices, 3)).astype(np.float32)
# replace=False below gives memory issues
indices = np.random.choice(num_vertices**2, num_edges, replace=True)
if sort:
indices.sort()
i, j = np.unravel_index(indices, (num_vertices, num_vertices)) # pylint: disable=unbalanced-tuple-unpacking

counts = np.zeros((num_vertices,), dtype=np.int64)
for ii in i:
counts[ii] += 1
weights = (1. / counts)[i].astype(np.float32)
indices = np.stack((i, j), axis=-1)

return vertices, indices, weights


def main(_):
FLAGS = flags.FLAGS
tf.config.optimizer.set_jit(FLAGS.jit)
tf.keras.backend.clear_session()
vertices, indices, weights = get_data(FLAGS.nv, FLAGS.ne, sort=FLAGS.sort)
nv = vertices.shape[0]

with tf.Graph().as_default():
vertices = tf.constant(vertices, dtype=tf.float32)
indices = tf.constant(indices, dtype=tf.int64)
weights = tf.constant(weights, dtype=tf.float32)
# batch size of 1
vertices, indices, weights = (
tf.expand_dims(t, axis=0) for t in (vertices, indices, weights))

data = Input(tensor=vertices)
indices = Input(tensor=indices)
weights = Input(tensor=weights)
inputs = (data, indices, weights)
data, indices, weights = tuple(
Lambda(tf.squeeze, arguments=dict(axis=0), name='squeeze{}'.format(i))(t)
for i, t in enumerate(inputs))
nv = Lambda(lambda x: tf.shape(x, out_type=tf.int64)[0])(data)

neighbors = Lambda(
lambda args: tf.SparseTensor(args[0], args[1], (args[2], args[2])))(
[indices, weights, nv])

for _ in range(FLAGS.num_layers):
layer = FeatureSteeredConvolutionKerasLayer(
sparse_impl=(
SparseImplementation.SPARSE_MATMUL if FLAGS.sparse else
SparseImplementation.GATHER_SUM),
num_weight_matrices=FLAGS.num_weight_matrices,
num_output_channels=FLAGS.num_output_channels)
data = layer([data, neighbors])
data = tf.nn.relu(data)

pred = data
output = Lambda(tf.expand_dims, arguments=dict(axis=0))(pred)
model = tf.keras.Model(inputs=inputs, outputs=output)

if FLAGS.backward:
loss = tf.reduce_sum(pred)
optimizer = tf.keras.optimizers.SGD()

model_weights = model.trainable_weights

grads = optimizer.get_gradients(loss, model_weights)
grads_and_vars = tuple(zip(grads, model_weights))
train_op = optimizer.apply_gradients(grads_and_vars)
else:
train_op = pred

bm = tf.test.Benchmark()
with tf.compat.v1.Session() as sess:
logging.info('Initializing variables...')

sess.run(tf.compat.v1.global_variables_initializer())

logging.info('Starting benchmarking...')
result = bm.run_op_benchmark(
sess, train_op, burn_iters=FLAGS.burn_iters,
min_iters=FLAGS.min_iters)
summarize(result)


if __name__ == '__main__':
app.run(main)
93 changes: 73 additions & 20 deletions tensorflow_graphics/geometry/convolution/graph_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,48 @@
from tensorflow_graphics.util import shape


class SparseImplementation(object):
GATHER_SUM = 'gather_sum'
SPARSE_MATMUL = 'sparse_matmul'

@classmethod
def all(cls):
return (SparseImplementation.GATHER_SUM, SparseImplementation.SPARSE_MATMUL)

@classmethod
def validate(cls, key):
if key not in cls.all():
raise ValueError('Invalid SparseImplementation Key %s.' % key)


def _gather_sum(adjacency_ind_0, adjacency_ind_1, adjacency_values, weights_q,
x_flat):
q_m_list = tf.unstack(weights_q, axis=-1)
p_sums = []

x_sep = tf.gather(x_flat, adjacency_ind_1)
num_segments = tf.shape(input=x_flat)[0]
for q_m in q_m_list:
# Compute `y_i_m = sum_{j in neighborhood(i)} q_m(x_i, x_j) * w_m * x_j`.
q_m = tf.expand_dims(q_m, axis=-1)
p_sums.append(tf.math.unsorted_segment_sum(
data=(q_m * x_sep) * tf.expand_dims(adjacency_values, -1),
segment_ids=adjacency_ind_0,
num_segments=num_segments))
return p_sums


def _sparse_matmul(
adjacency_indices, adjacency_values, dense_shape, weights_q, x_flat):
p_sums = []
q_m_list = tf.unstack(weights_q, axis=-1)
for q_m in q_m_list:
sp = tf.SparseTensor(
adjacency_indices, adjacency_values * q_m, dense_shape)
p_sums.append(tf.sparse.sparse_dense_matmul(sp, x_flat))
return p_sums


def feature_steered_convolution(data,
neighbors,
sizes,
Expand All @@ -32,7 +74,8 @@ def feature_steered_convolution(data,
var_c,
var_w,
var_b,
name=None):
name=None,
sparse_impl=SparseImplementation.GATHER_SUM):
# pyformat: disable
"""Implements the Feature Steered graph convolution.

Expand All @@ -42,7 +85,8 @@ def feature_steered_convolution(data,
https://arxiv.org/abs/1706.05206

The shorthands used below are
`V`: The number of vertices.
`Vi`: The number of vertices in the input.
`Vo`: The number of vertices in the output.
`C`: The number of channels in the input data.
`D`: The number of channels in the output after convolution.
`W`: The number of weight matrices used in the convolution.
Expand All @@ -53,9 +97,9 @@ def feature_steered_convolution(data,
In the following, A1 to An are optional batch dimensions.

Args:
data: A `float` tensor with shape `[A1, ..., An, V, C]`.
data: A `float` tensor with shape `[A1, ..., An, Vi, C]`.
neighbors: A `SparseTensor` with the same type as `data` and with shape
`[A1, ..., An, V, V]` representing vertex neighborhoods. The neighborhood
`[A1, ..., An, Vo, Vi]` representing vertex neighborhoods. The neighborhood
of a vertex defines the support region for convolution. For a mesh, a
common choice for the neighborhood of vertex i would be the vertices in
the K-ring of i (including i itself). Each vertex must have at least one
Expand All @@ -81,17 +125,21 @@ def feature_steered_convolution(data,
var_c: A 1-D tensor with shape `[W]`.
var_w: A 3-D tensor with shape `[W, C, D]`.
var_b: A 1-D tensor with shape `[D]`.
sparse_impl: `SparseImplementation` prop, ('gather_sum', 'sparse_matmul').
This influences the computational performance. Results should be the same
except for floating point errors.
name: A name for this op. Defaults to
`graph_convolution_feature_steered_convolution`.

Returns:
Tensor with shape `[A1, ..., An, V, D]`.
Tensor with shape `[A1, ..., An, Vo, D]`.

Raises:
TypeError: if the input types are invalid.
ValueError: if the input dimensions are invalid.
"""
# pyformat: enable
SparseImplementation.validate(sparse_impl)
with tf.compat.v1.name_scope(
name, "graph_convolution_feature_steered_convolution",
[data, neighbors, sizes, var_u, var_v, var_c, var_w, var_b]):
Expand Down Expand Up @@ -131,28 +179,33 @@ def feature_steered_convolution(data,
adjacency = neighbors
x_u = tf.matmul(x_flat, var_u)
x_v = tf.matmul(x_flat, var_v)
adjacency_ind_0 = adjacency.indices[:, 0]
adjacency_ind_1 = adjacency.indices[:, 1]
adjacency_ind_0, adjacency_ind_1 = tf.unstack(adjacency.indices, axis=1)
x_u_rep = tf.gather(x_u, adjacency_ind_0)
x_v_sep = tf.gather(x_v, adjacency_ind_1)
weights_q = tf.exp(x_u_rep + x_v_sep + tf.reshape(var_c, (1, -1)))
weights_q_sum = tf.reduce_sum(
logits = x_u_rep + x_v_sep + tf.reshape(var_c, (1, -1))
# numerically stable softmax
# compared to tf.nn.softmax, this gives better performance when JIT compiled
logits = logits - tf.reduce_max(logits, axis=-1, keepdims=True)
weights_q = tf.exp(logits)
weights_q = weights_q / tf.reduce_sum(
input_tensor=weights_q, axis=-1, keepdims=True)
weights_q = weights_q / weights_q_sum
y_i_m = []
x_sep = tf.gather(x_flat, adjacency_ind_1)
q_m_list = tf.unstack(weights_q, axis=-1)

if sparse_impl == SparseImplementation.GATHER_SUM:
p_sums = _gather_sum(
adjacency_ind_0, adjacency_ind_1, adjacency.values, weights_q, x_flat)
else:
assert(sparse_impl == SparseImplementation.SPARSE_MATMUL)
p_sums = _sparse_matmul(
adjacency.indices, adjacency.values, adjacency.dense_shape, weights_q,
x_flat)

w_m_list = tf.unstack(var_w, axis=0)
y_i_m = []

x_flat_shape = tf.shape(input=x_flat)
for q_m, w_m in zip(q_m_list, w_m_list):
for p_sum, w_m in zip(p_sums, w_m_list):
# Compute `y_i_m = sum_{j in neighborhood(i)} q_m(x_i, x_j) * w_m * x_j`.
q_m = tf.expand_dims(q_m, axis=-1)
p_sum = tf.math.unsorted_segment_sum(
data=(q_m * x_sep) * tf.expand_dims(adjacency.values, -1),
segment_ids=adjacency_ind_0,
num_segments=x_flat_shape[0])
y_i_m.append(tf.matmul(p_sum, w_m))

y_out = tf.add_n(inputs=y_i_m) + tf.reshape(var_b, [1, -1])
if data_ndims > 2:
y_out = unflatten(y_out)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,13 +333,23 @@ def test_feature_steered_convolution_only_self_edges(self, batch_size,

@parameterized.parameters(
(((1.0,), (2.0,), (3.0,)), np.ones(shape=(3, 3)) / 3.0, ((0.5,),),
((1.3,),), (-0.7,), (((0.8,),),), (3.0,), ((4.6,), (4.6,), (4.6,))),
((1.3,),), (-0.7,), (((0.8,),),), (3.0,), ((4.6,), (4.6,), (4.6,)),
'gather_sum'),
(((1.0,), (2.0,), (3.0,)), np.ones(shape=(3, 3)) / 3.0, ((0.5, 0.2),),
((0.3, 0.4),), (-0.7, 0.15), (((0.8,),), ((1.1,),)), (3.0,),
((5.011706928844621,), (4.971030281984818,), (4.927388658982911,))),
((5.011706928844621,), (4.971030281984818,), (4.927388658982911,)),
'gather_sum'),
(((1.0,), (2.0,), (3.0,)), np.ones(shape=(3, 3)) / 3.0, ((0.5,),),
((1.3,),), (-0.7,), (((0.8,),),), (3.0,), ((4.6,), (4.6,), (4.6,)),
'sparse_matmul'),
(((1.0,), (2.0,), (3.0,)), np.ones(shape=(3, 3)) / 3.0, ((0.5, 0.2),),
((0.3, 0.4),), (-0.7, 0.15), (((0.8,),), ((1.1,),)), (3.0,),
((5.011706928844621,), (4.971030281984818,), (4.927388658982911,)),
'sparse_matmul'),
)
def test_feature_steered_convolution_padding_preset(self, data, neighbors, u,
v, c, w, b, expected):
v, c, w, b, expected,
sparse_impl):
"""Test expected result for preset data and filter values."""
array = (np.array(i) for i in (data, neighbors, expected))
data, neighbors, expected = array
Expand All @@ -354,7 +364,8 @@ def test_feature_steered_convolution_padding_preset(self, data, neighbors, u,
var_v=v,
var_c=c,
var_w=w,
var_b=b)
var_b=b,
sparse_impl=sparse_impl)
self.assertAllClose(y, expected)

@parameterized.parameters(
Expand Down
Loading