@@ -7,6 +7,7 @@ use crate::StatusExt;
77use crate :: mem:: { AlignedBuffer , PoolAllocation } ;
88use crate :: proto:: device_path:: PoolDevicePathNode ;
99use core:: alloc:: LayoutError ;
10+ use core:: cell:: UnsafeCell ;
1011use core:: ptr:: { self , NonNull } ;
1112use uefi_macros:: unsafe_protocol;
1213use uefi_raw:: Status ;
@@ -33,7 +34,7 @@ pub type AtaPassThruMode = uefi_raw::protocol::ata::AtaPassThruMode;
3334#[ derive( Debug ) ]
3435#[ repr( transparent) ]
3536#[ unsafe_protocol( AtaPassThruProtocol :: GUID ) ]
36- pub struct AtaPassThru ( AtaPassThruProtocol ) ;
37+ pub struct AtaPassThru ( UnsafeCell < AtaPassThruProtocol > ) ;
3738
3839impl AtaPassThru {
3940 /// Retrieves the mode structure for the Extended SCSI Pass Thru protocol.
@@ -42,7 +43,7 @@ impl AtaPassThru {
4243 /// The [`AtaPassThruMode`] structure containing configuration details of the protocol.
4344 #[ must_use]
4445 pub fn mode ( & self ) -> AtaPassThruMode {
45- let mut mode = unsafe { ( * self . 0 . mode ) . clone ( ) } ;
46+ let mut mode = unsafe { ( * ( * self . 0 . get ( ) ) . mode ) . clone ( ) } ;
4647 mode. io_align = mode. io_align . max ( 1 ) ; // 0 and 1 is the same, says UEFI spec
4748 mode
4849 }
@@ -101,16 +102,12 @@ impl AtaPassThru {
101102/// available / connected device using [`AtaDevice::execute_command`] before doing anything meaningful.
102103#[ derive( Debug ) ]
103104pub struct AtaDevice < ' a > {
104- proto : & ' a AtaPassThruProtocol ,
105+ proto : & ' a UnsafeCell < AtaPassThruProtocol > ,
105106 port : u16 ,
106107 pmp : u16 ,
107108}
108109
109110impl AtaDevice < ' _ > {
110- const fn proto_mut ( & mut self ) -> * mut AtaPassThruProtocol {
111- ptr:: from_ref ( self . proto ) . cast_mut ( )
112- }
113-
114111 /// Returns the port number of the device.
115112 ///
116113 /// # Details
@@ -142,7 +139,9 @@ impl AtaDevice<'_> {
142139 /// - [`Status::DEVICE_ERROR`] A device error occurred while attempting to reset the specified ATA device.
143140 /// - [`Status::TIMEOUT`] A timeout occurred while attempting to reset the specified ATA device.
144141 pub fn reset ( & mut self ) -> crate :: Result < ( ) > {
145- unsafe { ( self . proto . reset_device ) ( self . proto_mut ( ) , self . port , self . pmp ) . to_result ( ) }
142+ unsafe {
143+ ( ( * self . proto . get ( ) ) . reset_device ) ( self . proto . get ( ) , self . port , self . pmp ) . to_result ( )
144+ }
146145 }
147146
148147 /// Get the final device path node for this device.
@@ -152,8 +151,13 @@ impl AtaDevice<'_> {
152151 pub fn path_node ( & self ) -> crate :: Result < PoolDevicePathNode > {
153152 unsafe {
154153 let mut path_ptr: * const DevicePathProtocol = ptr:: null ( ) ;
155- ( self . proto . build_device_path ) ( self . proto , self . port , self . pmp , & mut path_ptr)
156- . to_result ( ) ?;
154+ ( ( * self . proto . get ( ) ) . build_device_path ) (
155+ self . proto . get ( ) ,
156+ self . port ,
157+ self . pmp ,
158+ & mut path_ptr,
159+ )
160+ . to_result ( ) ?;
157161 NonNull :: new ( path_ptr. cast_mut ( ) )
158162 . map ( |p| PoolDevicePathNode ( PoolAllocation :: new ( p. cast ( ) ) ) )
159163 . ok_or_else ( || Status :: OUT_OF_RESOURCES . into ( ) )
@@ -184,8 +188,8 @@ impl AtaDevice<'_> {
184188 ) -> crate :: Result < AtaResponse < ' req > > {
185189 req. packet . acb = & req. acb ;
186190 unsafe {
187- ( self . proto . pass_thru ) (
188- self . proto_mut ( ) ,
191+ ( ( * self . proto . get ( ) ) . pass_thru ) (
192+ self . proto . get ( ) ,
189193 self . port ,
190194 self . pmp ,
191195 & mut req. packet ,
@@ -203,7 +207,7 @@ impl AtaDevice<'_> {
203207/// is actually available and connected!
204208#[ derive( Debug ) ]
205209pub struct AtaDeviceIterator < ' a > {
206- proto : & ' a AtaPassThruProtocol ,
210+ proto : & ' a UnsafeCell < AtaPassThruProtocol > ,
207211 // when there are no more devices on this port -> get next port
208212 end_of_port : bool ,
209213 prev_port : u16 ,
@@ -216,7 +220,9 @@ impl<'a> Iterator for AtaDeviceIterator<'a> {
216220 fn next ( & mut self ) -> Option < Self :: Item > {
217221 loop {
218222 if self . end_of_port {
219- let result = unsafe { ( self . proto . get_next_port ) ( self . proto , & mut self . prev_port ) } ;
223+ let result = unsafe {
224+ ( ( * self . proto . get ( ) ) . get_next_port ) ( self . proto . get ( ) , & mut self . prev_port )
225+ } ;
220226 match result {
221227 Status :: SUCCESS => self . end_of_port = false ,
222228 Status :: NOT_FOUND => return None , // no more ports / devices. End of list
@@ -233,7 +239,11 @@ impl<'a> Iterator for AtaDeviceIterator<'a> {
233239 // to the port! A port where the device is directly connected uses a pmp-value of 0xFFFF.
234240 let was_first = self . prev_pmp == 0xFFFF ;
235241 let result = unsafe {
236- ( self . proto . get_next_device ) ( self . proto , self . prev_port , & mut self . prev_pmp )
242+ ( ( * self . proto . get ( ) ) . get_next_device ) (
243+ self . proto . get ( ) ,
244+ self . prev_port ,
245+ & mut self . prev_pmp ,
246+ )
237247 } ;
238248 match result {
239249 Status :: SUCCESS => {
0 commit comments