From 051ad179ce092eef6de8917a796a3ea74cf28467 Mon Sep 17 00:00:00 2001 From: Yinghai Lu Date: Wed, 17 Feb 2021 21:21:35 -0800 Subject: [PATCH] Add onnxifi interface for set/get options (#52388) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52388 Pull Request resolved: https://github.com/pytorch/glow/pull/5364 This allows us to change global variables through onnxifi calls. And add python bindings along with it. Note that we supply a dummy backend_id as it's not needed by glow due to setting being global. #codemod Reviewed By: jfix71 Differential Revision: D26481652 fbshipit-source-id: 3d6368c200ee1f9da369627e6043bb058871a55b --- lib/Onnxifi/onnxifiGlow.cpp | 65 +++++++++++++++++++++++++++++++++---- thirdparty/foxi | 2 +- 2 files changed, 60 insertions(+), 7 deletions(-) diff --git a/lib/Onnxifi/onnxifiGlow.cpp b/lib/Onnxifi/onnxifiGlow.cpp index e81c8e0f76..2543f755ce 100644 --- a/lib/Onnxifi/onnxifiGlow.cpp +++ b/lib/Onnxifi/onnxifiGlow.cpp @@ -608,6 +608,48 @@ GLOW_ONNXIFI_LIBRARY_FUNCTION_WRAPPER(onnxReleaseTraceEvents)( return ONNXIFI_STATUS_SUCCESS; } +/// Set Onnxifi option +EXTERNC ONNXIFI_PUBLIC ONNXIFI_CHECK_RESULT onnxStatus ONNXIFI_ABI +GLOW_ONNXIFI_LIBRARY_FUNCTION_WRAPPER(onnxSetOption)(const char *optionName, + const char *optionValue) { + if (!optionName || !optionValue) { + return ONNXIFI_STATUS_INVALID_POINTER; + } + onnxStatus ret = ONNXIFI_STATUS_SUCCESS; + int d = 0; + if (!strcmp(optionName, "glow_num_devices")) { + if (sscanf(optionValue, "%d", &d) == 1) { + glow::flags::NumDevices = d; + } else { + ret = ONNXIFI_STATUS_UNSUPPORTED_ATTRIBUTE; + } + } else { + ret = ONNXIFI_STATUS_INVALID_NAME; + } + return ret; +} + +/// Get Onnxifi option +EXTERNC ONNXIFI_PUBLIC ONNXIFI_CHECK_RESULT onnxStatus ONNXIFI_ABI +GLOW_ONNXIFI_LIBRARY_FUNCTION_WRAPPER(onnxGetOption)( + const char *optionName, char *optionValue, size_t *optionValueLength) { + if (!optionName || !optionValue || !optionValueLength) { + return ONNXIFI_STATUS_INVALID_POINTER; + } + onnxStatus ret = ONNXIFI_STATUS_SUCCESS; + if (!strcmp(optionName, "glow_num_devices")) { + int n = snprintf(optionValue, *optionValueLength, "%d", + glow::flags::NumDevices); + if (n < 0) { + ret = ONNXIFI_STATUS_UNSUPPORTED_ATTRIBUTE; + } else if (n < *optionValueLength) { + *optionValueLength = n; + } + } else { + ret = ONNXIFI_STATUS_INVALID_NAME; + } + return ret; +} /// Get pointer to onnxifi extension function with \p name. EXTERNC ONNXIFI_PUBLIC ONNXIFI_CHECK_RESULT onnxStatus ONNXIFI_ABI GLOW_ONNXIFI_LIBRARY_FUNCTION_WRAPPER(onnxGetExtensionFunctionAddress)( @@ -617,11 +659,16 @@ GLOW_ONNXIFI_LIBRARY_FUNCTION_WRAPPER(onnxGetExtensionFunctionAddress)( return ONNXIFI_STATUS_INVALID_POINTER; } - auto &manager = glow::onnxifi::GlowOnnxifiManager::get(); - - auto *glowBackend = static_cast(backendID); - if (!manager.isValid(glowBackend)) { - return ONNXIFI_STATUS_INVALID_ID; + // We don't check backend id for set/get option functions as the options + // global to Glow. + static const std::unordered_set bypass{"onnxSetOptionFunction", + "onnxGetOptionFunction"}; + if (bypass.find(name) == bypass.end()) { + auto &manager = glow::onnxifi::GlowOnnxifiManager::get(); + auto *glowBackend = static_cast(backendID); + if (!manager.isValid(glowBackend)) { + return ONNXIFI_STATUS_INVALID_ID; + } } // Map of name to onnxExtensionFunctionPointer, one entry for each implemented @@ -638,7 +685,13 @@ GLOW_ONNXIFI_LIBRARY_FUNCTION_WRAPPER(onnxGetExtensionFunctionAddress)( GLOW_ONNXIFI_LIBRARY_FUNCTION_WRAPPER(onnxWaitEventFor))}, {"onnxReleaseTraceEventsFunction", reinterpret_cast( - GLOW_ONNXIFI_LIBRARY_FUNCTION_WRAPPER(onnxReleaseTraceEvents))}}; + GLOW_ONNXIFI_LIBRARY_FUNCTION_WRAPPER(onnxReleaseTraceEvents))}, + {"onnxSetOptionFunction", + reinterpret_cast( + GLOW_ONNXIFI_LIBRARY_FUNCTION_WRAPPER(onnxSetOption))}, + {"onnxGetOptionFunction", + reinterpret_cast( + GLOW_ONNXIFI_LIBRARY_FUNCTION_WRAPPER(onnxGetOption))}}; auto extensionIt = extensionMap.find(name); diff --git a/thirdparty/foxi b/thirdparty/foxi index 6a4e19a2aa..bd6feb6d0d 160000 --- a/thirdparty/foxi +++ b/thirdparty/foxi @@ -1 +1 @@ -Subproject commit 6a4e19a2aaf7ae4b9fa9597526e65b395d5e79ad +Subproject commit bd6feb6d0d3fc903df42b4feb82a602a5fcb1fd5