Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ macro_rules! build_table_column_slice_getter {
($(#[$attr:meta])* => $column: ident, $name: ident, $cast: ty) => {
$(#[$attr])*
pub fn $name(&self) -> &[$cast] {
$crate::sys::generate_slice(self.as_ref().$column, self.num_rows())
// SAFETY: all array lengths are the number of rows in the table
unsafe{$crate::sys::generate_slice(self.as_ref().$column, self.num_rows())}
}
};
}
Expand All @@ -127,7 +128,8 @@ macro_rules! build_table_column_slice_mut_getter {
($(#[$attr:meta])* => $column: ident, $name: ident, $cast: ty) => {
$(#[$attr])*
pub fn $name(&mut self) -> &mut [$cast] {
$crate::sys::generate_slice_mut(self.as_ref().$column, self.num_rows())
// SAFETY: all array lengths are the number of rows in the table
unsafe{$crate::sys::generate_slice_mut(self.as_ref().$column, self.num_rows())}
}
};
}
Expand Down
22 changes: 14 additions & 8 deletions src/sys/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,22 +189,28 @@ pub fn tsk_ragged_column_access<
.map(|(p, n)| unsafe { std::slice::from_raw_parts(p.cast::<O>(), n) })
}

pub fn generate_slice<'a, L: Into<bindings::tsk_size_t>, I, O>(
/// # SAFETY
///
/// * data must not be NULL
/// * length must be a valid offset from data
/// (ideally it comes from the tskit-c API)
pub unsafe fn generate_slice<'a, L: Into<bindings::tsk_size_t>, I, O>(
data: *const I,
length: L,
) -> &'a [O] {
assert!(!data.is_null());
// SAFETY: pointer is not null, length comes from C API
unsafe { std::slice::from_raw_parts(data.cast::<O>(), length.into() as usize) }
std::slice::from_raw_parts(data.cast::<O>(), length.into() as usize)
}

pub fn generate_slice_mut<'a, L: Into<bindings::tsk_size_t>, I, O>(
/// # SAFETY
///
/// * data must not be NULL
/// * length must be a valid offset from data
/// (ideally it comes from the tskit-c API)
pub unsafe fn generate_slice_mut<'a, L: Into<bindings::tsk_size_t>, I, O>(
data: *mut I,
length: L,
) -> &'a mut [O] {
assert!(!data.is_null());
// SAFETY: pointer is not null, length comes from C API
unsafe { std::slice::from_raw_parts_mut(data.cast::<O>(), length.into() as usize) }
std::slice::from_raw_parts_mut(data.cast::<O>(), length.into() as usize)
}

pub fn get_tskit_error_message(code: i32) -> String {
Expand Down
91 changes: 59 additions & 32 deletions src/sys/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ impl<'treeseq> LLTree<'treeseq> {
pub fn samples_array(&self) -> Result<&[super::newtypes::NodeId], TskitError> {
err_if_not_tracking_samples!(
self.flags,
super::generate_slice(self.as_ll_ref().samples, self.num_samples())
// SAFETY: num_samples is the correct value
unsafe { super::generate_slice(self.as_ll_ref().samples, self.num_samples()) }
)
}

Expand Down Expand Up @@ -182,43 +183,63 @@ impl<'treeseq> LLTree<'treeseq> {

pub fn sample_nodes(&self) -> &[NodeId] {
assert!(!self.as_ptr().is_null());
// SAFETY: self ptr is not null and the tree is initialized
let num_samples =
unsafe { bindings::tsk_treeseq_get_num_samples(self.as_ll_ref().tree_sequence) };
super::generate_slice(self.as_ll_ref().samples, num_samples)
unsafe {
// SAFETY: self ptr is not null and the tree is initialized
// num_samples is the correct array length
let num_samples = bindings::tsk_treeseq_get_num_samples(self.as_ll_ref().tree_sequence);
super::generate_slice(self.as_ll_ref().samples, num_samples)
}
}

pub fn parent_array(&self) -> &[NodeId] {
super::generate_slice(self.as_ll_ref().parent, self.treeseq.num_nodes_raw() + 1)
// SAFETY: the array length is the number of nodes + 1 for the "virtual root"
unsafe { super::generate_slice(self.as_ll_ref().parent, self.treeseq.num_nodes_raw() + 1) }
}

pub fn left_sib_array(&self) -> &[NodeId] {
super::generate_slice(self.as_ll_ref().left_sib, self.treeseq.num_nodes_raw() + 1)
// SAFETY: the array length is the number of nodes + 1 for the "virtual root"
unsafe {
super::generate_slice(self.as_ll_ref().left_sib, self.treeseq.num_nodes_raw() + 1)
}
}

pub fn right_sib_array(&self) -> &[NodeId] {
super::generate_slice(self.as_ll_ref().right_sib, self.treeseq.num_nodes_raw() + 1)
// SAFETY: the array length is the number of nodes + 1 for the "virtual root"
unsafe {
super::generate_slice(self.as_ll_ref().right_sib, self.treeseq.num_nodes_raw() + 1)
}
}

pub fn left_child_array(&self) -> &[NodeId] {
super::generate_slice(
self.as_ll_ref().left_child,
self.treeseq.num_nodes_raw() + 1,
)
// SAFETY: the array length is the number of nodes + 1 for the "virtual root"
unsafe {
super::generate_slice(
self.as_ll_ref().left_child,
self.treeseq.num_nodes_raw() + 1,
)
}
}

pub fn right_child_array(&self) -> &[NodeId] {
super::generate_slice(
self.as_ll_ref().right_child,
self.treeseq.num_nodes_raw() + 1,
)
// SAFETY: the array length is the number of nodes + 1 for the "virtual root"
unsafe {
super::generate_slice(
self.as_ll_ref().right_child,
self.treeseq.num_nodes_raw() + 1,
)
}
}

pub fn total_branch_length(&self, by_span: bool) -> Result<Time, TskitError> {
let time: &[Time] = super::generate_slice(
unsafe { (*(*(*self.as_ptr()).tree_sequence).tables).nodes.time },
self.treeseq.num_nodes_raw() + 1,
);
assert!(!self.treeseq.as_ref().tables.is_null());
// SAFETY: array len is number of nodes + 1 for the "virtual root"
// tables ptr is not NULL
let time: &[Time] = unsafe {
super::generate_slice(
(*(self.treeseq.as_ref()).tables).nodes.time,
self.treeseq.num_nodes_raw() + 1,
)
};
let mut b = Time::from(0.);
for n in self.traverse_nodes(NodeTraversalOrder::Preorder) {
let p = self.parent(n).ok_or(TskitError::IndexError {})?;
Expand Down Expand Up @@ -246,32 +267,38 @@ impl<'treeseq> LLTree<'treeseq> {
}

pub fn left_sample_array(&self) -> Result<&[NodeId], TskitError> {
err_if_not_tracking_samples!(
self.flags,
err_if_not_tracking_samples!(self.flags, unsafe {
// SAFETY: array length is number of nodes + 1 for the "virtual root"
super::generate_slice(
self.as_ll_ref().left_sample,
self.treeseq.num_nodes_raw() + 1
self.treeseq.num_nodes_raw() + 1,
)
)
})
}

pub fn right_sample_array(&self) -> Result<&[NodeId], TskitError> {
err_if_not_tracking_samples!(
self.flags,
super::generate_slice(
self.as_ll_ref().right_sample,
self.treeseq.num_nodes_raw() + 1
)
// SAFETY: array length is number of nodes + 1 for the "virtual root"
unsafe {
super::generate_slice(
self.as_ll_ref().right_sample,
self.treeseq.num_nodes_raw() + 1,
)
}
)
}

pub fn next_sample_array(&self) -> Result<&[NodeId], TskitError> {
err_if_not_tracking_samples!(
self.flags,
super::generate_slice(
self.as_ll_ref().next_sample,
self.treeseq.num_nodes_raw() + 1
)
// SAFETY: array length is number of nodes + 1 for the "virtual root"
unsafe {
super::generate_slice(
self.as_ll_ref().next_sample,
self.treeseq.num_nodes_raw() + 1,
)
}
)
}

Expand Down
6 changes: 4 additions & 2 deletions src/trees/treeseq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,10 @@ impl TreeSequence {

/// Get the list of sample nodes as a slice.
pub fn sample_nodes(&self) -> &[NodeId] {
let num_samples = unsafe { ll_bindings::tsk_treeseq_get_num_samples(self.as_ptr()) };
sys::generate_slice(self.as_ref().samples, num_samples)
unsafe {
let num_samples = ll_bindings::tsk_treeseq_get_num_samples(self.as_ref());
sys::generate_slice(self.as_ref().samples, num_samples)
}
}

/// Get the number of trees.
Expand Down
Loading