Skip to content

Commit

Permalink
tensor_wrap(): minor improvements
Browse files Browse the repository at this point in the history
- don't crash when given an invalid tensor descriptor wrapping a NULL pointer
- don't crash when a tensor framework cannot be imported
  • Loading branch information
wjakob committed Dec 5, 2022
1 parent adfa9e5 commit 633672c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 22 deletions.
52 changes: 31 additions & 21 deletions src/tensor.cpp
Expand Up @@ -544,46 +544,56 @@ static void tensor_capsule_destructor(PyObject *o) {
error_scope scope; // temporarily save any existing errors
managed_tensor *mt =
(managed_tensor *) PyCapsule_GetPointer(o, "dltensor");

if (mt)
tensor_dec_ref((tensor_handle *) mt->manager_ctx);
else
PyErr_Clear();
}

PyObject *tensor_wrap(tensor_handle *th, int framework) noexcept {
if (!th)
return none().release().ptr();

tensor_inc_ref(th);
object o = steal(PyCapsule_New(th->tensor, "dltensor", tensor_capsule_destructor)),
object o = steal(PyCapsule_New(th->tensor, "dltensor",
tensor_capsule_destructor)),
package;

switch ((tensor_framework) framework) {
case tensor_framework::none:
break;
try {
switch ((tensor_framework) framework) {
case tensor_framework::none:
break;

case tensor_framework::numpy:
package = module_::import_("numpy");
o = handle(internals_get().nb_tensor)(o);
break;
case tensor_framework::numpy:
package = module_::import_("numpy");
o = handle(internals_get().nb_tensor)(o);
break;

case tensor_framework::pytorch:
package = module_::import_("torch.utils.dlpack");
break;
case tensor_framework::pytorch:
package = module_::import_("torch.utils.dlpack");
break;


case tensor_framework::tensorflow:
package = module_::import_("tensorflow.experimental.dlpack");
break;
case tensor_framework::tensorflow:
package = module_::import_("tensorflow.experimental.dlpack");
break;

case tensor_framework::jax:
package = module_::import_("jax.dlpack");
break;
case tensor_framework::jax:
package = module_::import_("jax.dlpack");
break;


default:
fail("nanobind::detail::tensor_wrap(): unknown framework "
"specified!");
default:
fail("nanobind::detail::tensor_wrap(): unknown framework "
"specified!");
}
} catch (const std::exception &e) {
PyErr_Format(PyExc_RuntimeError,
"Could not import tensor framework: %s", e.what());
return nullptr;
}


if (package.is_valid()) {
try {
o = package.attr("from_dlpack")(o);
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tensor.cpp
Expand Up @@ -136,5 +136,5 @@ NB_MODULE(test_tensor_ext, m) {
});

return nb::tensor<nb::numpy, float>(f, 0, shape, deleter);
});
});
}

0 comments on commit 633672c

Please sign in to comment.