Skip to content

Conversation

@jackd
Copy link
Contributor

@jackd jackd commented Sep 27, 2019

This PR is the result of some work I've been working on variations on feature steered convolutions. Each variation has it's own strengths and weaknesses, so I'm unsure if graphics would like to accept any/all of the options, or how to package them.

Algebraic manipulations

The idea behind the changes is to avoid storing temporary feature values at each edge (caused by gather -> reduce) by using tf.sparse.sparse_dense_matmul. In order to do this, we perform the following manipulations to the m terms from the paper (m == w in code).

y_m = sum_j(neighbors_ij q(x_i, x_j) W_m @ x_j)
        = sum_j(weighted_neighbors_ij @ exp(um.T @ x_i + cm) exp(vm.T @ x_j) *  x_j @ W_m)
        = exp(um.T @ x_i + cm) sum_j(weighted_neighbors_ij * exp(vm.T @ x_j) * x_j @ W_m)

where weighted_neighbors_ij = neighbors_ij / sum_m (q_m(x_i, x_j) is the original neighborhood weighting value divided by the softmax normalization factor.

This inner summation can then be implemented using tf.sparse.sparse_dense_matmul,

y_m = exp(um.T @ x_i + cm) * (weighted_neighbors_ij @ (exp(vm.T @ x_j) * x_j) @  W_m)

All implementations are algebraicly equivalent and yield similar results (to within 1e-5).

Main Changes

  • Renamed existing feature_steered_convolution to feature_steered_convolution_v1 and new feature_steered_convolution which redirects to this or other implementations.
  • Added feature_steered_convolution_v2 which is based on sparse multiplication above. This implementation computes all m terms in a single vectorized block which is fastest, though it requires a feature tensor of shape [V, D, W] before the final dimension is reduced, so requires a lot of memory. The default is, as far as I can tell, always optimal, so this could easily be dropped (including to make testing/verifying easier, but happy to remove if approved).
  • Added feature_steered_convolution_v3 which addresses memory issues in v2 by computing the last dimension of the conceptual [V, D, W] tensor sequentially.
  • Added transform_data_first option to all implementations. This allows transforming x_flat via var_w before or after other multiplications (before is more efficient if the number of features is decreasing). This is equivalent to taking advantage of associativity of matrix multiplication in the second equation above.
  • Added a memory_efficient option to v1 and v3 that uses foldl for sequential additions over W dimension.
  • Added different segment sum implementations to v1 similar to this PR based on tf.segment_sum and tf.unsorted_segment_sum.
  • Added basic benchmarks to show performance improvements for convolutions alone as well as the example model.

Not included

Tests. It would be straight-forward enough to add a version of each implementation to the existing test-suite, but how thorough do we want to be? Basic checks in feature_steered_conv_benchmark.py implementations give values consistent with original implementation.

Benchmark summary

  • v2 is consistently fastest, though uses up to 70% more memory for very sparse neighborhoods. Memory efficiency scales well with less sparse neighborhoods, but poorly with m.
  • memory efficient implementations lead to considerable savings - a factor of 10 for large m- at a small performance penalty (~10%).
  • sorted segment sum is always faster and generally consumes less memory (or roughly the same) compared to custom method implemented in this package.

Sample benchmark results

Name keys:
p2d: uses partition_sums_2d implementation
sorted: uses tf.math.segment_sum
unsorted: uses tf.math.unsorted_segment_sum
bad: uses non-default transform_data_first argument
mem: uses memory efficient implementation

Single convolution
Lower sparsity

*************
** SUMMARY **
*************
batch_size     : 4
num_vertices   : 1000
in_channels    : 32
out_channels   : 16
num_filters    : 8
sparsity       : 0.25
Baseline time: v1_p2d, 0.11661195755004883s
rel times:
v1_p2d          1.000
v1_p2d_bad      1.321
v1_sorted       0.774
v1_unsorted     0.811
v1_p2d_mem      1.246
v1_sorted_mem   0.905
v1_unsorted_mem 1.028
v2_default      0.417
v2_bad          0.694
v3              0.472
v3_mem          0.567
v3_bad          0.730
v3_mem_bad      0.849
Baseline memory: v1_p2d, 1843.4498443603516mb
v1_p2d          1.000
v1_p2d_bad      1.798
v1_sorted       0.793
v1_unsorted     0.678
v1_p2d_mem      0.328
v1_sorted_mem   0.316
v1_unsorted_mem 0.326
v2_default      0.811
v2_bad          1.613
v3              0.580
v3_mem          0.151
v3_bad          1.148
v3_mem_bad      0.250
Errors w.r.t v1_p2d
v1_p2d_bad: 5.7220458984375e-06
v1_sorted : 5.7220458984375e-06
v1_unsorted: 7.62939453125e-06
v1_p2d_mem: 7.62939453125e-06
v1_sorted_mem: 5.7220458984375e-06
v1_unsorted_mem: 8.58306884765625e-06
v2_default: 7.62939453125e-06
v2_bad    : 5.7220458984375e-06
v3        : 7.62939453125e-06
v3_mem    : 8.58306884765625e-06
v3_bad    : 5.7220458984375e-06
v3_mem_bad: 5.7220458984375e-06

Single convolution
Higher sparsity

*************
** SUMMARY **
*************
batch_size     : 2
num_vertices   : 1000
in_channels    : 32
out_channels   : 16
num_filters    : 8
sparsity       : 0.5
Baseline time: v1_p2d, 0.11733543872833252s
rel times:
v1_p2d          1.000
v1_p2d_bad      1.318
v1_sorted       0.771
v1_unsorted     0.815
v1_p2d_mem      1.250
v1_sorted_mem   0.899
v1_unsorted_mem 1.032
v2_default      0.423
v2_bad          0.690
v3              0.452
v3_mem          0.565
v3_bad          0.734
v3_mem_bad      0.830
Baseline memory: v1_p2d, 1840.8919219970703mb
v1_p2d          1.000
v1_p2d_bad      1.798
v1_sorted       0.772
v1_unsorted     0.678
v1_p2d_mem      0.328
v1_sorted_mem   0.315
v1_unsorted_mem 0.326
v2_default      0.810
v2_bad          1.609
v3              0.579
v3_mem          0.150
v3_bad          1.144
v3_mem_bad      0.250
Errors w.r.t v1_p2d
v1_p2d_bad: 6.67572021484375e-06
v1_sorted : 7.62939453125e-06
v1_unsorted: 1.049041748046875e-05
v1_p2d_mem: 8.58306884765625e-06
v1_sorted_mem: 8.58306884765625e-06
v1_unsorted_mem: 9.5367431640625e-06
v2_default: 8.58306884765625e-06
v2_bad    : 7.62939453125e-06
v3        : 9.5367431640625e-06
v3_mem    : 8.58306884765625e-06
v3_bad    : 7.62939453125e-06
v3_mem_bad: 6.67572021484375e-06

Demo model
filters = 8

*************
** SUMMARY **
*************
num_filters    : 8
Baseline time: v1_p2d, 0.1005086898803711s
rel times:
v1_p2d          1.000
v1_sorted       0.571
v1_unsorted     0.728
v1_p2d_mem      0.730
v1_sorted_mem   0.626
v1_unsorted_mem 0.792
v2_default      0.551
v3              0.635
v3_mem          0.674
Baseline memory: v1_p2d, 630.019645690918mb
v1_p2d          1.000
v1_sorted       1.045
v1_unsorted     1.149
v1_p2d_mem      0.455
v1_sorted_mem   0.389
v1_unsorted_mem 0.405
v2              1.733
v3              1.376
v3_mem          0.366

Demo model
filters = 64

*************
** SUMMARY **
*************
num_filters    : 64
Baseline time: v1_p2d, 0.4433619976043701s
rel times:
v1_p2d          1.000
v1_sorted       0.803
v1_unsorted     1.093
v1_p2d_mem      1.159
v1_sorted_mem   0.989
v1_unsorted_mem 1.302
v2              0.764
v3              0.952
v3_mem          1.087
Baseline memory: v1_p2d, 4582.036735534668mb
v1_p2d          1.000
v1_sorted       1.004
v1_unsorted     1.119
v1_p2d_mem      0.092
v1_sorted_mem   0.093
v1_unsorted_mem 0.093
v2              1.871
v3              1.451
v3_mem          0.090

@julienvalentin
Copy link
Contributor

@avneesh-sud @amakadia it would be awesome if you could look at this PR :)

@jackd
Copy link
Contributor Author

jackd commented Jan 17, 2020

@julienvalentin @amakadia I've been working on a similar problem since making this PR - haven't quite resolved things there yet, but I imagine things will probably flow across when I get a chance. In short, benchmark results are significantly different when using JIT compilation. Preliminary result indicate:

  • When using JIT compilation, varying tensor sizes have a significant impact on performance. unsorted_segment_* has a guaranteed fixed size output, while segment_* do not. This could be the source of the performance difference. My experience is that newtworks using segment_* over unsorted_semgnet_* are no faster even when the size can be guaranteed to be constant. I haven't benchmarked the operation in isolation, so it might be somewhat faster, but the difference is being swamped by higher order operations in the network.
  • The memory benefits of foldling rather than unstacking disappear when using JIT compilation.
  • I'm still not convinced sparse matrix products are slower than gather/sum-ing. Switching to a sparse matrix-vector product based implementation would require a somewhat uglier code base since there's no max-version of a sparse mat-vec product.

I've got a lot on my plate for the next month or so, but hoping to revisit this and this PR after that. I'm happy to close these for now and create a new merged one later if that would make management easier?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants