Skip to content

Commit 1c7e7ce

Browse files
kguerda-idriskguerda-idrisncassereauncassereaurflamary
authored
[MRG] OpenMP support (#260)
* Added : OpenMP support Restored : Epsilon and Debug mode Replaced : parmap => multiprocessing is now replace by multithreading * Commit clean up * Number of CPUs correctly calculated on SLURM clusters * Corrected number of processes for cluster slurm * Mistake corrected * parmap is now deprecated * Now a different solver is used depending on the requested number of threads * Tiny mistake corrected * Folders are now in the ot library instead of at the root * Helpers is now correctly placed * Attempt to make compilation work smoothly * OS compatible path * NumThreads now defaults to 1 * Better flags * Mistake corrected in case of OpenMP unavailability * Revert OpenMP flags modification, which do not compile on Windows * Test helper functions * Helpers comments * Documentation update * File title corrected * Warning no longer using print * Last attempt for macos compilation * pls work * atempt * solving a type error * TypeError OpenMP * Compilation finally working on Windows * Bug solve, number of threads now correctly selected * 64 bits solver to avoid overflows for bigger problems * 64 bits EMD corrected Co-authored-by: kguerda-idris <ssos023@jean-zay3.idris.fr> Co-authored-by: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Co-authored-by: ncassereau <nathan.cassereau@idris.fr> Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent 7dde9e8 commit 1c7e7ce

13 files changed

+2442
-185
lines changed

ot/helpers/openmp_helpers.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""Helpers for OpenMP support during the build."""
2+
3+
# This code is adapted for a large part from the astropy openmp helpers, which
4+
# can be found at: https://github.com/astropy/extension-helpers/blob/master/extension_helpers/_openmp_helpers.py # noqa
5+
6+
7+
import os
8+
import sys
9+
import textwrap
10+
import subprocess
11+
12+
from distutils.errors import CompileError, LinkError
13+
14+
from pre_build_helpers import compile_test_program
15+
16+
17+
def get_openmp_flag(compiler):
18+
"""Get openmp flags for a given compiler"""
19+
20+
if hasattr(compiler, 'compiler'):
21+
compiler = compiler.compiler[0]
22+
else:
23+
compiler = compiler.__class__.__name__
24+
25+
if sys.platform == "win32" and ('icc' in compiler or 'icl' in compiler):
26+
omp_flag = ['/Qopenmp']
27+
elif sys.platform == "win32":
28+
omp_flag = ['/openmp']
29+
elif sys.platform in ("darwin", "linux") and "icc" in compiler:
30+
omp_flag = ['-qopenmp']
31+
elif sys.platform == "darwin" and 'openmp' in os.getenv('CPPFLAGS', ''):
32+
omp_flag = []
33+
else:
34+
# Default flag for GCC and clang:
35+
omp_flag = ['-fopenmp']
36+
if sys.platform.startswith("darwin"):
37+
omp_flag += ["-Xpreprocessor", "-lomp"]
38+
return omp_flag
39+
40+
41+
def check_openmp_support():
42+
"""Check whether OpenMP test code can be compiled and run"""
43+
44+
code = textwrap.dedent(
45+
"""\
46+
#include <omp.h>
47+
#include <stdio.h>
48+
int main(void) {
49+
#pragma omp parallel
50+
printf("nthreads=%d\\n", omp_get_num_threads());
51+
return 0;
52+
}
53+
""")
54+
55+
extra_preargs = os.getenv('LDFLAGS', None)
56+
if extra_preargs is not None:
57+
extra_preargs = extra_preargs.strip().split(" ")
58+
extra_preargs = [
59+
flag for flag in extra_preargs
60+
if flag.startswith(('-L', '-Wl,-rpath', '-l'))]
61+
62+
extra_postargs = get_openmp_flag
63+
64+
try:
65+
output, compile_flags = compile_test_program(
66+
code,
67+
extra_preargs=extra_preargs,
68+
extra_postargs=extra_postargs
69+
)
70+
71+
if output and 'nthreads=' in output[0]:
72+
nthreads = int(output[0].strip().split('=')[1])
73+
openmp_supported = len(output) == nthreads
74+
elif "PYTHON_CROSSENV" in os.environ:
75+
# Since we can't run the test program when cross-compiling
76+
# assume that openmp is supported if the program can be
77+
# compiled.
78+
openmp_supported = True
79+
else:
80+
openmp_supported = False
81+
82+
except (CompileError, LinkError, subprocess.CalledProcessError):
83+
openmp_supported = False
84+
compile_flags = []
85+
return openmp_supported, compile_flags

ot/helpers/pre_build_helpers.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""Helpers to check build environment before actual build of POT"""
2+
3+
import os
4+
import sys
5+
import glob
6+
import tempfile
7+
import setuptools # noqa
8+
import subprocess
9+
10+
from distutils.dist import Distribution
11+
from distutils.sysconfig import customize_compiler
12+
from numpy.distutils.ccompiler import new_compiler
13+
from numpy.distutils.command.config_compiler import config_cc
14+
15+
16+
def _get_compiler():
17+
"""Get a compiler equivalent to the one that will be used to build POT
18+
Handles compiler specified as follows:
19+
- python setup.py build_ext --compiler=<compiler>
20+
- CC=<compiler> python setup.py build_ext
21+
"""
22+
dist = Distribution({'script_name': os.path.basename(sys.argv[0]),
23+
'script_args': sys.argv[1:],
24+
'cmdclass': {'config_cc': config_cc}})
25+
26+
cmd_opts = dist.command_options.get('build_ext')
27+
if cmd_opts is not None and 'compiler' in cmd_opts:
28+
compiler = cmd_opts['compiler'][1]
29+
else:
30+
compiler = None
31+
32+
ccompiler = new_compiler(compiler=compiler)
33+
customize_compiler(ccompiler)
34+
35+
return ccompiler
36+
37+
38+
def compile_test_program(code, extra_preargs=[], extra_postargs=[]):
39+
"""Check that some C code can be compiled and run"""
40+
ccompiler = _get_compiler()
41+
42+
# extra_(pre/post)args can be a callable to make it possible to get its
43+
# value from the compiler
44+
if callable(extra_preargs):
45+
extra_preargs = extra_preargs(ccompiler)
46+
if callable(extra_postargs):
47+
extra_postargs = extra_postargs(ccompiler)
48+
49+
start_dir = os.path.abspath('.')
50+
51+
with tempfile.TemporaryDirectory() as tmp_dir:
52+
try:
53+
os.chdir(tmp_dir)
54+
55+
# Write test program
56+
with open('test_program.c', 'w') as f:
57+
f.write(code)
58+
59+
os.mkdir('objects')
60+
61+
# Compile, test program
62+
ccompiler.compile(['test_program.c'], output_dir='objects',
63+
extra_postargs=extra_postargs)
64+
65+
# Link test program
66+
objects = glob.glob(
67+
os.path.join('objects', '*' + ccompiler.obj_extension))
68+
ccompiler.link_executable(objects, 'test_program',
69+
extra_preargs=extra_preargs,
70+
extra_postargs=extra_postargs)
71+
72+
if "PYTHON_CROSSENV" not in os.environ:
73+
# Run test program if not cross compiling
74+
# will raise a CalledProcessError if return code was non-zero
75+
output = subprocess.check_output('./test_program')
76+
output = output.decode(
77+
sys.stdout.encoding or 'utf-8').splitlines()
78+
else:
79+
# Return an empty output if we are cross compiling
80+
# as we cannot run the test_program
81+
output = []
82+
except Exception:
83+
raise
84+
finally:
85+
os.chdir(start_dir)
86+
87+
return output, extra_postargs

ot/lp/EMD.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,18 @@
1818

1919
#include <iostream>
2020
#include <vector>
21-
#include "network_simplex_simple.h"
2221

23-
using namespace lemon;
2422
typedef unsigned int node_id_type;
2523

2624
enum ProblemType {
2725
INFEASIBLE,
2826
OPTIMAL,
2927
UNBOUNDED,
30-
MAX_ITER_REACHED
28+
MAX_ITER_REACHED
3129
};
3230

3331
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter);
32+
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads);
3433

