From d2fb7d4c4c3e276bd2f75cb8be35af155454720d Mon Sep 17 00:00:00 2001 From: Clemens Wasser Date: Mon, 17 Apr 2023 22:12:48 +0200 Subject: [PATCH] Use windows-sys and fix feature flag --- Cargo.lock | 2 +- Cargo.toml | 8 ++++- src/timer/mod.rs | 7 ++-- src/timer/windows_timer.rs | 70 ++++++++++++++++++++------------------ 4 files changed, 49 insertions(+), 38 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 70748b3c9..8403c34df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -434,7 +434,7 @@ dependencies = [ "statistical", "tempfile", "thiserror", - "winapi", + "windows-sys 0.45.0", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 49fb72d54..ee5a7f96b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,13 @@ anyhow = "1.0" libc = "0.2" [target.'cfg(windows)'.dependencies] -winapi = { version = "0.3", features = ["processthreadsapi", "minwindef", "winnt", "jobapi2", "tlhelp32", "handleapi"] } +windows-sys = { version = "0.45", features = [ + "Win32_Foundation", + "Win32_Security", + "Win32_System_JobObjects", + "Win32_System_LibraryLoader", + "Win32_System_Threading", +] } [target.'cfg(all(windows, not(windows_process_extensions_main_thread_handle)))'.dependencies] once_cell = "1.17" diff --git a/src/timer/mod.rs b/src/timer/mod.rs index d91bef2f2..2c57c2b5c 100644 --- a/src/timer/mod.rs +++ b/src/timer/mod.rs @@ -13,6 +13,9 @@ use std::fs::File; #[cfg(target_os = "linux")] use std::os::unix::io::AsRawFd; +#[cfg(target_os = "windows")] +use windows_sys::Win32::System::Threading::CREATE_SUSPENDED; + use crate::util::units::Second; use wall_clock_timer::WallClockTimer; @@ -82,8 +85,8 @@ pub fn execute_and_measure(mut command: Command) -> Result { { use std::os::windows::process::CommandExt; - // Create a suspended process - command.creation_flags(4); + // Create the process in a suspended state so that we don't miss any cpu time between process creation and `CPUTimer` start. + command.creation_flags(CREATE_SUSPENDED); } let wallclock_timer = WallClockTimer::start(); diff --git a/src/timer/windows_timer.rs b/src/timer/windows_timer.rs index 84ffac58e..c6c5f25b8 100644 --- a/src/timer/windows_timer.rs +++ b/src/timer/windows_timer.rs @@ -3,43 +3,47 @@ use std::{mem, os::windows::io::AsRawHandle, process, ptr}; -use winapi::{ - shared::{ntdef::NTSTATUS, ntstatus::STATUS_SUCCESS}, - um::{ - handleapi::CloseHandle, - jobapi2::{AssignProcessToJobObject, CreateJobObjectW, QueryInformationJobObject}, - libloaderapi::{GetModuleHandleA, GetProcAddress}, - winnt::{ - JobObjectBasicAccountingInformation, HANDLE, JOBOBJECT_BASIC_ACCOUNTING_INFORMATION, - }, +use windows_sys::Win32::{ + Foundation::{CloseHandle, HANDLE}, + System::JobObjects::{ + AssignProcessToJobObject, CreateJobObjectW, JobObjectBasicAccountingInformation, + QueryInformationJobObject, JOBOBJECT_BASIC_ACCOUNTING_INFORMATION, }, }; -#[cfg(windows_process_extensions_main_thread_handle)] -use winapi::shared::minwindef::DWORD; +#[cfg(feature = "windows_process_extensions_main_thread_handle")] +use std::os::windows::process::ChildExt; +#[cfg(feature = "windows_process_extensions_main_thread_handle")] +use windows_sys::Win32::System::Threading::ResumeThread; -#[cfg(not(windows_process_extensions_main_thread_handle))] +#[cfg(not(feature = "windows_process_extensions_main_thread_handle"))] use once_cell::sync::Lazy; +#[cfg(not(feature = "windows_process_extensions_main_thread_handle"))] +use windows_sys::{ + s, w, + Win32::{ + Foundation::{NTSTATUS, STATUS_SUCCESS}, + System::LibraryLoader::{GetModuleHandleW, GetProcAddress}, + }, +}; use crate::util::units::Second; const HUNDRED_NS_PER_MS: i64 = 10; -#[cfg(not(windows_process_extensions_main_thread_handle))] +#[cfg(not(feature = "windows_process_extensions_main_thread_handle"))] #[allow(non_upper_case_globals)] static NtResumeProcess: Lazy NTSTATUS> = Lazy::new(|| { // SAFETY: Getting the module handle for ntdll.dll is safe - let ntdll = unsafe { GetModuleHandleA(b"ntdll.dll\0".as_ptr().cast()) }; - assert!(!ntdll.is_null(), "GetModuleHandleA failed"); + let ntdll = unsafe { GetModuleHandleW(w!("ntdll.dll")) }; + assert!(ntdll != 0, "GetModuleHandleW failed"); // SAFETY: The ntdll handle is valid - let nt_resume_process = - unsafe { GetProcAddress(ntdll, b"NtResumeProcess\0".as_ptr().cast()) }; - assert!(!nt_resume_process.is_null(), "GetProcAddress failed"); + let nt_resume_process = unsafe { GetProcAddress(ntdll, s!("NtResumeProcess")) }; // SAFETY: We transmute to the correct function signature - unsafe { mem::transmute(nt_resume_process) } + unsafe { mem::transmute(nt_resume_process.unwrap()) } }); pub struct CPUTimer { @@ -48,24 +52,24 @@ pub struct CPUTimer { impl CPUTimer { pub unsafe fn start_suspended_process(child: &process::Child) -> Self { - let child_handle = child.as_raw_handle().cast(); + let child_handle = child.as_raw_handle() as HANDLE; // SAFETY: Creating a new job object is safe let job_object = unsafe { CreateJobObjectW(ptr::null_mut(), ptr::null_mut()) }; - assert!(!job_object.is_null(), "CreateJobObjectW failed"); + assert!(job_object != 0, "CreateJobObjectW failed"); // SAFETY: The job object handle is valid let ret = unsafe { AssignProcessToJobObject(job_object, child_handle) }; assert!(ret != 0, "AssignProcessToJobObject failed"); - #[cfg(windows_process_extensions_main_thread_handle)] + #[cfg(feature = "windows_process_extensions_main_thread_handle")] { // SAFETY: The main thread handle is valid - let ret = unsafe { ResumeThread(child.main_thread_handle().as_raw_handle()) }; - assert!(ret != -1 as DWORD, "ResumeThread failed"); + let ret = unsafe { ResumeThread(child.main_thread_handle().as_raw_handle() as HANDLE) }; + assert!(ret != u32::MAX, "ResumeThread failed"); } - #[cfg(not(windows_process_extensions_main_thread_handle))] + #[cfg(not(feature = "windows_process_extensions_main_thread_handle"))] { // Since we can't get the main thread handle on stable rust, we use // the undocumented but widely known `NtResumeProcess` function to @@ -98,17 +102,15 @@ impl CPUTimer { // SAFETY: The job object info got correctly initialized let job_object_info = unsafe { job_object_info.assume_init() }; - // SAFETY: The `TotalUserTime` is "The total amount of user-mode execution time for + // The `TotalUserTime` is "The total amount of user-mode execution time for // all active processes associated with the job, as well as all terminated processes no - // longer associated with the job, in 100-nanosecond ticks." and is safe to extract - let user: i64 = unsafe { job_object_info.TotalUserTime.QuadPart() } / HUNDRED_NS_PER_MS; + // longer associated with the job, in 100-nanosecond ticks." + let user: i64 = job_object_info.TotalUserTime / HUNDRED_NS_PER_MS; - // SAFETY: The `TotalKernelTime` is "The total amount of kernel-mode execution time + // The `TotalKernelTime` is "The total amount of kernel-mode execution time // for all active processes associated with the job, as well as all terminated - // processes no longer associated with the job, in 100-nanosecond ticks." and is safe - // to extract - let kernel: i64 = - unsafe { job_object_info.TotalKernelTime.QuadPart() } / HUNDRED_NS_PER_MS; + // processes no longer associated with the job, in 100-nanosecond ticks." + let kernel: i64 = job_object_info.TotalKernelTime / HUNDRED_NS_PER_MS; (user as f64 * 1e-6, kernel as f64 * 1e-6) } else { (0.0, 0.0) @@ -117,7 +119,7 @@ impl CPUTimer { } impl Drop for CPUTimer { - fn drop(self: &mut Self) { + fn drop(&mut self) { // SAFETY: A valid job object got created in `start_suspended_process` unsafe { CloseHandle(self.job_object) }; }