From a8412db5f3c2b2f0733a45fd15c1991d7103035c Mon Sep 17 00:00:00 2001 From: Nicholas Bishop Date: Mon, 20 Oct 2025 20:14:18 -0400 Subject: [PATCH] uefi: Wrap NmvePassThruProtocol in UnsafeCell See abdd6c4ff87f ("uefi: Wrap AtaPassThruProtocol in UnsafeCell") for details on why this change is needed. --- uefi/src/proto/nvme/pass_thru.rs | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/uefi/src/proto/nvme/pass_thru.rs b/uefi/src/proto/nvme/pass_thru.rs index 415ddd9b5..14c087252 100644 --- a/uefi/src/proto/nvme/pass_thru.rs +++ b/uefi/src/proto/nvme/pass_thru.rs @@ -7,6 +7,7 @@ use crate::StatusExt; use crate::mem::{AlignedBuffer, PoolAllocation}; use crate::proto::device_path::PoolDevicePathNode; use core::alloc::LayoutError; +use core::cell::UnsafeCell; use core::ptr::{self, NonNull}; use uefi_macros::unsafe_protocol; use uefi_raw::Status; @@ -40,7 +41,7 @@ pub type NvmeNamespaceId = u32; #[derive(Debug)] #[repr(transparent)] #[unsafe_protocol(NvmExpressPassThruProtocol::GUID)] -pub struct NvmePassThru(NvmExpressPassThruProtocol); +pub struct NvmePassThru(UnsafeCell); impl NvmePassThru { /// Retrieves the mode of the NVMe Pass Thru protocol. @@ -49,7 +50,7 @@ impl NvmePassThru { /// An instance of [`NvmePassThruMode`] describing the NVMe controller's capabilities. #[must_use] pub fn mode(&self) -> NvmePassThruMode { - let mut mode = unsafe { (*self.0.mode).clone() }; + let mut mode = unsafe { (*(*self.0.get()).mode).clone() }; mode.io_align = mode.io_align.max(1); // 0 and 1 is the same, says UEFI spec mode } @@ -116,15 +117,11 @@ impl NvmePassThru { /// Typically, consumer devices only have a single namespace where all the data resides (id 1). #[derive(Debug)] pub struct NvmeNamespace<'a> { - proto: &'a NvmExpressPassThruProtocol, + proto: &'a UnsafeCell, namespace_id: NvmeNamespaceId, } impl NvmeNamespace<'_> { - const fn proto_mut(&mut self) -> *mut NvmExpressPassThruProtocol { - ptr::from_ref(self.proto).cast_mut() - } - /// Retrieves the namespace identifier (NSID) associated with this NVMe namespace. #[must_use] pub const fn namespace_id(&self) -> NvmeNamespaceId { @@ -138,8 +135,12 @@ impl NvmeNamespace<'_> { pub fn path_node(&self) -> crate::Result { unsafe { let mut path_ptr: *const DevicePathProtocol = ptr::null(); - (self.proto.build_device_path)(self.proto, self.namespace_id, &mut path_ptr) - .to_result()?; + ((*self.proto.get()).build_device_path)( + self.proto.get(), + self.namespace_id, + &mut path_ptr, + ) + .to_result()?; NonNull::new(path_ptr.cast_mut()) .map(|p| PoolDevicePathNode(PoolAllocation::new(p.cast()))) .ok_or_else(|| Status::OUT_OF_RESOURCES.into()) @@ -178,8 +179,8 @@ impl NvmeNamespace<'_> { req.packet.nvme_cmd = &req.cmd; req.packet.nvme_completion = &mut completion; unsafe { - (self.proto.pass_thru)( - self.proto_mut(), + ((*self.proto.get()).pass_thru)( + self.proto.get(), self.namespace_id, &mut req.packet, ptr::null_mut(), @@ -195,7 +196,7 @@ impl NvmeNamespace<'_> { /// on the NVMe controller. #[derive(Debug)] pub struct NvmeNamespaceIterator<'a> { - proto: &'a NvmExpressPassThruProtocol, + proto: &'a UnsafeCell, prev: NvmeNamespaceId, } @@ -203,7 +204,8 @@ impl<'a> Iterator for NvmeNamespaceIterator<'a> { type Item = NvmeNamespace<'a>; fn next(&mut self) -> Option { - let result = unsafe { (self.proto.get_next_namespace)(self.proto, &mut self.prev) }; + let result = + unsafe { ((*self.proto.get()).get_next_namespace)(self.proto.get(), &mut self.prev) }; match result { Status::SUCCESS => Some(NvmeNamespace { proto: self.proto,