3534

3635

ot/lp/EMD_wrapper.cpp

Lines changed: 117 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,22 @@
1212
*
1313
*/
1414

15+
16+
#include "network_simplex_simple.h"
17+
#include "network_simplex_simple_omp.h"
1518
#include "EMD.h"
19+
#include <cstdint>
1620

1721

1822
int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
1923
double* alpha, double* beta, double *cost, int maxIter) {
20-
// beware M and C anre strored in row major C style!!!
21-
int n, m, i, cur;
24+
// beware M and C are stored in row major C style!!!
25+
26+
using namespace lemon;
27+
int n, m, cur;
2228

2329
typedef FullBipartiteDigraph Digraph;
24-
DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
30+
DIGRAPH_TYPEDEFS(Digraph);
2531

2632
// Get the number of non zero coordinates for r and c
2733
n=0;
@@ -48,7 +54,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
4854
std::vector<int> indI(n), indJ(m);
4955
std::vector<double> weights1(n), weights2(m);
5056
Digraph di(n, m);
51-
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter);
57+
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter);
5258

5359
// Set supply and demand, don't account for 0 values (faster)
5460

@@ -76,23 +82,26 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
7682
net.supplyMap(&weights1[0], n, &weights2[0], m);
7783

7884
// Set the cost of each edge
85+
int64_t idarc = 0;
7986
for (int i=0; i<n; i++) {
8087
for (int j=0; j<m; j++) {
8188
double val=*(D+indI[i]*n2+indJ[j]);
82-
net.setCost(di.arcFromId(i*m+j), val);
89+
net.setCost(di.arcFromId(idarc), val);
90+
++idarc;
8391
}
8492
}
8593

