diff --git a/src/proto/device_path.rs b/src/proto/device_path.rs index daaa59564..33daf5f92 100644 --- a/src/proto/device_path.rs +++ b/src/proto/device_path.rs @@ -17,6 +17,7 @@ //! total size of the Node including the header. use crate::{proto::Protocol, unsafe_guid}; +use core::slice; /// Header that appears at the start of every [`DevicePath`] node. #[derive(Clone, Copy, Debug, Eq, PartialEq)] @@ -35,7 +36,7 @@ pub struct DevicePathHeader { /// This can be opened on a `LoadedImage.device()` handle using the `HandleProtocol` boot service. #[repr(C, packed)] #[unsafe_guid("09576e91-6d3f-11d2-8e39-00a0c969723b")] -#[derive(Protocol)] +#[derive(Eq, Protocol)] pub struct DevicePath { header: DevicePathHeader, } @@ -70,6 +71,27 @@ impl DevicePath { } } +impl PartialEq for DevicePath { + fn eq(&self, other: &DevicePath) -> bool { + // Check for equality with a byte-by-byte comparison of the device + // paths. Note that this covers the entire payload of the device path + // using the `length` field in the header, so it's not the same as just + // comparing the fields of the `DevicePath` struct. + unsafe { + let self_bytes = slice::from_raw_parts( + self as *const DevicePath as *const u8, + self.length() as usize, + ); + let other_bytes = slice::from_raw_parts( + other as *const DevicePath as *const u8, + other.length() as usize, + ); + + self_bytes == other_bytes + } + } +} + /// Iterator over [`DevicePath`] nodes. /// /// Iteration ends when a path is reached where [`DevicePath::is_end_entire`]