From 667a574c8fca0393d351f558680805cad9b71ec5 Mon Sep 17 00:00:00 2001 From: Steve Dower Date: Mon, 15 Apr 2024 16:43:28 +0100 Subject: [PATCH] gh-112278: Improve error handling in wmi module and tests (GH-117818) --- Lib/test/test_wmi.py | 28 ++++++++++---- PC/_wmimodule.cpp | 91 +++++++++++++++++++++++++++++++++++--------- 2 files changed, 94 insertions(+), 25 deletions(-) diff --git a/Lib/test/test_wmi.py b/Lib/test/test_wmi.py index 3445702846d8a0..f667926d1f8ddf 100644 --- a/Lib/test/test_wmi.py +++ b/Lib/test/test_wmi.py @@ -1,17 +1,31 @@ # Test the internal _wmi module on Windows # This is used by the platform module, and potentially others +import time import unittest -from test.support import import_helper, requires_resource +from test.support import import_helper, requires_resource, LOOPBACK_TIMEOUT # Do this first so test will be skipped if module doesn't exist _wmi = import_helper.import_module('_wmi', required_on=['win']) +def wmi_exec_query(query): + # gh-112278: WMI maybe slow response when first call. + try: + return _wmi.exec_query(query) + except BrokenPipeError: + pass + except WindowsError as e: + if e.winerror != 258: + raise + time.sleep(LOOPBACK_TIMEOUT) + return _wmi.exec_query(query) + + class WmiTests(unittest.TestCase): def test_wmi_query_os_version(self): - r = _wmi.exec_query("SELECT Version FROM Win32_OperatingSystem").split("\0") + r = wmi_exec_query("SELECT Version FROM Win32_OperatingSystem").split("\0") self.assertEqual(1, len(r)) k, eq, v = r[0].partition("=") self.assertEqual("=", eq, r[0]) @@ -28,7 +42,7 @@ def test_wmi_query_repeated(self): def test_wmi_query_error(self): # Invalid queries fail with OSError try: - _wmi.exec_query("SELECT InvalidColumnName FROM InvalidTableName") + wmi_exec_query("SELECT InvalidColumnName FROM InvalidTableName") except OSError as ex: if ex.winerror & 0xFFFFFFFF == 0x80041010: # This is the expected error code. All others should fail the test @@ -42,7 +56,7 @@ def test_wmi_query_repeated_error(self): def test_wmi_query_not_select(self): # Queries other than SELECT are blocked to avoid potential exploits with self.assertRaises(ValueError): - _wmi.exec_query("not select, just in case someone tries something") + wmi_exec_query("not select, just in case someone tries something") @requires_resource('cpu') def test_wmi_query_overflow(self): @@ -50,11 +64,11 @@ def test_wmi_query_overflow(self): # Test multiple times to ensure consistency for _ in range(2): with self.assertRaises(OSError): - _wmi.exec_query("SELECT * FROM CIM_DataFile") + wmi_exec_query("SELECT * FROM CIM_DataFile") def test_wmi_query_multiple_rows(self): # Multiple instances should have an extra null separator - r = _wmi.exec_query("SELECT ProcessId FROM Win32_Process WHERE ProcessId < 1000") + r = wmi_exec_query("SELECT ProcessId FROM Win32_Process WHERE ProcessId < 1000") self.assertFalse(r.startswith("\0"), r) self.assertFalse(r.endswith("\0"), r) it = iter(r.split("\0")) @@ -69,6 +83,6 @@ def test_wmi_query_threads(self): from concurrent.futures import ThreadPoolExecutor query = "SELECT ProcessId FROM Win32_Process WHERE ProcessId < 1000" with ThreadPoolExecutor(4) as pool: - task = [pool.submit(_wmi.exec_query, query) for _ in range(32)] + task = [pool.submit(wmi_exec_query, query) for _ in range(32)] for t in task: self.assertRegex(t.result(), "ProcessId=") diff --git a/PC/_wmimodule.cpp b/PC/_wmimodule.cpp index 310aa86d94d9b6..22ed05276e6f07 100644 --- a/PC/_wmimodule.cpp +++ b/PC/_wmimodule.cpp @@ -8,6 +8,11 @@ // Version history // 2022-08: Initial contribution (Steve Dower) +// clinic/_wmimodule.cpp.h uses internal pycore_modsupport.h API +#ifndef Py_BUILD_CORE_BUILTIN +# define Py_BUILD_CORE_MODULE 1 +#endif + #define _WIN32_DCOM #include #include @@ -39,6 +44,8 @@ struct _query_data { LPCWSTR query; HANDLE writePipe; HANDLE readPipe; + HANDLE initEvent; + HANDLE connectEvent; }; @@ -75,12 +82,18 @@ _query_thread(LPVOID param) IID_IWbemLocator, (LPVOID *)&locator ); } + if (SUCCEEDED(hr) && !SetEvent(data->initEvent)) { + hr = HRESULT_FROM_WIN32(GetLastError()); + } if (SUCCEEDED(hr)) { hr = locator->ConnectServer( bstr_t(L"ROOT\\CIMV2"), NULL, NULL, 0, NULL, 0, 0, &services ); } + if (SUCCEEDED(hr) && !SetEvent(data->connectEvent)) { + hr = HRESULT_FROM_WIN32(GetLastError()); + } if (SUCCEEDED(hr)) { hr = CoSetProxyBlanket( services, RPC_C_AUTHN_WINNT, RPC_C_AUTHZ_NONE, NULL, @@ -184,6 +197,24 @@ _query_thread(LPVOID param) } +static DWORD +wait_event(HANDLE event, DWORD timeout) +{ + DWORD err = 0; + switch (WaitForSingleObject(event, timeout)) { + case WAIT_OBJECT_0: + break; + case WAIT_TIMEOUT: + err = WAIT_TIMEOUT; + break; + default: + err = GetLastError(); + break; + } + return err; +} + + /*[clinic input] _wmi.exec_query @@ -226,7 +257,11 @@ _wmi_exec_query_impl(PyObject *module, PyObject *query) Py_BEGIN_ALLOW_THREADS - if (!CreatePipe(&data.readPipe, &data.writePipe, NULL, 0)) { + data.initEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + data.connectEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + if (!data.initEvent || !data.connectEvent || + !CreatePipe(&data.readPipe, &data.writePipe, NULL, 0)) + { err = GetLastError(); } else { hThread = CreateThread(NULL, 0, _query_thread, (LPVOID*)&data, 0, NULL); @@ -238,6 +273,19 @@ _wmi_exec_query_impl(PyObject *module, PyObject *query) } } + // gh-112278: If current user doesn't have permission to query the WMI, the + // function IWbemLocator::ConnectServer will hang for 5 seconds, and there + // is no way to specify the timeout. So we use an Event object to simulate + // a timeout. The initEvent will be set after COM initialization, it will + // take a longer time when first initialized. The connectEvent will be set + // after connected to WMI. + if (!err) { + err = wait_event(data.initEvent, 1000); + if (!err) { + err = wait_event(data.connectEvent, 100); + } + } + while (!err) { if (ReadFile( data.readPipe, @@ -259,28 +307,35 @@ _wmi_exec_query_impl(PyObject *module, PyObject *query) CloseHandle(data.readPipe); } - // Allow the thread some time to clean up - switch (WaitForSingleObject(hThread, 1000)) { - case WAIT_OBJECT_0: - // Thread ended cleanly - if (!GetExitCodeThread(hThread, (LPDWORD)&err)) { - err = GetLastError(); - } - break; - case WAIT_TIMEOUT: - // Probably stuck - there's not much we can do, unfortunately - if (err == 0 || err == ERROR_BROKEN_PIPE) { - err = WAIT_TIMEOUT; + if (hThread) { + // Allow the thread some time to clean up + int thread_err; + switch (WaitForSingleObject(hThread, 100)) { + case WAIT_OBJECT_0: + // Thread ended cleanly + if (!GetExitCodeThread(hThread, (LPDWORD)&thread_err)) { + thread_err = GetLastError(); + } + break; + case WAIT_TIMEOUT: + // Probably stuck - there's not much we can do, unfortunately + thread_err = WAIT_TIMEOUT; + break; + default: + thread_err = GetLastError(); + break; } - break; - default: + // An error on our side is more likely to be relevant than one from + // the thread, but if we don't have one on our side we'll take theirs. if (err == 0 || err == ERROR_BROKEN_PIPE) { - err = GetLastError(); + err = thread_err; } - break; + + CloseHandle(hThread); } - CloseHandle(hThread); + CloseHandle(data.initEvent); + CloseHandle(data.connectEvent); hThread = NULL; Py_END_ALLOW_THREADS