Skip to content

Commit 6c89597

Browse files
authored
Merge pull request #1793 from rust-osdev/bishop-nvme-ub-fix
uefi: Wrap NmvePassThruProtocol in UnsafeCell
2 parents a9d4ef3 + a8412db commit 6c89597

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

uefi/src/proto/nvme/pass_thru.rs

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::StatusExt;
77
use crate::mem::{AlignedBuffer, PoolAllocation};
88
use crate::proto::device_path::PoolDevicePathNode;
99
use core::alloc::LayoutError;
10+
use core::cell::UnsafeCell;
1011
use core::ptr::{self, NonNull};
1112
use uefi_macros::unsafe_protocol;
1213
use uefi_raw::Status;
@@ -40,7 +41,7 @@ pub type NvmeNamespaceId = u32;
4041
#[derive(Debug)]
4142
#[repr(transparent)]
4243
#[unsafe_protocol(NvmExpressPassThruProtocol::GUID)]
43-
pub struct NvmePassThru(NvmExpressPassThruProtocol);
44+
pub struct NvmePassThru(UnsafeCell<NvmExpressPassThruProtocol>);
4445

4546
impl NvmePassThru {
4647
/// Retrieves the mode of the NVMe Pass Thru protocol.
@@ -49,7 +50,7 @@ impl NvmePassThru {
4950
/// An instance of [`NvmePassThruMode`] describing the NVMe controller's capabilities.
5051
#[must_use]
5152
pub fn mode(&self) -> NvmePassThruMode {
52-
let mut mode = unsafe { (*self.0.mode).clone() };
53+
let mut mode = unsafe { (*(*self.0.get()).mode).clone() };
5354
mode.io_align = mode.io_align.max(1); // 0 and 1 is the same, says UEFI spec
5455
mode
5556
}
@@ -116,15 +117,11 @@ impl NvmePassThru {
116117
/// Typically, consumer devices only have a single namespace where all the data resides (id 1).
117118
#[derive(Debug)]
118119
pub struct NvmeNamespace<'a> {
119-
proto: &'a NvmExpressPassThruProtocol,
120+
proto: &'a UnsafeCell<NvmExpressPassThruProtocol>,
120121
namespace_id: NvmeNamespaceId,
121122
}
122123

123124
impl NvmeNamespace<'_> {
124-
const fn proto_mut(&mut self) -> *mut NvmExpressPassThruProtocol {
125-
ptr::from_ref(self.proto).cast_mut()
126-
}
127-
128125
/// Retrieves the namespace identifier (NSID) associated with this NVMe namespace.
129126
#[must_use]
130127
pub const fn namespace_id(&self) -> NvmeNamespaceId {
@@ -138,8 +135,12 @@ impl NvmeNamespace<'_> {
138135
pub fn path_node(&self) -> crate::Result<PoolDevicePathNode> {
139136
unsafe {
140137
let mut path_ptr: *const DevicePathProtocol = ptr::null();
141-
(self.proto.build_device_path)(self.proto, self.namespace_id, &mut path_ptr)
142-
.to_result()?;
138+
((*self.proto.get()).build_device_path)(
139+
self.proto.get(),
140+
self.namespace_id,
141+
&mut path_ptr,
142+
)
143+
.to_result()?;
143144
NonNull::new(path_ptr.cast_mut())
144145
.map(|p| PoolDevicePathNode(PoolAllocation::new(p.cast())))
145146
.ok_or_else(|| Status::OUT_OF_RESOURCES.into())
@@ -178,8 +179,8 @@ impl NvmeNamespace<'_> {
178179
req.packet.nvme_cmd = &req.cmd;
179180
req.packet.nvme_completion = &mut completion;
180181
unsafe {
181-
(self.proto.pass_thru)(
182-
self.proto_mut(),
182+
((*self.proto.get()).pass_thru)(
183+
self.proto.get(),
183184
self.namespace_id,
184185
&mut req.packet,
185186
ptr::null_mut(),
@@ -195,15 +196,16 @@ impl NvmeNamespace<'_> {
195196
/// on the NVMe controller.
196197
#[derive(Debug)]
197198
pub struct NvmeNamespaceIterator<'a> {
198-
proto: &'a NvmExpressPassThruProtocol,
199+
proto: &'a UnsafeCell<NvmExpressPassThruProtocol>,
199200
prev: NvmeNamespaceId,
200201
}
201202

202203
impl<'a> Iterator for NvmeNamespaceIterator<'a> {
203204
type Item = NvmeNamespace<'a>;
204205

205206
fn next(&mut self) -> Option<Self::Item> {
206-
let result = unsafe { (self.proto.get_next_namespace)(self.proto, &mut self.prev) };
207+
let result =
208+
unsafe { ((*self.proto.get()).get_next_namespace)(self.proto.get(), &mut self.prev) };
207209
match result {
208210
Status::SUCCESS => Some(NvmeNamespace {
209211
proto: self.proto,

0 commit comments

Comments
 (0)