diff --git a/onnx_diagnostic/helpers/rt_helper.py b/onnx_diagnostic/helpers/rt_helper.py index dab05142..ebd6e157 100644 --- a/onnx_diagnostic/helpers/rt_helper.py +++ b/onnx_diagnostic/helpers/rt_helper.py @@ -112,4 +112,14 @@ def make_feeds( if copy: flat = [t.copy() if hasattr(t, "copy") else t.clone() for t in flat] - return dict(zip(names, flat)) + # bool, int, float, onnxruntime does not support float, bool, int + new_flat = [] + for i in flat: + if isinstance(i, bool): + i = np.array(i, dtype=np.bool_) + elif isinstance(i, int): + i = np.array(i, dtype=np.int64) + elif isinstance(i, float): + i = np.array(i, dtype=np.float32) + new_flat.append(i) + return dict(zip(names, new_flat))