Skip to content

Conversation

@jackd
Copy link
Contributor

@jackd jackd commented Sep 2, 2019

The current partition_sums_2d, from what I can tell, is almost always slower and less memory efficient than tf.math.unsorted_segment_sum (See example below which exhibits 2x speed-up and 10x memory reduction). This PR:

  • replaces the current implementation with a light wrapper around tf.math.unsorted_segment_sum
  • removed direct calls to partition_sums_2d to make max/weighted reduction code branches more similar (segment_max vs segment_sum respectively).
  • fixes a bug that allowed unsorted indices to be used in tf.segment_max (and added tests demonstrating possible errors).
  • changes a reference to dimension.value that was annoying me (I've taken to using tf.enable_v2_tensorshape which wasn't compatible with this code) (Maybe this shouldn't be a part of this PR... happy to take it out).

An order kwarg has been added to certain methods for potential optimization, though it seems to make minimal difference to performance from what I can tell. Strictly speaking this is a breaking change (adds a non-final kwarg - could put it after name to make it less breaking, but there seems to be a convention that name always goes last).

Changes that I haven't made to keep the PR minimal/mostly breaking but I'll float anyway:

  • reduction in ('max', 'weighted') seems clunky. Could the weighting be considered separately from the reduction, and the reduction just be one of ('max', 'sum')? Better yet, could reduction be one of (tf.math.unsorted_segment_sum, tf.math.segment_sum, tf.unsorted_segment_max, tf.segment_max) and do away with the need to sorted kwarg introduced in this PR?
  • if partition_sums_2d is just a thin wrapper, should it be deprecated?

Benchmark demonstrating performance improvement:

import numpy as np
import tensorflow as tf
from tensorflow_graphics.util import shape
import functools


def partition_sums_2d_original(data, group_ids, row_weights=None, name=None):
  """Original implementation."""
  with tf.compat.v1.name_scope(name, "utils_partition_sums_2d",
                               [data, group_ids, row_weights]):
    data = tf.convert_to_tensor(value=data)
    group_ids = tf.convert_to_tensor(value=group_ids)
    if not group_ids.dtype.is_integer:
      raise TypeError("'group_ids' must be an integer tensor.")
    elif group_ids.dtype != tf.int64:
      group_ids = tf.cast(group_ids, dtype=tf.int64)
    if row_weights is None:
      row_weights = tf.ones_like(group_ids, dtype=data.dtype)
    else:
      row_weights = tf.convert_to_tensor(value=row_weights)

    if row_weights.dtype != data.dtype:
      raise TypeError("'data' and 'row_weights' must have the same type.")
    shape.check_static(tensor=data, tensor_name="data", has_rank=2)
    shape.check_static(tensor=group_ids, tensor_name="group_ids", has_rank=1)
    shape.check_static(
        tensor=row_weights, tensor_name="row_weights", has_rank=1)
    shape.compare_dimensions(
        tensors=(data, group_ids, row_weights),
        tensor_names=("data", "group_ids", "row_weights"),
        axes=0)

    num_rows = tf.size(input=group_ids, out_type=tf.int64)
    sparse_indices = tf.stack((group_ids, tf.range(num_rows)), axis=1)
    out_shape = (tf.reduce_max(input_tensor=group_ids) + 1, num_rows)
    sparse = tf.SparseTensor(sparse_indices, row_weights, dense_shape=out_shape)
    return tf.sparse.sparse_dense_matmul(sparse, data)


def partition_sums_2d_new(
    data, group_ids, row_weights=None,  is_sorted=False, name=None):
  """Implementation in this PR."""
  with tf.compat.v1.name_scope(name, "utils_partition_sums_2d",
                               [data, group_ids, row_weights]):
    data = tf.convert_to_tensor(value=data)
    group_ids = tf.convert_to_tensor(value=group_ids)
    if not group_ids.dtype.is_integer:
      raise TypeError("'group_ids' must be an integer tensor.")
    elif group_ids.dtype != tf.int64:
      group_ids = tf.cast(group_ids, dtype=tf.int64)
    
    shape.check_static(tensor=data, tensor_name="data", has_rank=2)
    shape.check_static(tensor=group_ids, tensor_name="group_ids", has_rank=1)

    if row_weights is None:
      shape.compare_dimensions(
        tensors=(data, group_ids),
        tensor_names=("data", "group_ids"),
        axes=0)
    else:
      row_weights = tf.convert_to_tensor(value=row_weights)
      shape.check_static(
        tensor=row_weights, tensor_name="row_weights", has_rank=1)
      if row_weights.dtype != data.dtype:
        raise TypeError("'data' and 'row_weights' must have the same type.")
      shape.compare_dimensions(
        tensors=(data, group_ids, row_weights),
        tensor_names=("data", "group_ids", "row_weights"),
        axes=0)
      data = data * tf.expand_dims(row_weights, axis=-1)
    
    return tf.math.unsorted_segment_sum(
      data, group_ids, tf.reduce_max(group_ids) + 1)


if __name__ == '__main__':
  r = np.random.RandomState(123)
  num_edges = 10000
  num_vertices = 1000
  num_features = 100
  ordered = False

  group_ids = r.randint(0, num_vertices, size=(num_edges,))
  if ordered:
    np.sort(group_ids)
  data = r.uniform(size=(num_edges, num_features)).astype(np.float32)
  row_weights = r.uniform(size=(num_edges,)).astype(np.float32)


for name, fn in (
    ('original', partition_sums_2d_original),
    ('new', functools.partial(partition_sums_2d_new, is_sorted=ordered))):
  with tf.Graph().as_default():
    op = fn(data, group_ids, row_weights)
    with tf.Session() as sess:
      print('--------------------------')
      print(name)
      tf.test.Benchmark().run_op_benchmark(sess, op)

Output

--------------------------
original
entry {
  name: "TensorFlowBenchmark.run_op_benchmark"
  iters: 10
  wall_time: 0.0004891157150268555
  extras {
    key: "allocator_maximum_num_bytes_GPU_0_bfc"
    value {
      double_value: 4600016.0
    }
  }
}
--------------------------
new
entry {
  name: "TensorFlowBenchmark.run_op_benchmark"
  iters: 10
  wall_time: 0.0002224445343017578
  extras {
    key: "allocator_maximum_num_bytes_GPU_0_bfc"
    value {
      double_value: 400000.0
    }
  }
}

@jackd jackd changed the title optimized , removed direct calls optimized partition_sums_2d Sep 2, 2019
@julienvalentin
Copy link
Contributor

Hi Jack,

Thanks for the PR! We were super heads down recently, but will look at this shortly.

Best.

@jackd
Copy link
Contributor Author

jackd commented Sep 28, 2019

@julienvalentin this other PR has some additional benchmarks related to this idea. It's primarily focused on other ideas, but incporating this different summation implementation was easy enough and allows for more meaningful benchmarks.

@julienvalentin
Copy link
Contributor

Hi Jack,

Ameesh recently pushed 228472e on this matter. Ameesh, when you get the chance it would be great if you could feed Jack with the details of what you have tried and submitted.

Best.

@amakadia
Copy link
Contributor

amakadia commented Nov 6, 2019

Hi Jack,
Thanks for investigating this issue. It seems the simplest course of action was to do a straight swap of partition_sums_2d with unsorted_segment_sum.

A interesting note: we found the unsorted_segment_* ops seem to be an order of magnitude faster than their corresponding segment_* ops (this was surprising!). We could also swap the segment_max for unsorted_segment_max but currently there is an issue with the gradient for unsorted_segment_max so we're leaving as-is.

@jackd
Copy link
Contributor Author

jackd commented Nov 6, 2019

we found the unsorted_segment_* ops seem to be an order of magnitude faster than their corresponding segment_* ops

This is the kind of thing that destroys my faith in humanity.

@amakadia did you also find the memory reduction from using foldl rather than unstack ? From memory I was getting a ~M memory reduction for only a slight degradation in performance (~20% if I remember correctly). I won't have much free time in the next couple of weeks, but if that's an option that would be considered I'm happy to put in a separate PR (this one convoluted that change in with a bunch of other things).

@julienvalentin
Copy link
Contributor

@amakadia can you recommend what the next steps should be here?

@jackd
Copy link
Contributor Author

jackd commented Jan 21, 2020

I'm satisfied these changes are now in master

@jackd jackd closed this Jan 21, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants