This repository contains the implementation of our algorithms from our paper OpenReview and the experiment code.
Optimal Transport is a popular distance metric for measuring similarity between distributions. Exact and approximate combinatorial algorithms for computing the optimal transport distance are hard to parallelize. This has motivated the development of numerical solvers (e.g. Sinkhorn method) that can exploit GPU parallelism and produce approximate solutions.
We introduce the first parallel combinatorial algorithm to find an additive
Current directory contains three parts:
- Implementation of parallel transport and assignment algorithm (section 4):
transport.py
,matching.py
- Experiments compare our method and Sinkhorn (section 5):
plgpu_vs_sinkorn_bench.py
,plgpu_vs_sinkorn_bench_rev.py
- Experiments compare our method and DROT (Appendix C.2):
plgpu_vs_drot_bench_step1.py
,plgpu_vs_drot_bench_step2.py
To use our algorithm or reproduce our experiments, simply install the following dependencies in your python environment and run the code.
For the first part, our algorithm implementation requires:
Reproducing our experiments requires:
To run the experiments in this repository, please download datasets in the anonymized link here. And also download the glove embedding file from here.
If you find this work helpful, please consider citing our paper:
@inproceedings{lahn2023combinatorial,
title={A Combinatorial Algorithm for Approximating the Optimal Transport in the Parallel and MPC Settings},
author={Lahn, Nathaniel and Raghvendra, Sharath and Zhang, Kaiyi},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023}
}
Synthetic Data OT
python plgpu_vs_sinkhorn_bench.py --nexp 10 --n 10000 --dataset_name synthetic_OT --is_transport 1 --delta_num 10 --delta_low 0.0007 --delta_high 0.1
Synthetic Data OT (reverse)
python plgpu_vs_sinkhorn_bench_rev.py --nexp 10 --n 10000 --dataset_name synthetic_OT --is_transport 1 --reg_num 10 --reg_low 0.00015 --reg_high 0.01
Synthetic Data Assignment
python plgpu_vs_sinkhorn_bench.py --nexp 10 --n 10000 --dataset_name synthetic_matching --delta_num 10 --delta_low 0.0007 --delta_high 0.01 --is_transport 0
Synthetic Data Assignment (reverse)
python plgpu_vs_sinkhorn_bench_rev.py --nexp 10 --n 10000 --dataset_name synthetic_matching --reg_num 10 --reg_low 0.00045 --reg_high 0.01 --is_transport 0
MNIST Data Assignment
python plgpu_vs_sinkhorn_bench.py --nexp 10 --n 10000 --dataset_name mnist_matching --delta_num 10 --delta_low 0.02 --delta_high 0.2 --is_transport 0
MNIST Data Assignment (reverse)
python plgpu_vs_sinkhorn_bench_rev.py --nexp 10 --n 10000 --dataset_name mnist_matching --reg_num 10 --reg_low 0.002 --reg_high 0.02 --is_transport 0
NLP Data
the count of monte cristo
python plgpu_vs_sinkhorn_bench.py --nexp 5 --dataset_name NLP_OT --is_transport 1 --delta_num 10 --delta_low 0.1 --delta_high 1 --nlp_name the-count-of-monte-cristo --metric euclidean --nlp_portion_size 2000
IMDB
python plgpu_vs_sinkhorn_bench.py --nexp 5 --dataset_name NLP_OT --is_transport 1 --delta_num 10 --delta_low 0.1 --delta_high 1 --nlp_name IMDB --metric euclidean --nlp_portion_size 100
20NEWS
python plgpu_vs_sinkhorn_bench.py --nexp 5 --dataset_name NLP_OT --is_transport 1 --delta_num 10 --delta_low 0.1 --delta_high 1 --nlp_name 20news --metric euclidean --nlp_portion_size 3000
NLP Data (reverse)
the count of monte cristo
python plgpu_vs_sinkhorn_bench_rev.py --nexp 5 --dataset_name NLP_OT --is_transport 1 --reg_num 10 --reg_low 0.001 --reg_high 0.1 --nlp_name the-count-of-monte-cristo --metric euclidean --nlp_portion_size 2000
IMDB
python plgpu_vs_sinkhorn_bench_rev.py --nexp 5 --dataset_name NLP_OT --is_transport 1 --reg_num 10 --reg_low 0.001 --reg_high 0.1 --nlp_name 20news --metric euclidean --nlp_portion_size 3000
20NEWS
python plgpu_vs_sinkhorn_bench_rev.py --nexp 5 --dataset_name NLP_OT --is_transport 1 --reg_num 10 --reg_low 0.001 --reg_high 0.1 --nlp_name IMDB --metric euclidean --nlp_portion_size 100
Synthetic Data OT
python plgpu_vs_drot_bench_step1.py --nexp 10 --n 10000 --dataset_name synthetic_OT --is_transport 1 --delta_num 10 --delta_low 0.0001 --delta_high 0.01
python plgpu_vs_drot_bench_step2.py --nexp 10 --n 10000 --dataset_name synthetic_OT --is_transport 1 --delta_num 10 --delta_low 0.0001 --delta_high 0.01
Synthetic Data Assignment
python plgpu_vs_drot_bench_step1.py --nexp 10 --n 10000 --dataset_name synthetic_matching --delta_num 10 --delta_low 0.0001 --delta_high 0.01 --is_transport 0
python plgpu_vs_drot_bench_step2.py --nexp 10 --n 10000 --dataset_name synthetic_matching --delta_num 10 --delta_low 0.0001 --delta_high 0.01 --is_transport 0
MNIST Data Assignment
python plgpu_vs_drot_bench_step1.py --nexp 10 --n 10000 --dataset_name mnist_matching --delta_num 10 --delta_low 0.02 --delta_high 0.2 --is_transport 0
python plgpu_vs_drot_bench_step2.py --nexp 10 --n 10000 --dataset_name mnist_matching --delta_num 10 --delta_low 0.02 --delta_high 0.2 --is_transport 0
NLP Data
the count of monte cristo
python plgpu_vs_drot_bench_step1.py --nexp 5 --dataset_name NLP_OT --is_transport 1 --delta_num 10 --delta_low 0.1 --delta_high 1 --nlp_name the-count-of-monte-cristo --metric euclidean --nlp_portion_size 2000
python plgpu_vs_drot_bench_step2.py --nexp 5 --dataset_name NLP_OT --is_transport 1 --delta_num 10 --delta_low 0.1 --delta_high 1 --nlp_name the-count-of-monte-cristo --metric euclidean --nlp_portion_size 2000
IMDB
python plgpu_vs_drot_bench_step1.py --nexp 5 --dataset_name NLP_OT --is_transport 1 --delta_num 10 --delta_low 0.1 --delta_high 1 --nlp_name IMDB --metric euclidean --nlp_portion_size 100
python plgpu_vs_drot_bench_step2.py --nexp 5 --dataset_name NLP_OT --is_transport 1 --delta_num 10 --delta_low 0.1 --delta_high 1 --nlp_name IMDB --metric euclidean --nlp_portion_size 100
20NEWS
python plgpu_vs_drot_bench_step1.py --nexp 5 --dataset_name NLP_OT --is_transport 1 --delta_num 10 --delta_low 0.1 --delta_high 1 --nlp_name 20news --metric euclidean --nlp_portion_size 3000
python plgpu_vs_drot_bench_step2.py --nexp 5 --dataset_name NLP_OT --is_transport 1 --delta_num 10 --delta_low 0.1 --delta_high 1 --nlp_name 20news --metric euclidean --nlp_portion_size 3000