Efficient Graph Field Integrators

Given weighted undirected graph $G=(V, E, W)$, a kernel $K:V\times V \rightarrow \mathbb{R}$ and a tensor field $\mathcal{F}:V \rightarrow \mathbb{R}^{d_{1} \times \ldots \times d_{l}}$ defined on $V$, where $d_{1},\ldots,d_{l}$ stand for tensor dimensions.

In this repository, we implement several methods that allow efficient computation of

$$i(v) := \sum_{w \in \mathrm{V}}\mathrm{K}(w,v)\mathcal{F}(w), \qquad \text{for all } v \in V.$$

We refer to the process of computing $i(v)$ as graph-field integration (GFI).

This repository accompanies the paper "Efficient Graph Field Integrators Meet Point Clouds".

Krzysztof Choromanski*, Arijit Sehanobish*, Han Lin*, Yunfan Zhao*, Eli Berger, Tetiana Parshakova, Alvin Pan, David Watkins, Tianyi Zhang, Valerii Likhosherstov, Somnath Basu Roy Chowdhury, Avinava Dubey, Deepali Jain, Tamas Sarlos, Snigdha Chaturvedi, Adrian Weller

Google Research, Columbia University, Haifa University, Stanford University, The Boston Dynamics AI Institute, University of Cambridge, The University of North Carolina at Chapel Hill, The Alan Turing Institute.

The Fortieth International Conference on Machine Learning (ICML), 2023


git clone
cd efficient_graph_algorithms
python3 -m venv env
source env/bin/activate
pip3 install -r requirements.txt
pip3 install -e . --user
git clone

If you have error when running command line pip3 install -e . --user, you can follow this link.

Getting started

This repository contains implementations of several GFIs that inherit from GFIntegrator. They can be categorized based on their representation of point clouds:

  1. Mesh graph-based representation
  • separator factorization GFI SeparationGFIntegrator
  • trees approximating graph metric
    • FRT trees-based FRTTreeGFIntegrator
    • Bartal trees-based GFI BartalTreeGFIntegrator
    • spanning tree-based GFI SpanningTreeGFIntegrator
  1. $\epsilon$-NN (Nearest Neighbor) based representation
  • random feature diffusion GFI DFGFIntegrator

These GFIs can be readily used for the following tasks

  • interpolation task using Interpolator
    • by specifying parametersGFIntegrator, vertices_known, vertices_interpolate at instantiation
    • and after, calling method interpolate while specifying the field values on the vetrices_known
  • Wasserstein barycenter using ConvolutionalBarycenter
    • by specifying parameters niter, tolerance at instantiation
    • and after, calling method get_convolutional_barycenter while specifying array with distributions, mixing weights and GFIntegrator.integrate_graph_field
  • Gromov Wasserstein discrepancy
    • Proximal point algorithm using gromov_wasserstein_discrepancy
    • conditional gradient algorithm using gw_lp
  • Fused Gromov Wasserstein using fgw_lp
  • Point cloud classification


Vertex normal prediction

First download Thingi10K mesh data from this link. The mesh IDs we used in our paper are listed in Appendix C1 of our paper.

scripts/experiments/vertex_normal_prediction_config.yaml is an example configuration file to run vertex normal prediction task.

To run experiment on this task:

python scripts/experiments/ 

For information on how to run each experiment:

Wasserstein barycenter

Follow the instructions above to download Thingi10K mesh data.

To run experiments for RFD:

python scripts/experiments/

To run experiments for SF:

python scripts/experiments/ 

To run experiments for trees:

python scripts/experiments/

(Fused) Gromov-Wasserstein discrepancy

To run experiments for RFD on GW with conjugate gradient method:

python scripts/experiments/

To run experiments for RFD on GW with proximal method:

python scripts/experiments/

To run experiments for RFD on FGW:

python scripts/experiments/

To run experiments for SF on GW with conjugate gradient method:

python scripts/experiments/

To run experiments for SF on GW with proximal method:

python scripts/experiments/

To run experiments for SF on FGW:

python scripts/experiments/

Point Cloud classification

We also use the approximated RFD kernel matrix for point cloud classification on ModelNet10. The code to run it is in notebooks/point_cloud_classification.ipynb

MeshGraphNet datasets

For information on how to download and prepare meshgraphnet dataset: