diff --git a/src/_macros.rs b/src/_macros.rs index 181f57b4a..4d4d1ae92 100644 --- a/src/_macros.rs +++ b/src/_macros.rs @@ -30,6 +30,34 @@ macro_rules! unsafe_tsk_column_access { }}; } +macro_rules! build_tskit_type { + ($name: ident, $ll_name: ty, $drop: ident) => { + impl Drop for $name { + fn drop(&mut self) { + let rv = unsafe { $drop(&mut *self.inner) }; + panic_on_tskit_error!(rv); + } + } + + impl crate::ffi::TskitType<$ll_name> for $name { + fn wrap() -> Self { + let temp: std::mem::MaybeUninit<$ll_name> = std::mem::MaybeUninit::uninit(); + $name { + inner: unsafe { Box::<$ll_name>::new(temp.assume_init()) }, + } + } + + fn as_ptr(&self) -> *const $ll_name { + &*self.inner + } + + fn as_mut_ptr(&mut self) -> *mut $ll_name { + &mut *self.inner + } + } + }; +} + #[cfg(test)] mod test { use crate::error::TskitRustError; diff --git a/src/ffi.rs b/src/ffi.rs new file mode 100644 index 000000000..80e62989f --- /dev/null +++ b/src/ffi.rs @@ -0,0 +1,56 @@ +//! Define traits related to wrapping tskit stuff + +/// Define what it means to wrap a tskit struct. +/// The implementation of Drop should call the +/// tsk_foo_free() function corresponding +/// to tsk_foo_t. +pub trait TskitType: Drop { + /// Encapsulate tsk_foo_t and return rust + /// object. Best practices seem to + /// suggest using Box for this. + fn wrap() -> Self; + /// Return const pointer + fn as_ptr(&self) -> *const T; + /// Return mutable pointer + fn as_mut_ptr(&mut self) -> *mut T; +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::bindings as ll_bindings; + use ll_bindings::tsk_table_collection_free; + + pub struct TableCollectionMock { + inner: Box, + } + + build_tskit_type!( + TableCollectionMock, + ll_bindings::tsk_table_collection_t, + tsk_table_collection_free + ); + + impl TableCollectionMock { + fn new(len: f64) -> Self { + let mut s = Self::wrap(); + + let rv = unsafe { ll_bindings::tsk_table_collection_init(s.as_mut_ptr(), 0) }; + assert_eq!(rv, 0); + + s.inner.sequence_length = len; + + s + } + + fn sequence_length(&self) -> f64 { + unsafe { (*self.as_ptr()).sequence_length } + } + } + + #[test] + fn test_create_mock_type() { + let t = TableCollectionMock::new(10.); + assert_eq!(t.sequence_length() as i64, 10); + } +} diff --git a/src/lib.rs b/src/lib.rs index dad06ea3a..6328bfb2c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,6 +9,7 @@ pub mod bindings; mod _macros; // Starts w/_ to be sorted at front by rustfmt! mod edge_table; pub mod error; +pub mod ffi; mod mutation_table; mod node_table; mod population_table; @@ -53,4 +54,3 @@ pub fn version() -> &'static str { // Testing modules mod test_tsk_variables; - diff --git a/src/table_collection.rs b/src/table_collection.rs index 530c2737a..fa791d51e 100644 --- a/src/table_collection.rs +++ b/src/table_collection.rs @@ -1,5 +1,6 @@ use crate::bindings as ll_bindings; use crate::error::TskitRustError; +use crate::ffi::TskitType; use crate::types::Bookmark; use crate::EdgeTable; use crate::MutationTable; @@ -8,19 +9,7 @@ use crate::PopulationTable; use crate::SiteTable; use crate::TskReturnValue; use crate::{tsk_flags_t, tsk_id_t, tsk_size_t}; - -/// Handle allocation details. -fn new_tsk_table_collection_t() -> Result, TskitRustError> -{ - let mut tsk_tables: std::mem::MaybeUninit = - std::mem::MaybeUninit::uninit(); - let rv = unsafe { ll_bindings::tsk_table_collection_init(tsk_tables.as_mut_ptr(), 0) }; - if rv < 0 { - return Err(TskitRustError::ErrorCode { code: rv }); - } - let rv = unsafe { Box::::new(tsk_tables.assume_init()) }; - Ok(rv) -} +use ll_bindings::tsk_table_collection_free; /// A table collection. /// @@ -69,9 +58,15 @@ fn new_tsk_table_collection_t() -> Result, + inner: Box, } +build_tskit_type!( + TableCollection, + ll_bindings::tsk_table_collection_t, + tsk_table_collection_free +); + impl TableCollection { /// Create a new table collection with a sequence length. pub fn new(sequence_length: f64) -> Result { @@ -81,16 +76,13 @@ impl TableCollection { expected: "sequence_length >= 0.0".to_string(), }); } - let tables = new_tsk_table_collection_t(); - match tables { - Ok(_) => (), - Err(e) => return Err(e), + let mut tables = Self::wrap(); + let rv = unsafe { ll_bindings::tsk_table_collection_init(tables.as_mut_ptr(), 0) }; + if rv < 0 { + return Err(crate::error::TskitRustError::ErrorCode { code: rv }); } - let mut rv = TableCollection { - tables: tables.unwrap(), - }; - rv.tables.sequence_length = sequence_length; - Ok(rv) + tables.inner.sequence_length = sequence_length; + Ok(tables) } /// Load a table collection from a file. @@ -119,16 +111,6 @@ impl TableCollection { } } - /// Access to raw C pointer as const tsk_table_collection_t *. - pub fn as_ptr(&self) -> *const ll_bindings::tsk_table_collection_t { - &*self.tables - } - - /// Access to raw C pointer as tsk_table_collection_t *. - pub fn as_mut_ptr(&mut self) -> *mut ll_bindings::tsk_table_collection_t { - &mut *self.tables - } - /// Length of the sequence/"genome". pub fn sequence_length(&self) -> f64 { unsafe { (*self.as_ptr()).sequence_length } @@ -138,35 +120,35 @@ impl TableCollection { /// Lifetime of return value is tied to (this) /// parent object. pub fn edges<'a>(&'a self) -> EdgeTable<'a> { - EdgeTable::<'a>::new_from_table(&self.tables.edges) + EdgeTable::<'a>::new_from_table(&self.inner.edges) } /// Get reference to the [``NodeTable``](crate::NodeTable). /// Lifetime of return value is tied to (this) /// parent object. pub fn nodes<'a>(&'a self) -> NodeTable<'a> { - NodeTable::<'a>::new_from_table(&self.tables.nodes) + NodeTable::<'a>::new_from_table(&self.inner.nodes) } /// Get reference to the [``SiteTable``](crate::SiteTable). /// Lifetime of return value is tied to (this) /// parent object. pub fn sites<'a>(&'a self) -> SiteTable<'a> { - SiteTable::<'a>::new_from_table(&self.tables.sites) + SiteTable::<'a>::new_from_table(&self.inner.sites) } /// Get reference to the [``MutationTable``](crate::MutationTable). /// Lifetime of return value is tied to (this) /// parent object. pub fn mutations<'a>(&'a self) -> MutationTable<'a> { - MutationTable::<'a>::new_from_table(&self.tables.mutations) + MutationTable::<'a>::new_from_table(&self.inner.mutations) } /// Get reference to the [``PopulationTable``](crate::PopulationTable). /// Lifetime of return value is tied to (this) /// parent object. pub fn populations<'a>(&'a self) -> PopulationTable<'a> { - PopulationTable::<'a>::new_from_table(&self.tables.populations) + PopulationTable::<'a>::new_from_table(&self.inner.populations) } /// Add a row to the edge table @@ -347,13 +329,6 @@ impl TableCollection { } } -impl Drop for TableCollection { - fn drop(&mut self) { - let rv = unsafe { ll_bindings::tsk_table_collection_free(&mut *self.tables) }; - panic_on_tskit_error!(rv); - } -} - #[cfg(test)] mod test { use super::*;