Skip to content

Commit a9d4ef3

Browse files
authored
Merge pull request #1792 from rust-osdev/bishop-scsi-ub-fix
uefi: Wrap ScsiPassThruProtocol in UnsafeCell
2 parents f2dfa4d + 5d034f0 commit a9d4ef3

File tree

1 file changed

+23
-17
lines changed

1 file changed

+23
-17
lines changed

uefi/src/proto/scsi/pass_thru.rs

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::mem::{AlignedBuffer, PoolAllocation};
88
use crate::proto::device_path::PoolDevicePathNode;
99
use crate::proto::unsafe_protocol;
1010
use core::alloc::LayoutError;
11+
use core::cell::UnsafeCell;
1112
use core::ptr::{self, NonNull};
1213
use uefi_raw::Status;
1314
use uefi_raw::protocol::device_path::DevicePathProtocol;
@@ -42,7 +43,7 @@ impl Default for ScsiTargetLun {
4243
#[derive(Debug)]
4344
#[repr(transparent)]
4445
#[unsafe_protocol(ExtScsiPassThruProtocol::GUID)]
45-
pub struct ExtScsiPassThru(ExtScsiPassThruProtocol);
46+
pub struct ExtScsiPassThru(UnsafeCell<ExtScsiPassThruProtocol>);
4647

4748
impl ExtScsiPassThru {
4849
/// Retrieves the mode structure for the Extended SCSI Pass Thru protocol.
@@ -51,7 +52,7 @@ impl ExtScsiPassThru {
5152
/// The [`ExtScsiPassThruMode`] structure containing configuration details of the protocol.
5253
#[must_use]
5354
pub fn mode(&self) -> ExtScsiPassThruMode {
54-
let mut mode = unsafe { (*self.0.passthru_mode).clone() };
55+
let mut mode = unsafe { (*(*self.0.get()).passthru_mode).clone() };
5556
mode.io_align = mode.io_align.max(1); // 0 and 1 is the same, says UEFI spec
5657
mode
5758
}
@@ -113,7 +114,7 @@ impl ExtScsiPassThru {
113114
/// - [`Status::DEVICE_ERROR`] A device error occurred while attempting to reset the SCSI channel.
114115
/// - [`Status::TIMEOUT`] A timeout occurred while attempting to reset the SCSI channel.
115116
pub fn reset_channel(&mut self) -> crate::Result<()> {
116-
unsafe { (self.0.reset_channel)(&mut self.0).to_result() }
117+
unsafe { ((*self.0.get()).reset_channel)(self.0.get()).to_result() }
117118
}
118119
}
119120

@@ -126,14 +127,10 @@ impl ExtScsiPassThru {
126127
/// You have to probe for availability before doing anything meaningful with it.
127128
#[derive(Clone, Debug)]
128129
pub struct ScsiDevice<'a> {
129-
proto: &'a ExtScsiPassThruProtocol,
130+
proto: &'a UnsafeCell<ExtScsiPassThruProtocol>,
130131
target_lun: ScsiTargetLun,
131132
}
132133
impl ScsiDevice<'_> {
133-
const fn proto_mut(&mut self) -> *mut ExtScsiPassThruProtocol {
134-
ptr::from_ref(self.proto).cast_mut()
135-
}
136-
137134
/// Returns the SCSI target address of the potential device.
138135
#[must_use]
139136
pub const fn target(&self) -> &ScsiTarget {
@@ -153,8 +150,8 @@ impl ScsiDevice<'_> {
153150
pub fn path_node(&self) -> crate::Result<PoolDevicePathNode> {
154151
unsafe {
155152
let mut path_ptr: *const DevicePathProtocol = ptr::null();
156-
(self.proto.build_device_path)(
157-
self.proto,
153+
((*self.proto.get()).build_device_path)(
154+
self.proto.get(),
158155
self.target().as_ptr(),
159156
self.lun(),
160157
&mut path_ptr,
@@ -185,8 +182,12 @@ impl ScsiDevice<'_> {
185182
/// by `Target` and `Lun`.
186183
pub fn reset(&mut self) -> crate::Result<()> {
187184
unsafe {
188-
(self.proto.reset_target_lun)(self.proto_mut(), self.target_lun.0.as_ptr(), self.lun())
189-
.to_result()
185+
((*self.proto.get()).reset_target_lun)(
186+
self.proto.get(),
187+
self.target_lun.0.as_ptr(),
188+
self.lun(),
189+
)
190+
.to_result()
190191
}
191192
}
192193

@@ -223,8 +224,8 @@ impl ScsiDevice<'_> {
223224
mut scsi_req: ScsiRequest<'req>,
224225
) -> crate::Result<ScsiResponse<'req>> {
225226
unsafe {
226-
(self.proto.pass_thru)(
227-
self.proto_mut(),
227+
((*self.proto.get()).pass_thru)(
228+
self.proto.get(),
228229
self.target_lun.0.as_ptr(),
229230
self.target_lun.1,
230231
&mut scsi_req.packet,
@@ -238,7 +239,7 @@ impl ScsiDevice<'_> {
238239
/// An iterator over SCSI devices available on the channel.
239240
#[derive(Debug)]
240241
pub struct ScsiTargetLunIterator<'a> {
241-
proto: &'a ExtScsiPassThruProtocol,
242+
proto: &'a UnsafeCell<ExtScsiPassThruProtocol>,
242243
prev: ScsiTargetLun,
243244
}
244245
impl<'a> Iterator for ScsiTargetLunIterator<'a> {
@@ -248,8 +249,13 @@ impl<'a> Iterator for ScsiTargetLunIterator<'a> {
248249
// get_next_target_lun() takes the target as a double ptr, meaning that the spec allows
249250
// the implementation to return us a new buffer (most impls don't actually seem to do though)
250251
let mut target: *mut u8 = self.prev.0.as_mut_ptr();
251-
let result =
252-
unsafe { (self.proto.get_next_target_lun)(self.proto, &mut target, &mut self.prev.1) };
252+
let result = unsafe {
253+
((*self.proto.get()).get_next_target_lun)(
254+
self.proto.get(),
255+
&mut target,
256+
&mut self.prev.1,
257+
)
258+
};
253259
if target != self.prev.0.as_mut_ptr() {
254260
// impl has returned us a new pointer instead of writing in our buffer, copy back
255261
unsafe {

0 commit comments

Comments
 (0)