Skip to content

Commit

Permalink
Fixing the openmp interface calls to avoid build failures when openmp…
Browse files Browse the repository at this point in the history
… is not enabled.
  • Loading branch information
shz9 committed Apr 8, 2024
1 parent 5d3fb56 commit 98df100
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions viprs/model/vi/e_step.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from cython.parallel import prange, parallel
cimport openmp
from ...utils.math_utils cimport (
sigmoid,
softmax,
Expand All @@ -16,6 +15,24 @@ cimport numpy as np
from cython cimport floating, integral


# A safe way to get the number of the thread currently executing the code:
# This is used to avoid compile-time errors when compiling the code with OpenMP support disabled.
# In earlier iterations, we used:
# cimport openmp
# openmp.omp_get_thread_num()
# But this tends to fail when OpenMP is not enabled.
# The code below is a safer way to get the thread number.
cdef extern from *:
"""
#ifdef _OPENMP
#include <omp.h>
#else
int omp_get_thread_num() { return 0; }
#endif
"""
int omp_get_thread_num() noexcept nogil


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
Expand Down Expand Up @@ -243,7 +260,7 @@ cpdef void e_step_mixture(int[::1] ld_left_bound,
for j in prange(c_size, nogil=True, schedule='static', num_threads=threads):

# Set the thread offset for the u_j array:
thread_offset = openmp.omp_get_thread_num() * (K + 1)
thread_offset = omp_get_thread_num() * (K + 1)

# The start and end coordinates for the flattened LD matrix:
ld_start = ld_indptr[j]
Expand Down

0 comments on commit 98df100

Please sign in to comment.