Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 49 additions & 6 deletions src/_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ macro_rules! unsafe_tsk_column_access {

macro_rules! unsafe_tsk_ragged_column_access {
($i: expr, $lo: expr, $hi: expr, $array: expr, $offset_array: expr, $offset_array_len: expr) => {{
if $i < $lo || $i >= ($hi as tsk_id_t) {
use std::convert::TryFrom;
let i = crate::SizeType::try_from($i)?;
if $i < $lo || i >= $hi {
Err(TskitError::IndexError {})
} else if $offset_array_len == 0 {
Ok(None)
} else {
let start = unsafe { *$offset_array.offset($i as isize) };
let stop = if $i < ($hi as tsk_id_t) {
let stop = if i < $hi {
unsafe { *$offset_array.offset(($i + 1) as isize) }
} else {
$offset_array_len as tsk_size_t
Expand All @@ -70,13 +72,15 @@ macro_rules! unsafe_tsk_ragged_column_access {
}};

($i: expr, $lo: expr, $hi: expr, $array: expr, $offset_array: expr, $offset_array_len: expr, $output_id_type: expr) => {{
if $i < $lo || $i >= ($hi as tsk_id_t) {
use std::convert::TryFrom;
let i = crate::SizeType::try_from($i)?;
if $i < $lo || i >= $hi {
Err(TskitError::IndexError {})
} else if $offset_array_len == 0 {
Ok(None)
} else {
let start = unsafe { *$offset_array.offset($i as isize) };
let stop = if $i < ($hi as tsk_id_t) {
let stop = if i < $hi {
unsafe { *$offset_array.offset(($i + 1) as isize) }
} else {
$offset_array_len as tsk_size_t
Expand All @@ -99,13 +103,15 @@ macro_rules! unsafe_tsk_ragged_column_access {
#[allow(unused_macros)]
macro_rules! unsafe_tsk_ragged_char_column_access {
($i: expr, $lo: expr, $hi: expr, $array: expr, $offset_array: expr, $offset_array_len: expr) => {{
if $i < $lo || $i >= ($hi as tsk_id_t) {
use std::convert::TryFrom;
let i = crate::SizeType::try_from($i)?;
if $i < $lo || i >= $hi {
Err(TskitError::IndexError {})
} else if $offset_array_len == 0 {
Ok(None)
} else {
let start = unsafe { *$offset_array.offset($i as isize) };
let stop = if $i < ($hi as tsk_id_t) {
let stop = if i < $hi {
unsafe { *$offset_array.offset(($i + 1) as isize) }
} else {
$offset_array_len as tsk_size_t
Expand Down Expand Up @@ -299,6 +305,14 @@ macro_rules! impl_id_traits {
}
}

impl std::convert::TryFrom<$idtype> for crate::SizeType {
type Error = crate::TskitError;

fn try_from(value: $idtype) -> Result<Self, Self::Error> {
crate::SizeType::try_from(value.0)
}
}

impl PartialEq<$crate::tsk_id_t> for $idtype {
fn eq(&self, other: &$crate::tsk_id_t) -> bool {
self.0 == *other
Expand All @@ -325,6 +339,35 @@ macro_rules! impl_id_traits {
};
}

macro_rules! impl_size_type_comparisons_for_row_ids {
($idtype: ty) => {
impl PartialEq<$idtype> for crate::SizeType {
fn eq(&self, other: &$idtype) -> bool {
self.0 == other.0 as crate::bindings::tsk_size_t
}
}

impl PartialEq<crate::SizeType> for $idtype {
fn eq(&self, other: &crate::SizeType) -> bool {
(self.0 as crate::bindings::tsk_size_t) == other.0
}
}

impl PartialOrd<$idtype> for crate::SizeType {
fn partial_cmp(&self, other: &$idtype) -> Option<std::cmp::Ordering> {
self.0
.partial_cmp(&(other.0 as crate::bindings::tsk_size_t))
}
}

impl PartialOrd<crate::SizeType> for $idtype {
fn partial_cmp(&self, other: &crate::SizeType) -> Option<std::cmp::Ordering> {
(self.0 as crate::bindings::tsk_size_t).partial_cmp(&other.0)
}
}
};
}

/// Convenience macro to handle implementing
/// [`crate::metadata::MetadataRoundtrip`]
#[macro_export]
Expand Down
19 changes: 14 additions & 5 deletions src/edge_table.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::bindings as ll_bindings;
use crate::metadata;
use crate::{tsk_id_t, tsk_size_t, TskitError};
use crate::{tsk_id_t, TskitError};
use crate::{EdgeId, NodeId};

/// Row of an [`EdgeTable`]
Expand All @@ -25,7 +25,12 @@ impl PartialEq for EdgeTableRow {
}

fn make_edge_table_row(table: &EdgeTable, pos: tsk_id_t) -> Option<EdgeTableRow> {
if pos < table.num_rows() as tsk_id_t {
use std::convert::TryFrom;
// panic is okay here, as we are handling a bad
// input value before we first call this to
// set up the iterator
let p = crate::SizeType::try_from(pos).unwrap();
if p < table.num_rows() {
let rv = EdgeTableRow {
id: pos.into(),
left: table.left(pos).unwrap(),
Expand Down Expand Up @@ -78,8 +83,8 @@ impl<'a> EdgeTable<'a> {
}

/// Return the number of rows
pub fn num_rows(&'a self) -> tsk_size_t {
self.table_.num_rows
pub fn num_rows(&'a self) -> crate::SizeType {
self.table_.num_rows.into()
}

/// Return the ``parent`` value from row ``row`` of the table.
Expand Down Expand Up @@ -147,6 +152,10 @@ impl<'a> EdgeTable<'a> {
///
/// [`TskitError::IndexError`] if `r` is out of range.
pub fn row<E: Into<EdgeId> + Copy>(&self, r: E) -> Result<EdgeTableRow, TskitError> {
table_row_access!(r.into().0, self, make_edge_table_row)
let ri = r.into();
if ri < 0 {
return Err(crate::TskitError::IndexError);
}
table_row_access!(ri.0, self, make_edge_table_row)
}
}
3 changes: 3 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ use thiserror::Error;

#[derive(Error, Debug)]
pub enum TskitError {
/// Returned when conversion attempts fail
#[error("range error: {}", *.0)]
RangeError(&'static str),
/// Used when bad input is encountered.
#[error("we received {} but expected {}",*got, *expected)]
ValueError { got: String, expected: String },
Expand Down
17 changes: 13 additions & 4 deletions src/individual_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ pub struct IndividualTable<'a> {
}

fn make_individual_table_row(table: &IndividualTable, pos: tsk_id_t) -> Option<IndividualTableRow> {
if pos < table.num_rows() as tsk_id_t {
use std::convert::TryFrom;
// panic is okay here, as we are handling a bad
// input value before we first call this to
// set up the iterator
let p = crate::SizeType::try_from(pos).unwrap();
if p < table.num_rows() {
let rv = IndividualTableRow {
id: pos.into(),
flags: table.flags(pos).unwrap(),
Expand Down Expand Up @@ -96,8 +101,8 @@ impl<'a> IndividualTable<'a> {
}

/// Return the number of rows
pub fn num_rows(&'a self) -> ll_bindings::tsk_size_t {
self.table_.num_rows
pub fn num_rows(&'a self) -> crate::SizeType {
self.table_.num_rows.into()
}

/// Return the flags for a given row.
Expand Down Expand Up @@ -181,6 +186,10 @@ impl<'a> IndividualTable<'a> {
&self,
r: I,
) -> Result<IndividualTableRow, TskitError> {
table_row_access!(r.into().0, self, make_individual_table_row)
let ri = r.into();
if ri < 0 {
return Err(crate::TskitError::IndexError);
}
table_row_access!(ri.0, self, make_individual_table_row)
}
}
115 changes: 111 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ pub use bindings::TSK_NODE_IS_SAMPLE;

// re-export types, too
pub use bindings::tsk_flags_t;
pub use bindings::tsk_id_t;
pub use bindings::tsk_size_t;

use bindings::tsk_id_t;
use bindings::tsk_size_t;

/// A node ID
///
Expand All @@ -116,7 +117,7 @@ pub use bindings::tsk_size_t;
///
/// ```
/// use tskit::NodeId;
/// use tskit::tsk_id_t;
/// use tskit::bindings::tsk_id_t;
///
/// let x: tsk_id_t = 1;
/// let y: NodeId = NodeId::from(x);
Expand All @@ -137,7 +138,7 @@ pub use bindings::tsk_size_t;
///
/// ```
/// use tskit::NodeId;
/// use tskit::tsk_id_t;
/// use tskit::bindings::tsk_id_t;
///
/// fn interesting<N: Into<NodeId>>(x: N) -> NodeId {
/// x.into()
Expand Down Expand Up @@ -226,6 +227,112 @@ impl_id_traits!(MutationId);
impl_id_traits!(MigrationId);
impl_id_traits!(EdgeId);

impl_size_type_comparisons_for_row_ids!(NodeId);
impl_size_type_comparisons_for_row_ids!(EdgeId);
impl_size_type_comparisons_for_row_ids!(SiteId);
impl_size_type_comparisons_for_row_ids!(MutationId);
impl_size_type_comparisons_for_row_ids!(PopulationId);
impl_size_type_comparisons_for_row_ids!(MigrationId);

/// Wraps `tsk_size_t`
///
/// This type plays the role of C's `size_t` in the `tskit` C library.
///
/// # Examples
///
/// ```
/// let s = tskit::SizeType::from(1 as tskit::bindings::tsk_size_t);
/// let mut t: tskit::bindings::tsk_size_t = s.into();
/// assert!(t == s);
/// assert!(t == 1);
/// let u = tskit::SizeType::from(s);
/// assert!(u == s);
/// t += 1;
/// assert!(t > s);
/// assert!(s < t);
/// ```
///
/// #[repr(transparent)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, std::hash::Hash)]
pub struct SizeType(tsk_size_t);

impl std::fmt::Display for SizeType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SizeType({})", self.0)
}
}

impl From<tsk_size_t> for SizeType {
fn from(value: tsk_size_t) -> Self {
Self(value)
}
}

impl From<SizeType> for tsk_size_t {
fn from(value: SizeType) -> Self {
value.0
}
}

impl From<SizeType> for usize {
fn from(value: SizeType) -> Self {
value.0 as usize
}
}

impl From<usize> for SizeType {
fn from(value: usize) -> Self {
Self(value as tsk_size_t)
}
}

impl std::convert::TryFrom<tsk_id_t> for SizeType {
type Error = crate::TskitError;

fn try_from(value: tsk_id_t) -> Result<Self, Self::Error> {
match value >= 0 {
true => Ok(Self(value as crate::bindings::tsk_size_t)),
false => Err(crate::TskitError::RangeError(stringify!(value.0))),
}
}
}

impl std::convert::TryFrom<SizeType> for tsk_id_t {
type Error = crate::TskitError;

fn try_from(value: SizeType) -> Result<Self, Self::Error> {
if value.0 > tsk_id_t::MAX as tsk_size_t {
Err(TskitError::RangeError(stringify!(value.0)))
} else {
Ok(value.0 as tsk_id_t)
}
}
}

impl PartialEq<SizeType> for tsk_size_t {
fn eq(&self, other: &SizeType) -> bool {
*self == other.0
}
}

impl PartialEq<tsk_size_t> for SizeType {
fn eq(&self, other: &tsk_size_t) -> bool {
self.0 == *other
}
}

impl PartialOrd<tsk_size_t> for SizeType {
fn partial_cmp(&self, other: &tsk_size_t) -> Option<std::cmp::Ordering> {
self.0.partial_cmp(other)
}
}

impl PartialOrd<SizeType> for tsk_size_t {
fn partial_cmp(&self, other: &SizeType) -> Option<std::cmp::Ordering> {
self.partial_cmp(&other.0)
}
}

// tskit defines this via a type cast
// in a macro. bindgen thus misses it.
// See bindgen issue 316.
Expand Down
13 changes: 7 additions & 6 deletions src/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@
//! into `Python` via the `tskit` `Python API`.

use crate::bindings::{tsk_id_t, tsk_size_t};
use crate::SizeType;
use thiserror::Error;

#[cfg(feature = "derive")]
Expand Down Expand Up @@ -233,8 +234,8 @@ impl EncodedMetadata {
}
}

pub(crate) fn len(&self) -> tsk_size_t {
self.encoded.len() as tsk_size_t
pub(crate) fn len(&self) -> SizeType {
self.encoded.len().into()
}
}

Expand Down Expand Up @@ -344,8 +345,8 @@ mod tests {
let enc = EncodedMetadata::new(&f).unwrap();
let p = enc.as_ptr();
let mut d = vec![];
for i in 0..enc.len() {
d.push(unsafe { *p.add(i as usize) as u8 });
for i in 0..usize::from(enc.len()) {
d.push(unsafe { *p.add(i) as u8 });
}
let df = F::decode(&d).unwrap();
assert_eq!(f.x, df.x);
Expand Down Expand Up @@ -378,8 +379,8 @@ mod test_serde {
let enc = EncodedMetadata::new(&f).unwrap();
let p = enc.as_ptr();
let mut d = vec![];
for i in 0..enc.len() {
d.push(unsafe { *p.add(i as usize) as u8 });
for i in 0..usize::from(enc.len()) {
d.push(unsafe { *p.add(i) as u8 });
}
let df = F::decode(&d).unwrap();
assert_eq!(f.x, df.x);
Expand Down
Loading