Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 23 additions & 17 deletions uefi/src/proto/scsi/pass_thru.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::mem::{AlignedBuffer, PoolAllocation};
use crate::proto::device_path::PoolDevicePathNode;
use crate::proto::unsafe_protocol;
use core::alloc::LayoutError;
use core::cell::UnsafeCell;
use core::ptr::{self, NonNull};
use uefi_raw::Status;
use uefi_raw::protocol::device_path::DevicePathProtocol;
Expand Down Expand Up @@ -42,7 +43,7 @@ impl Default for ScsiTargetLun {
#[derive(Debug)]
#[repr(transparent)]
#[unsafe_protocol(ExtScsiPassThruProtocol::GUID)]
pub struct ExtScsiPassThru(ExtScsiPassThruProtocol);
pub struct ExtScsiPassThru(UnsafeCell<ExtScsiPassThruProtocol>);

impl ExtScsiPassThru {
/// Retrieves the mode structure for the Extended SCSI Pass Thru protocol.
Expand All @@ -51,7 +52,7 @@ impl ExtScsiPassThru {
/// The [`ExtScsiPassThruMode`] structure containing configuration details of the protocol.
#[must_use]
pub fn mode(&self) -> ExtScsiPassThruMode {
let mut mode = unsafe { (*self.0.passthru_mode).clone() };
let mut mode = unsafe { (*(*self.0.get()).passthru_mode).clone() };
mode.io_align = mode.io_align.max(1); // 0 and 1 is the same, says UEFI spec
mode
}
Expand Down Expand Up @@ -113,7 +114,7 @@ impl ExtScsiPassThru {
/// - [`Status::DEVICE_ERROR`] A device error occurred while attempting to reset the SCSI channel.
/// - [`Status::TIMEOUT`] A timeout occurred while attempting to reset the SCSI channel.
pub fn reset_channel(&mut self) -> crate::Result<()> {
unsafe { (self.0.reset_channel)(&mut self.0).to_result() }
unsafe { ((*self.0.get()).reset_channel)(self.0.get()).to_result() }
}
}

Expand All @@ -126,14 +127,10 @@ impl ExtScsiPassThru {
/// You have to probe for availability before doing anything meaningful with it.
#[derive(Clone, Debug)]
pub struct ScsiDevice<'a> {
proto: &'a ExtScsiPassThruProtocol,
proto: &'a UnsafeCell<ExtScsiPassThruProtocol>,
target_lun: ScsiTargetLun,
}
impl ScsiDevice<'_> {
const fn proto_mut(&mut self) -> *mut ExtScsiPassThruProtocol {
ptr::from_ref(self.proto).cast_mut()
}

/// Returns the SCSI target address of the potential device.
#[must_use]
pub const fn target(&self) -> &ScsiTarget {
Expand All @@ -153,8 +150,8 @@ impl ScsiDevice<'_> {
pub fn path_node(&self) -> crate::Result<PoolDevicePathNode> {
unsafe {
let mut path_ptr: *const DevicePathProtocol = ptr::null();
(self.proto.build_device_path)(
self.proto,
((*self.proto.get()).build_device_path)(
self.proto.get(),
self.target().as_ptr(),
self.lun(),
&mut path_ptr,
Expand Down Expand Up @@ -185,8 +182,12 @@ impl ScsiDevice<'_> {
/// by `Target` and `Lun`.
pub fn reset(&mut self) -> crate::Result<()> {
unsafe {
(self.proto.reset_target_lun)(self.proto_mut(), self.target_lun.0.as_ptr(), self.lun())
.to_result()
((*self.proto.get()).reset_target_lun)(
self.proto.get(),
self.target_lun.0.as_ptr(),
self.lun(),
)
.to_result()
}
}

Expand Down Expand Up @@ -223,8 +224,8 @@ impl ScsiDevice<'_> {
mut scsi_req: ScsiRequest<'req>,
) -> crate::Result<ScsiResponse<'req>> {
unsafe {
(self.proto.pass_thru)(
self.proto_mut(),
((*self.proto.get()).pass_thru)(
self.proto.get(),
self.target_lun.0.as_ptr(),
self.target_lun.1,
&mut scsi_req.packet,
Expand All @@ -238,7 +239,7 @@ impl ScsiDevice<'_> {
/// An iterator over SCSI devices available on the channel.
#[derive(Debug)]
pub struct ScsiTargetLunIterator<'a> {
proto: &'a ExtScsiPassThruProtocol,
proto: &'a UnsafeCell<ExtScsiPassThruProtocol>,
prev: ScsiTargetLun,
}
impl<'a> Iterator for ScsiTargetLunIterator<'a> {
Expand All @@ -248,8 +249,13 @@ impl<'a> Iterator for ScsiTargetLunIterator<'a> {
// get_next_target_lun() takes the target as a double ptr, meaning that the spec allows
// the implementation to return us a new buffer (most impls don't actually seem to do though)
let mut target: *mut u8 = self.prev.0.as_mut_ptr();
let result =
unsafe { (self.proto.get_next_target_lun)(self.proto, &mut target, &mut self.prev.1) };
let result = unsafe {
((*self.proto.get()).get_next_target_lun)(
self.proto.get(),
&mut target,
&mut self.prev.1,
)
};
if target != self.prev.0.as_mut_ptr() {
// impl has returned us a new pointer instead of writing in our buffer, copy back
unsafe {
Expand Down