Skip to content

Commit

Permalink
Adds an environment variable check that allows only turning off the p…
Browse files Browse the repository at this point in the history
…arallel target for NUMBA
  • Loading branch information
hugohadfield committed Feb 6, 2019
1 parent fd05c00 commit 3859289
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 8 deletions.
13 changes: 10 additions & 3 deletions clifford/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,20 @@
_print_precision = 5 # pretty printing precision on floats
TEST_NUMBA = True

import os
try:
NUMBA_DISABLE_PARALLEL = os.environ['NUMBA_DISABLE_PARALLEL']
NUMBA_PARALLEL = not bool(NUMBA_DISABLE_PARALLEL)
except:
NUMBA_PARALLEL = True


def test_numba():
"""
This tests numba to see if it can successfully compile a specific program
https://github.com/numba/numba/issues/3671
"""
@numba.njit(parallel=True)
@numba.njit(parallel=NUMBA_PARALLEL)
def play_games():
monte_carlo_cell_visit_frequency = np.zeros(100, dtype=np.int_)
monte_carlo_cell_visit_frequency != 0
Expand Down Expand Up @@ -110,7 +117,7 @@ def adjoint_func(value):
return adjoint_func


@numba.njit(parallel=True, nogil=True)
@numba.njit(parallel=NUMBA_PARALLEL, nogil=True)
def construct_tables(gradeList, linear_map_to_bitmap,
bitmap_to_linear_map, signature):

Expand Down Expand Up @@ -716,7 +723,7 @@ def _genTables(self):
k_list, l_list, m_list, mult_table_vals, imt_prod_mask, omt_prod_mask, lcmt_prod_mask = construct_tables(np.array(self.gradeList),
self.linear_map_to_bitmap,
self.bitmap_to_linear_map,
np.array(self.sig))
self.sig)

# This generates the functions that will perform the various products
self.gmt_func = get_mult_function(k_list,l_list,m_list,mult_table_vals,self.gaDims,self.gradeList)
Expand Down
6 changes: 3 additions & 3 deletions clifford/tools/g3c/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@
random_rotation_rotor, generate_rotation_rotor, val_random_euc_mv
from clifford.g3c import *
import clifford as cf
from clifford import val_get_left_gmt_matrix, grades_present
from clifford import val_get_left_gmt_matrix, grades_present, NUMBA_PARALLEL
import warnings

# Allow syntactic alternatives to the standard included in the clifford package
Expand Down Expand Up @@ -420,7 +420,7 @@ def midpoint_of_line_cluster(line_cluster):
return layout.MultiVector(value=center_point)


@numba.njit(parallel=True)
@numba.njit(parallel=NUMBA_PARALLEL)
def val_midpoint_of_line_cluster(array_line_cluster):
"""
Gets an approximate center point of a line cluster
Expand All @@ -436,7 +436,7 @@ def val_midpoint_of_line_cluster(array_line_cluster):
return center_point


@numba.njit(parallel=True)
@numba.njit(parallel=NUMBA_PARALLEL)
def val_midpoint_of_line_cluster_grad(array_line_cluster):
"""
Gets an approximate center point of a line cluster
Expand Down
5 changes: 3 additions & 2 deletions clifford/tools/g3c/cost_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

from . import *
from clifford import NUMBA_PARALLEL
import itertools

imt_func = layout.imt_func
Expand Down Expand Up @@ -193,7 +194,7 @@ def object_set_cost_sum(object_set_a, object_set_b, object_type='generic'):
return sum_val


@numba.njit(parallel=True, nogil=True)
@numba.njit(parallel=NUMBA_PARALLEL, nogil=True)
def val_object_set_cost_matrix(object_array_a, object_array_b):
"""
Evaluates the rotor cost matrix between two sets of objects
Expand All @@ -207,7 +208,7 @@ def val_object_set_cost_matrix(object_array_a, object_array_b):
return matrix


@numba.njit(parallel=True, nogil=True)
@numba.njit(parallel=NUMBA_PARALLEL, nogil=True)
def val_line_set_cost_matrix(object_array_a, object_array_b):
"""
Evaluates the rotor cost matrix between two sets of objects
Expand Down

0 comments on commit 3859289

Please sign in to comment.