Skip to content

Commit abdd6c4

Browse files
uefi: Wrap AtaPassThruProtocol in UnsafeCell
The existing code had some potential UB; it created a mutable pointer from a const reference and passed it across the FFI boundary. (Whether the pointee is actually mutated depends on the firmware implementation.) An `UnsafeCell` allows the interior data to be mutated through a const reference. `AtaPassThru` now contains an `UnsafeCell<AtaPassThruProtocol>`, which allows a mutable pointer to be created with less risk of UB. (Note that it's still not allowed to create multiple mutable _references_ to the data, but as long as only raw pointers are used, it should be OK.) The `AtaDevice` and `AtaDeviceIterator` types have been adjusted to take a reference to the `UnsafeCell`.
1 parent f2f6abf commit abdd6c4

File tree

1 file changed

+25
-15
lines changed

1 file changed

+25
-15
lines changed

uefi/src/proto/ata/pass_thru.rs

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::StatusExt;
77
use crate::mem::{AlignedBuffer, PoolAllocation};
88
use crate::proto::device_path::PoolDevicePathNode;
99
use core::alloc::LayoutError;
10+
use core::cell::UnsafeCell;
1011
use core::ptr::{self, NonNull};
1112
use uefi_macros::unsafe_protocol;
1213
use uefi_raw::Status;
@@ -33,7 +34,7 @@ pub type AtaPassThruMode = uefi_raw::protocol::ata::AtaPassThruMode;
3334
#[derive(Debug)]
3435
#[repr(transparent)]
3536
#[unsafe_protocol(AtaPassThruProtocol::GUID)]
36-
pub struct AtaPassThru(AtaPassThruProtocol);
37+
pub struct AtaPassThru(UnsafeCell<AtaPassThruProtocol>);
3738

3839
impl AtaPassThru {
3940
/// Retrieves the mode structure for the Extended SCSI Pass Thru protocol.
@@ -42,7 +43,7 @@ impl AtaPassThru {
4243
/// The [`AtaPassThruMode`] structure containing configuration details of the protocol.
4344
#[must_use]
4445
pub fn mode(&self) -> AtaPassThruMode {
45-
let mut mode = unsafe { (*self.0.mode).clone() };
46+
let mut mode = unsafe { (*(*self.0.get()).mode).clone() };
4647
mode.io_align = mode.io_align.max(1); // 0 and 1 is the same, says UEFI spec
4748
mode
4849
}
@@ -101,16 +102,12 @@ impl AtaPassThru {
101102
/// available / connected device using [`AtaDevice::execute_command`] before doing anything meaningful.
102103
#[derive(Debug)]
103104
pub struct AtaDevice<'a> {
104-
proto: &'a AtaPassThruProtocol,
105+
proto: &'a UnsafeCell<AtaPassThruProtocol>,
105106
port: u16,
106107
pmp: u16,
107108
}
108109

109110
impl AtaDevice<'_> {
110-
const fn proto_mut(&mut self) -> *mut AtaPassThruProtocol {
111-
ptr::from_ref(self.proto).cast_mut()
112-
}
113-
114111
/// Returns the port number of the device.
115112
///
116113
/// # Details
@@ -142,7 +139,9 @@ impl AtaDevice<'_> {
142139
/// - [`Status::DEVICE_ERROR`] A device error occurred while attempting to reset the specified ATA device.
143140
/// - [`Status::TIMEOUT`] A timeout occurred while attempting to reset the specified ATA device.
144141
pub fn reset(&mut self) -> crate::Result<()> {
145-
unsafe { (self.proto.reset_device)(self.proto_mut(), self.port, self.pmp).to_result() }
142+
unsafe {
143+
((*self.proto.get()).reset_device)(self.proto.get(), self.port, self.pmp).to_result()
144+
}
146145
}
147146

148147
/// Get the final device path node for this device.
@@ -152,8 +151,13 @@ impl AtaDevice<'_> {
152151
pub fn path_node(&self) -> crate::Result<PoolDevicePathNode> {
153152
unsafe {
154153
let mut path_ptr: *const DevicePathProtocol = ptr::null();
155-
(self.proto.build_device_path)(self.proto, self.port, self.pmp, &mut path_ptr)
156-
.to_result()?;
154+
((*self.proto.get()).build_device_path)(
155+
self.proto.get(),
156+
self.port,
157+
self.pmp,
158+
&mut path_ptr,
159+
)
160+
.to_result()?;
157161
NonNull::new(path_ptr.cast_mut())
158162
.map(|p| PoolDevicePathNode(PoolAllocation::new(p.cast())))
159163
.ok_or_else(|| Status::OUT_OF_RESOURCES.into())
@@ -184,8 +188,8 @@ impl AtaDevice<'_> {
184188
) -> crate::Result<AtaResponse<'req>> {
185189
req.packet.acb = &req.acb;
186190
unsafe {
187-
(self.proto.pass_thru)(
188-
self.proto_mut(),
191+
((*self.proto.get()).pass_thru)(
192+
self.proto.get(),
189193
self.port,
190194
self.pmp,
191195
&mut req.packet,
@@ -203,7 +207,7 @@ impl AtaDevice<'_> {
203207
/// is actually available and connected!
204208
#[derive(Debug)]
205209
pub struct AtaDeviceIterator<'a> {
206-
proto: &'a AtaPassThruProtocol,
210+
proto: &'a UnsafeCell<AtaPassThruProtocol>,
207211
// when there are no more devices on this port -> get next port
208212
end_of_port: bool,
209213
prev_port: u16,
@@ -216,7 +220,9 @@ impl<'a> Iterator for AtaDeviceIterator<'a> {
216220
fn next(&mut self) -> Option<Self::Item> {
217221
loop {
218222
if self.end_of_port {
219-
let result = unsafe { (self.proto.get_next_port)(self.proto, &mut self.prev_port) };
223+
let result = unsafe {
224+
((*self.proto.get()).get_next_port)(self.proto.get(), &mut self.prev_port)
225+
};
220226
match result {
221227
Status::SUCCESS => self.end_of_port = false,
222228
Status::NOT_FOUND => return None, // no more ports / devices. End of list
@@ -233,7 +239,11 @@ impl<'a> Iterator for AtaDeviceIterator<'a> {
233239
// to the port! A port where the device is directly connected uses a pmp-value of 0xFFFF.
234240
let was_first = self.prev_pmp == 0xFFFF;
235241
let result = unsafe {
236-
(self.proto.get_next_device)(self.proto, self.prev_port, &mut self.prev_pmp)
242+
((*self.proto.get()).get_next_device)(
243+
self.proto.get(),
244+
self.prev_port,
245+
&mut self.prev_pmp,
246+
)
237247
};
238248
match result {
239249
Status::SUCCESS => {

0 commit comments

Comments
 (0)