8694

8795
// Solve the problem with the network simplex algorithm
8896

8997
int ret=net.run();
98+
int i, j;
9099
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
91100
*cost = 0;
92101
Arc a; di.first(a);
93102
for (; a != INVALID; di.next(a)) {
94-
int i = di.source(a);
95-
int j = di.target(a);
103+
i = di.source(a);
104+
j = di.target(a);
96105
double flow = net.flow(a);
97106
*cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
98107
*(G+indI[i]*n2+indJ[j-n]) = flow;
@@ -106,3 +115,104 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
106115
return ret;
107116
}
108117

118+
119+
120+
121+
122+
123+
124+
int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
125+
double* alpha, double* beta, double *cost, int maxIter, int numThreads) {
126+
// beware M and C are stored in row major C style!!!
127+
128+
using namespace lemon_omp;
129+
int n, m, cur;
130+
131+
typedef FullBipartiteDigraph Digraph;
132+
DIGRAPH_TYPEDEFS(Digraph);
133+
134+
// Get the number of non zero coordinates for r and c
135+
n=0;
136+
for (int i=0; i<n1; i++) {
137+
double val=*(X+i);
138+
if (val>0) {
139+
n++;
140+
}else if(val<0){
141+
return INFEASIBLE;
142+
}
143+
}
144+
m=0;
145+
for (int i=0; i<n2; i++) {
146+
double val=*(Y+i);
147+
if (val>0) {
148+
m++;
149+
}else if(val<0){
150+
return INFEASIBLE;
151+
}
152+
}
153+
154+
// Define the graph
155+
156+
std::vector<int> indI(n), indJ(m);
157+
std::vector<double> weights1(n), weights2(m);
158+
Digraph di(n, m);
159+
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter, numThreads);
160+
161+
// Set supply and demand, don't account for 0 values (faster)
162+
163+
cur=0;
164+
for (int i=0; i<n1; i++) {
165+
double val=*(X+i);
166+
if (val>0) {
167+
weights1[ cur ] = val;
168+
indI[cur++]=i;
169+
}
170+
}
171+
172+
// Demand is actually negative supply...
173+
174+
cur=0;
175+
for (int i=0; i<n2; i++) {
176+
double val=*(Y+i);
177+
if (val>0) {
178+
weights2[ cur ] = -val;
179+
indJ[cur++]=i;
180+
}
181+
}
182+
183+
184+
net.supplyMap(&weights1[0], n, &weights2[0], m);
185+
186+
// Set the cost of each edge
187+
int64_t idarc = 0;
188+
for (int i=0; i<n; i++) {
189+
for (int j=0; j<m; j++) {
190+
double val=*(D+indI[i]*n2+indJ[j]);
191+
net.setCost(di.arcFromId(idarc), val);
192+
++idarc;
193+
}
194+
}
195+
196+
197+
// Solve the problem with the network simplex algorithm
198+
199+
int ret=net.run();
200+
int i, j;
201+
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
202+
*cost = 0;
203+
Arc a; di.first(a);
204+
for (; a != INVALID; di.next(a)) {
205+
i = di.source(a);
206+
j = di.target(a);
207+
double flow = net.flow(a);
208+
*cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
209+
*(G+indI[i]*n2+indJ[j-n]) = flow;
210+
*(alpha + indI[i]) = -net.potential(i);
211+
*(beta + indJ[j-n]) = net.potential(j);
212+
}
213+
214+
}
215+
216+
217+
return ret;
218+
}

0 commit comments

Comments
 (0)