Skip to content

Commit

Permalink
upcast array to solve njit failure (#3652)
Browse files Browse the repository at this point in the history
* upcast array to solve njit failure

* add tests and remove debug statements
  • Loading branch information
CloseChoice committed May 14, 2024
1 parent 86d8bc5 commit e21d70f
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 2 deletions.
33 changes: 31 additions & 2 deletions shap/utils/_masked_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,12 +356,40 @@ def _build_delta_masked_inputs(masks, batch_positions, num_mask_samples, num_var

return all_masked_inputs, i + 1 # i + 1 is the number of output rows after averaging

def _upcast_array(arr: np.ndarray) -> np.ndarray:
"""Since njit doesn't support float16, we need to upcast it to float32.
Args:
arr (np.ndarray): array to upcast
Returns
-------
np.ndarray: upcasted array
"""
if arr.dtype == np.float16:
return arr.astype(np.float32)
else:
return arr

def _build_fixed_output(averaged_outs, last_outs, outputs, batch_positions, varying_rows, num_varying_rows, link, linearizing_weights):
if len(last_outs.shape) == 1:
_build_fixed_single_output(averaged_outs, last_outs, outputs, batch_positions, varying_rows, num_varying_rows, link, linearizing_weights)
_build_fixed_single_output(_upcast_array(averaged_outs),
_upcast_array(last_outs),
_upcast_array(outputs),
batch_positions,
varying_rows,
num_varying_rows,
link,
linearizing_weights)
else:
_build_fixed_multi_output(averaged_outs, last_outs, outputs, batch_positions, varying_rows, num_varying_rows, link, linearizing_weights)
_build_fixed_multi_output(_upcast_array(averaged_outs),
_upcast_array(last_outs),
_upcast_array(outputs),
batch_positions,
varying_rows,
num_varying_rows,
link,
linearizing_weights)

@njit # we can't use this when using a custom link function...
def _build_fixed_single_output(averaged_outs, last_outs, outputs, batch_positions, varying_rows, num_varying_rows, link, linearizing_weights):
Expand Down Expand Up @@ -389,6 +417,7 @@ def _build_fixed_single_output(averaged_outs, last_outs, outputs, batch_position
def _build_fixed_multi_output(averaged_outs, last_outs, outputs, batch_positions, varying_rows, num_varying_rows, link, linearizing_weights):
# here we can assume that the outputs will always be the same size, and we need
# to carry over evaluation outputs

sample_count = last_outs.shape[0]
for i in range(len(averaged_outs)):
if batch_positions[i] < batch_positions[i+1]:
Expand Down
23 changes: 23 additions & 0 deletions tests/utils/test_masked_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import numpy as np

from shap.links import identity
from shap.utils._masked_model import _build_fixed_output


def test__build_fixed_output():
"""GH3651"""
num_varying_rows = np.array([1])
varying_rows = np.array([[True]])
batch_positions = np.array([0, 1])
averaged_outs = np.zeros((1, 10), dtype=np.float32)
last_outs = np.zeros((1, 10), dtype=np.float32)
outputs = np.random.rand(1, 10).astype(np.float16)
_build_fixed_output(averaged_outs,
last_outs,
outputs,
batch_positions,
varying_rows,
num_varying_rows,
identity,
None)
assert np.allclose(averaged_outs, outputs, 1e-2)

0 comments on commit e21d70f

Please sign in to comment.