diff --git a/src/_macros.rs b/src/_macros.rs index f21d5741b..4ee56fc8d 100644 --- a/src/_macros.rs +++ b/src/_macros.rs @@ -118,7 +118,8 @@ macro_rules! build_table_column_slice_getter { ($(#[$attr:meta])* => $column: ident, $name: ident, $cast: ty) => { $(#[$attr])* pub fn $name(&self) -> &[$cast] { - $crate::sys::generate_slice(self.as_ref().$column, self.num_rows()) + // SAFETY: all array lengths are the number of rows in the table + unsafe{$crate::sys::generate_slice(self.as_ref().$column, self.num_rows())} } }; } @@ -127,7 +128,8 @@ macro_rules! build_table_column_slice_mut_getter { ($(#[$attr:meta])* => $column: ident, $name: ident, $cast: ty) => { $(#[$attr])* pub fn $name(&mut self) -> &mut [$cast] { - $crate::sys::generate_slice_mut(self.as_ref().$column, self.num_rows()) + // SAFETY: all array lengths are the number of rows in the table + unsafe{$crate::sys::generate_slice_mut(self.as_ref().$column, self.num_rows())} } }; } diff --git a/src/sys/mod.rs b/src/sys/mod.rs index a8d5f058d..9f1490cf8 100644 --- a/src/sys/mod.rs +++ b/src/sys/mod.rs @@ -189,22 +189,28 @@ pub fn tsk_ragged_column_access< .map(|(p, n)| unsafe { std::slice::from_raw_parts(p.cast::(), n) }) } -pub fn generate_slice<'a, L: Into, I, O>( +/// # SAFETY +/// +/// * data must not be NULL +/// * length must be a valid offset from data +/// (ideally it comes from the tskit-c API) +pub unsafe fn generate_slice<'a, L: Into, I, O>( data: *const I, length: L, ) -> &'a [O] { - assert!(!data.is_null()); - // SAFETY: pointer is not null, length comes from C API - unsafe { std::slice::from_raw_parts(data.cast::(), length.into() as usize) } + std::slice::from_raw_parts(data.cast::(), length.into() as usize) } -pub fn generate_slice_mut<'a, L: Into, I, O>( +/// # SAFETY +/// +/// * data must not be NULL +/// * length must be a valid offset from data +/// (ideally it comes from the tskit-c API) +pub unsafe fn generate_slice_mut<'a, L: Into, I, O>( data: *mut I, length: L, ) -> &'a mut [O] { - assert!(!data.is_null()); - // SAFETY: pointer is not null, length comes from C API - unsafe { std::slice::from_raw_parts_mut(data.cast::(), length.into() as usize) } + std::slice::from_raw_parts_mut(data.cast::(), length.into() as usize) } pub fn get_tskit_error_message(code: i32) -> String { diff --git a/src/sys/tree.rs b/src/sys/tree.rs index 87898b538..61e22b1c8 100644 --- a/src/sys/tree.rs +++ b/src/sys/tree.rs @@ -59,7 +59,8 @@ impl<'treeseq> LLTree<'treeseq> { pub fn samples_array(&self) -> Result<&[super::newtypes::NodeId], TskitError> { err_if_not_tracking_samples!( self.flags, - super::generate_slice(self.as_ll_ref().samples, self.num_samples()) + // SAFETY: num_samples is the correct value + unsafe { super::generate_slice(self.as_ll_ref().samples, self.num_samples()) } ) } @@ -182,43 +183,63 @@ impl<'treeseq> LLTree<'treeseq> { pub fn sample_nodes(&self) -> &[NodeId] { assert!(!self.as_ptr().is_null()); - // SAFETY: self ptr is not null and the tree is initialized - let num_samples = - unsafe { bindings::tsk_treeseq_get_num_samples(self.as_ll_ref().tree_sequence) }; - super::generate_slice(self.as_ll_ref().samples, num_samples) + unsafe { + // SAFETY: self ptr is not null and the tree is initialized + // num_samples is the correct array length + let num_samples = bindings::tsk_treeseq_get_num_samples(self.as_ll_ref().tree_sequence); + super::generate_slice(self.as_ll_ref().samples, num_samples) + } } pub fn parent_array(&self) -> &[NodeId] { - super::generate_slice(self.as_ll_ref().parent, self.treeseq.num_nodes_raw() + 1) + // SAFETY: the array length is the number of nodes + 1 for the "virtual root" + unsafe { super::generate_slice(self.as_ll_ref().parent, self.treeseq.num_nodes_raw() + 1) } } pub fn left_sib_array(&self) -> &[NodeId] { - super::generate_slice(self.as_ll_ref().left_sib, self.treeseq.num_nodes_raw() + 1) + // SAFETY: the array length is the number of nodes + 1 for the "virtual root" + unsafe { + super::generate_slice(self.as_ll_ref().left_sib, self.treeseq.num_nodes_raw() + 1) + } } pub fn right_sib_array(&self) -> &[NodeId] { - super::generate_slice(self.as_ll_ref().right_sib, self.treeseq.num_nodes_raw() + 1) + // SAFETY: the array length is the number of nodes + 1 for the "virtual root" + unsafe { + super::generate_slice(self.as_ll_ref().right_sib, self.treeseq.num_nodes_raw() + 1) + } } pub fn left_child_array(&self) -> &[NodeId] { - super::generate_slice( - self.as_ll_ref().left_child, - self.treeseq.num_nodes_raw() + 1, - ) + // SAFETY: the array length is the number of nodes + 1 for the "virtual root" + unsafe { + super::generate_slice( + self.as_ll_ref().left_child, + self.treeseq.num_nodes_raw() + 1, + ) + } } pub fn right_child_array(&self) -> &[NodeId] { - super::generate_slice( - self.as_ll_ref().right_child, - self.treeseq.num_nodes_raw() + 1, - ) + // SAFETY: the array length is the number of nodes + 1 for the "virtual root" + unsafe { + super::generate_slice( + self.as_ll_ref().right_child, + self.treeseq.num_nodes_raw() + 1, + ) + } } pub fn total_branch_length(&self, by_span: bool) -> Result { - let time: &[Time] = super::generate_slice( - unsafe { (*(*(*self.as_ptr()).tree_sequence).tables).nodes.time }, - self.treeseq.num_nodes_raw() + 1, - ); + assert!(!self.treeseq.as_ref().tables.is_null()); + // SAFETY: array len is number of nodes + 1 for the "virtual root" + // tables ptr is not NULL + let time: &[Time] = unsafe { + super::generate_slice( + (*(self.treeseq.as_ref()).tables).nodes.time, + self.treeseq.num_nodes_raw() + 1, + ) + }; let mut b = Time::from(0.); for n in self.traverse_nodes(NodeTraversalOrder::Preorder) { let p = self.parent(n).ok_or(TskitError::IndexError {})?; @@ -246,32 +267,38 @@ impl<'treeseq> LLTree<'treeseq> { } pub fn left_sample_array(&self) -> Result<&[NodeId], TskitError> { - err_if_not_tracking_samples!( - self.flags, + err_if_not_tracking_samples!(self.flags, unsafe { + // SAFETY: array length is number of nodes + 1 for the "virtual root" super::generate_slice( self.as_ll_ref().left_sample, - self.treeseq.num_nodes_raw() + 1 + self.treeseq.num_nodes_raw() + 1, ) - ) + }) } pub fn right_sample_array(&self) -> Result<&[NodeId], TskitError> { err_if_not_tracking_samples!( self.flags, - super::generate_slice( - self.as_ll_ref().right_sample, - self.treeseq.num_nodes_raw() + 1 - ) + // SAFETY: array length is number of nodes + 1 for the "virtual root" + unsafe { + super::generate_slice( + self.as_ll_ref().right_sample, + self.treeseq.num_nodes_raw() + 1, + ) + } ) } pub fn next_sample_array(&self) -> Result<&[NodeId], TskitError> { err_if_not_tracking_samples!( self.flags, - super::generate_slice( - self.as_ll_ref().next_sample, - self.treeseq.num_nodes_raw() + 1 - ) + // SAFETY: array length is number of nodes + 1 for the "virtual root" + unsafe { + super::generate_slice( + self.as_ll_ref().next_sample, + self.treeseq.num_nodes_raw() + 1, + ) + } ) } diff --git a/src/trees/treeseq.rs b/src/trees/treeseq.rs index 9c7616552..e04d84ad4 100644 --- a/src/trees/treeseq.rs +++ b/src/trees/treeseq.rs @@ -283,8 +283,10 @@ impl TreeSequence { /// Get the list of sample nodes as a slice. pub fn sample_nodes(&self) -> &[NodeId] { - let num_samples = unsafe { ll_bindings::tsk_treeseq_get_num_samples(self.as_ptr()) }; - sys::generate_slice(self.as_ref().samples, num_samples) + unsafe { + let num_samples = ll_bindings::tsk_treeseq_get_num_samples(self.as_ref()); + sys::generate_slice(self.as_ref().samples, num_samples) + } } /// Get the number of trees.