Skip to content

Commit

Permalink
Merge pull request #81 from pygae/fix_jit_windows
Browse files Browse the repository at this point in the history
Adds an environment variable check that allows only turning off the p…
  • Loading branch information
hugohadfield committed Feb 6, 2019
2 parents fd05c00 + 3859289 commit b99a46e
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 8 deletions.
13 changes: 10 additions & 3 deletions clifford/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,20 @@
_print_precision = 5 # pretty printing precision on floats
TEST_NUMBA = True

import os
try:
NUMBA_DISABLE_PARALLEL = os.environ['NUMBA_DISABLE_PARALLEL']
NUMBA_PARALLEL = not bool(NUMBA_DISABLE_PARALLEL)
except:
NUMBA_PARALLEL = True


def test_numba():
"""
This tests numba to see if it can successfully compile a specific program
https://github.com/numba/numba/issues/3671
"""
@numba.njit(parallel=True)
@numba.njit(parallel=NUMBA_PARALLEL)
def play_games():
monte_carlo_cell_visit_frequency = np.zeros(100, dtype=np.int_)
monte_carlo_cell_visit_frequency != 0
Expand Down Expand Up @@ -110,7 +117,7 @@ def adjoint_func(value):
return adjoint_func


@numba.njit(parallel=True, nogil=True)
@numba.njit(parallel=NUMBA_PARALLEL, nogil=True)
def construct_tables(gradeList, linear_map_to_bitmap,
bitmap_to_linear_map, signature):

Expand Down Expand Up @@ -716,7 +723,7 @@ def _genTables(self):
k_list, l_list, m_list, mult_table_vals, imt_prod_mask, omt_prod_mask, lcmt_prod_mask = construct_tables(np.array(self.gradeList),
self.linear_map_to_bitmap,
self.bitmap_to_linear_map,
np.array(self.sig))
self.sig)

# This generates the functions that will perform the various products
self.gmt_func = get_mult_function(k_list,l_list,m_list,mult_table_vals,self.gaDims,self.gradeList)
Expand Down
6 changes: 3 additions & 3 deletions clifford/tools/g3c/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@
random_rotation_rotor, generate_rotation_rotor, val_random_euc_mv
from clifford.g3c import *
import clifford as cf
from clifford import val_get_left_gmt_matrix, grades_present
from clifford import val_get_left_gmt_matrix, grades_present, NUMBA_PARALLEL
import warnings

# Allow syntactic alternatives to the standard included in the clifford package
Expand Down Expand Up @@ -420,7 +420,7 @@ def midpoint_of_line_cluster(line_cluster):
return layout.MultiVector(value=center_point)


@numba.njit(parallel=True)
@numba.njit(parallel=NUMBA_PARALLEL)
def val_midpoint_of_line_cluster(array_line_cluster):
"""
Gets an approximate center point of a line cluster
Expand All @@ -436,7 +436,7 @@ def val_midpoint_of_line_cluster(array_line_cluster):
return center_point


@numba.njit(parallel=True)
@numba.njit(parallel=NUMBA_PARALLEL)
def val_midpoint_of_line_cluster_grad(array_line_cluster):
"""
Gets an approximate center point of a line cluster
Expand Down
5 changes: 3 additions & 2 deletions clifford/tools/g3c/cost_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

from . import *
from clifford import NUMBA_PARALLEL
import itertools

imt_func = layout.imt_func
Expand Down Expand Up @@ -193,7 +194,7 @@ def object_set_cost_sum(object_set_a, object_set_b, object_type='generic'):
return sum_val


@numba.njit(parallel=True, nogil=True)
@numba.njit(parallel=NUMBA_PARALLEL, nogil=True)
def val_object_set_cost_matrix(object_array_a, object_array_b):
"""
Evaluates the rotor cost matrix between two sets of objects
Expand All @@ -207,7 +208,7 @@ def val_object_set_cost_matrix(object_array_a, object_array_b):
return matrix


@numba.njit(parallel=True, nogil=True)
@numba.njit(parallel=NUMBA_PARALLEL, nogil=True)
def val_line_set_cost_matrix(object_array_a, object_array_b):
"""
Evaluates the rotor cost matrix between two sets of objects
Expand Down

0 comments on commit b99a46e

Please sign in to comment.