From 1b5988587600ba00a7a2a7f33af30b02423ff467 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 12 Nov 2025 16:02:05 +0100 Subject: [PATCH 1/5] chore: update dependencies --- Cargo.toml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ed7ebd9..059892d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,9 +33,9 @@ zarrs = { version = "0.22.0", features = [ "sharding", "async", ], optional = true } -ndarray = { version = "0.16.1", optional = true } -arrow = { version = "56.2.0", optional = true } -arrow-schema = { version = "56.2.0", features = [ +ndarray = { version = "0.17.1", optional = true } +arrow = { version = "57.0.0", optional = true } +arrow-schema = { version = "57.0.0", features = [ "canonical_extension_types", ], optional = true } nuts-derive = { path = "./nuts-derive", version = "0.1.0" } @@ -50,9 +50,9 @@ pretty_assertions = "1.4.0" criterion = "0.7.0" nix = { version = "0.30.0", features = ["sched"] } approx = "0.5.1" -equator = "0.4.2" +equator = "0.4.0" serde_json = "1.0" -ndarray = "0.16.1" +ndarray = "0.17.1" tempfile = "3.0" zarrs_object_store = "0.5.0" object_store = "0.12.0" From 494769e125901aeba80f676aceb8c7bc24c584b1 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 15 Oct 2025 17:36:34 +0200 Subject: [PATCH 2/5] fix: store step size info in transform_adapt_strategy --- src/transform_adapt_strategy.rs | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/transform_adapt_strategy.rs b/src/transform_adapt_strategy.rs index 6cffd93..e109536 100644 --- a/src/transform_adapt_strategy.rs +++ b/src/transform_adapt_strategy.rs @@ -1,11 +1,12 @@ use nuts_derive::Storable; +use nuts_storable::{HasDims, Storable}; use serde::Serialize; use crate::adapt_strategy::CombinedCollector; use crate::chain::AdaptStrategy; use crate::hamiltonian::{Hamiltonian, Point}; use crate::nuts::{Collector, NutsOptions, SampleInfo}; -use crate::sampler_stats::SamplerStats; +use crate::sampler_stats::{SamplerStats, StatsDims}; use crate::state::State; use crate::stepsize::AcceptanceRateCollector; use crate::stepsize::{StepSizeSettings, Strategy as StepSizeStrategy}; @@ -43,17 +44,23 @@ pub struct TransformAdaptation { } #[derive(Debug, Storable)] -pub struct Stats { +pub struct Stats> { tuning: bool, + #[storable(flatten)] + pub step_size: S, + #[storable(ignore)] + _phantom: std::marker::PhantomData P>, } impl SamplerStats for TransformAdaptation { - type Stats = Stats; + type Stats = Stats>::Stats>; type StatsOptions = (); - fn extract_stats(&self, _math: &mut M, _opt: Self::StatsOptions) -> Self::Stats { + fn extract_stats(&self, math: &mut M, _opt: Self::StatsOptions) -> Self::Stats { Stats { tuning: self.tuning, + step_size: { self.step_size.extract_stats(math, ()) }, + _phantom: std::marker::PhantomData, } } } From 18cd93799398a0c04d3d363c3cb7e7d313772b5a Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 10 Nov 2025 12:07:32 +0100 Subject: [PATCH 3/5] feat: support datetime coordinates --- nuts-storable/src/lib.rs | 12 ++++++++++ src/lib.rs | 2 +- src/storage/arrow.rs | 24 +++++++++++++++++++ src/storage/csv.rs | 2 ++ src/storage/hashmap.rs | 6 +++++ src/storage/ndarray.rs | 3 +++ src/storage/zarr/async_impl.rs | 44 ++++++++++++++++++++++++++++++++++ src/storage/zarr/common.rs | 25 ++++++++++++++++++- src/storage/zarr/sync_impl.rs | 41 ++++++++++++++++++++++++++++++- 9 files changed, 156 insertions(+), 3 deletions(-) diff --git a/nuts-storable/src/lib.rs b/nuts-storable/src/lib.rs index 032cb0c..67fe13c 100644 --- a/nuts-storable/src/lib.rs +++ b/nuts-storable/src/lib.rs @@ -1,5 +1,13 @@ use std::collections::HashMap; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DateTimeUnit { + Seconds, + Milliseconds, + Microseconds, + Nanoseconds, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ItemType { U64, @@ -8,6 +16,8 @@ pub enum ItemType { F32, Bool, String, + DateTime64(DateTimeUnit), + TimeDelta64(DateTimeUnit), } #[derive(Debug, Clone, PartialEq)] @@ -18,6 +28,8 @@ pub enum Value { F32(Vec), Bool(Vec), ScalarString(String), + DateTime64(DateTimeUnit, Vec), + TimeDelta64(DateTimeUnit, Vec), ScalarU64(u64), ScalarI64(i64), ScalarF64(f64), diff --git a/src/lib.rs b/src/lib.rs index 0a2a61b..637aabe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -117,7 +117,7 @@ mod transform_adapt_strategy; mod transformed_hamiltonian; pub use nuts_derive::Storable; -pub use nuts_storable::{HasDims, ItemType, Storable, Value}; +pub use nuts_storable::{DateTimeUnit, HasDims, ItemType, Storable, Value}; pub use adapt_strategy::EuclideanAdaptOptions; pub use chain::Chain; diff --git a/src/storage/arrow.rs b/src/storage/arrow.rs index c1aa571..da0845b 100644 --- a/src/storage/arrow.rs +++ b/src/storage/arrow.rs @@ -32,6 +32,12 @@ impl ArrowBuilder { ItemType::I64 => Box::new(Int64Builder::with_capacity(capacity)), ItemType::U64 => Box::new(UInt64Builder::with_capacity(capacity)), ItemType::String => Box::new(StringBuilder::with_capacity(capacity, capacity)), + ItemType::DateTime64(_) => { + panic!("DateTime values not supported as values in arrow storage") + } + ItemType::TimeDelta64(_) => { + panic!("TimeDelta values not supported as values in arrow storage") + } }; if shape.is_empty() { @@ -100,6 +106,12 @@ impl ArrowBuilder { string_builder.append_value(&item); } } + Value::DateTime64(_, _) => { + panic!("DateTime64 scalar values not supported in arrow storage") + } + Value::TimeDelta64(_, _) => { + panic!("TimeDelta64 scalar values not supported in arrow storage") + } }, ArrowBuilder::Tensor(list_builder) => { match value { @@ -154,6 +166,12 @@ impl ArrowBuilder { downcast_builder!(list_builder.values(), BooleanBuilder, ScalarBool)? .append_value(val); } + Value::DateTime64(_, _) => { + panic!("DateTime64 scalar values not supported in arrow storage") + } + Value::TimeDelta64(_, _) => { + panic!("TimeDelta64 scalar values not supported in arrow storage") + } } list_builder.append(true); } @@ -211,6 +229,12 @@ fn item_type_to_arrow_type(item_type: ItemType) -> DataType { ItemType::I64 => DataType::Int64, ItemType::Bool => DataType::Boolean, ItemType::String => DataType::Utf8, + ItemType::DateTime64(_) => { + panic!("DateTime64 scalar values not supported in arrow storage") + } + ItemType::TimeDelta64(_) => { + panic!("TimeDelta64 scalar values not supported in arrow storage") + } } } diff --git a/src/storage/csv.rs b/src/storage/csv.rs index 005e525..12ca9b8 100644 --- a/src/storage/csv.rs +++ b/src/storage/csv.rs @@ -220,6 +220,8 @@ impl CsvChainStorage { vec[0].clone() } } + Value::DateTime64(_, _) => panic!("DateTime64 not supported in CSV output"), + Value::TimeDelta64(_, _) => panic!("TimeDelta64 not supported in CSV output"), } } diff --git a/src/storage/hashmap.rs b/src/storage/hashmap.rs index 204b8d7..7364957 100644 --- a/src/storage/hashmap.rs +++ b/src/storage/hashmap.rs @@ -26,6 +26,7 @@ impl HashMapValue { ItemType::I64 => HashMapValue::I64(Vec::new()), ItemType::U64 => HashMapValue::U64(Vec::new()), ItemType::String => HashMapValue::String(Vec::new()), + ItemType::DateTime64(_) | ItemType::TimeDelta64(_) => HashMapValue::I64(Vec::new()), } } @@ -45,6 +46,11 @@ impl HashMapValue { (HashMapValue::Bool(vec), Value::Bool(v)) => vec.extend(v), (HashMapValue::I64(vec), Value::I64(v)) => vec.extend(v), + (HashMapValue::String(vec), Value::Strings(v)) => vec.extend(v), + (HashMapValue::String(vec), Value::ScalarString(v)) => vec.push(v), + (HashMapValue::I64(vec), Value::DateTime64(_, v)) => vec.extend(v), + (HashMapValue::I64(vec), Value::TimeDelta64(_, v)) => vec.extend(v), + _ => panic!("Mismatched item type"), } } diff --git a/src/storage/ndarray.rs b/src/storage/ndarray.rs index 39e2d93..2f9b65b 100644 --- a/src/storage/ndarray.rs +++ b/src/storage/ndarray.rs @@ -30,6 +30,9 @@ impl NdarrayValue { ItemType::String => { NdarrayValue::String(ArrayD::from_elem(IxDyn(shape), String::new())) } + ItemType::DateTime64(_) | ItemType::TimeDelta64(_) => { + NdarrayValue::I64(ArrayD::zeros(IxDyn(shape))) + } } } diff --git a/src/storage/zarr/async_impl.rs b/src/storage/zarr/async_impl.rs index a361aec..69a59d3 100644 --- a/src/storage/zarr/async_impl.rs +++ b/src/storage/zarr/async_impl.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::iter::once; +use std::num::NonZero; use std::sync::Arc; use tokio::task::JoinHandle; @@ -8,6 +9,7 @@ use nuts_storable::{ItemType, Value}; use zarrs::array::{ArrayBuilder, DataType, FillValue}; use zarrs::array_subset::ArraySubset; use zarrs::group::GroupBuilder; +use zarrs::metadata_ext::data_type::NumpyTimeUnit; use zarrs::storage::{ AsyncReadableWritableListableStorage, AsyncReadableWritableListableStorageTraits, }; @@ -140,6 +142,38 @@ async fn store_coords( &Value::I64(ref v) => (DataType::Int64, v.len(), FillValue::from(0i64)), &Value::Bool(ref v) => (DataType::Bool, v.len(), FillValue::from(false)), &Value::Strings(ref v) => (DataType::String, v.len(), FillValue::from("")), + &Value::DateTime64(unit, ref v) => { + let unit = match unit { + nuts_storable::DateTimeUnit::Seconds => NumpyTimeUnit::Second, + nuts_storable::DateTimeUnit::Milliseconds => NumpyTimeUnit::Millisecond, + nuts_storable::DateTimeUnit::Microseconds => NumpyTimeUnit::Microsecond, + nuts_storable::DateTimeUnit::Nanoseconds => NumpyTimeUnit::Nanosecond, + }; + ( + DataType::NumpyDateTime64 { + unit, + scale_factor: NonZero::new(1).unwrap(), + }, + v.len(), + FillValue::from(0i64), + ) + } + &Value::TimeDelta64(unit, ref v) => { + let unit = match unit { + nuts_storable::DateTimeUnit::Seconds => NumpyTimeUnit::Second, + nuts_storable::DateTimeUnit::Milliseconds => NumpyTimeUnit::Millisecond, + nuts_storable::DateTimeUnit::Microseconds => NumpyTimeUnit::Microsecond, + nuts_storable::DateTimeUnit::Nanoseconds => NumpyTimeUnit::Nanosecond, + }; + ( + DataType::NumpyTimeDelta64 { + unit, + scale_factor: NonZero::new(1).unwrap(), + }, + v.len(), + FillValue::from(0i64), + ) + } _ => panic!("Unsupported coordinate type for {}", name), }; let name: &String = name; @@ -179,6 +213,16 @@ async fn store_coords( .async_store_chunk_elements::(&subset, v) .await? } + &Value::DateTime64(_, ref data) => { + coord_array + .async_store_chunk_elements::(&subset, data) + .await? + } + &Value::TimeDelta64(_, ref data) => { + coord_array + .async_store_chunk_elements::(&subset, data) + .await? + } _ => unreachable!(), } coord_array.async_store_metadata().await?; diff --git a/src/storage/zarr/common.rs b/src/storage/zarr/common.rs index 5e5d222..042b148 100644 --- a/src/storage/zarr/common.rs +++ b/src/storage/zarr/common.rs @@ -1,10 +1,11 @@ -use std::collections::HashMap; use std::mem::replace; use std::sync::Arc; +use std::{collections::HashMap, num::NonZero}; use anyhow::Result; use nuts_storable::{ItemType, Value}; use zarrs::array::{Array, ArrayBuilder, DataType, FillValue}; +use zarrs::metadata_ext::data_type::NumpyTimeUnit; /// Container for different types of sample values #[derive(Clone, Debug)] @@ -51,6 +52,8 @@ impl SampleBuffer { ItemType::Bool => SampleBufferValue::Bool(Vec::with_capacity(chunk_size)), ItemType::I64 => SampleBufferValue::I64(Vec::with_capacity(chunk_size)), ItemType::String => panic!("String type not supported in SampleBuffer"), + ItemType::DateTime64(_) => panic!("DateTime64 type not supported in SampleBuffer"), + ItemType::TimeDelta64(_) => panic!("TimeDelta64 type not supported in SampleBuffer"), }; Self { items: inner, @@ -196,6 +199,24 @@ pub fn create_arrays( ItemType::I64 => DataType::Int64, ItemType::Bool => DataType::Bool, ItemType::String => DataType::String, + ItemType::DateTime64(unit) => DataType::NumpyDateTime64 { + unit: match unit { + nuts_storable::DateTimeUnit::Seconds => NumpyTimeUnit::Second, + nuts_storable::DateTimeUnit::Milliseconds => NumpyTimeUnit::Millisecond, + nuts_storable::DateTimeUnit::Microseconds => NumpyTimeUnit::Microsecond, + nuts_storable::DateTimeUnit::Nanoseconds => NumpyTimeUnit::Nanosecond, + }, + scale_factor: NonZero::new(1).unwrap(), + }, + ItemType::TimeDelta64(unit) => DataType::NumpyTimeDelta64 { + unit: match unit { + nuts_storable::DateTimeUnit::Seconds => NumpyTimeUnit::Second, + nuts_storable::DateTimeUnit::Milliseconds => NumpyTimeUnit::Millisecond, + nuts_storable::DateTimeUnit::Microseconds => NumpyTimeUnit::Microsecond, + nuts_storable::DateTimeUnit::Nanoseconds => NumpyTimeUnit::Nanosecond, + }, + scale_factor: NonZero::new(1).unwrap(), + }, }; let fill_value = match item_type { ItemType::F64 => FillValue::from(f64::NAN), @@ -204,6 +225,8 @@ pub fn create_arrays( ItemType::I64 => FillValue::from(0i64), ItemType::Bool => FillValue::from(false), ItemType::String => FillValue::from(""), + ItemType::DateTime64(_) => FillValue::new_null(), + ItemType::TimeDelta64(_) => FillValue::new_null(), }; let grid: Vec = std::iter::once(1) .chain(std::iter::once(draw_chunk_size)) diff --git a/src/storage/zarr/sync_impl.rs b/src/storage/zarr/sync_impl.rs index 0966c65..106287d 100644 --- a/src/storage/zarr/sync_impl.rs +++ b/src/storage/zarr/sync_impl.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::iter::once; +use std::num::NonZero; use std::sync::Arc; use anyhow::{Context, Result}; @@ -7,6 +8,7 @@ use nuts_storable::{ItemType, Value}; use zarrs::array::{ArrayBuilder, DataType, FillValue}; use zarrs::array_subset::ArraySubset; use zarrs::group::GroupBuilder; +use zarrs::metadata_ext::data_type::NumpyTimeUnit; use zarrs::storage::{ReadableWritableListableStorage, ReadableWritableListableStorageTraits}; use super::common::{Chunk, SampleBuffer, SampleBufferValue}; @@ -37,7 +39,38 @@ pub fn store_coords( &Value::I64(ref v) => (DataType::Int64, v.len(), FillValue::from(0i64)), &Value::Bool(ref v) => (DataType::Bool, v.len(), FillValue::from(false)), &Value::Strings(ref v) => (DataType::String, v.len(), FillValue::from("")), - _ => panic!("Unsupported coordinate type for {}", name), + &Value::DateTime64(ref unit, ref data) => ( + DataType::NumpyDateTime64 { + unit: match unit { + nuts_storable::DateTimeUnit::Seconds => NumpyTimeUnit::Second, + nuts_storable::DateTimeUnit::Milliseconds => NumpyTimeUnit::Millisecond, + nuts_storable::DateTimeUnit::Microseconds => NumpyTimeUnit::Microsecond, + nuts_storable::DateTimeUnit::Nanoseconds => NumpyTimeUnit::Nanosecond, + }, + scale_factor: NonZero::new(1).unwrap(), + }, + data.len(), + FillValue::new_null(), + ), + &Value::TimeDelta64(ref unit, ref data) => ( + DataType::NumpyTimeDelta64 { + unit: match unit { + nuts_storable::DateTimeUnit::Seconds => NumpyTimeUnit::Second, + nuts_storable::DateTimeUnit::Milliseconds => NumpyTimeUnit::Millisecond, + nuts_storable::DateTimeUnit::Microseconds => NumpyTimeUnit::Microsecond, + nuts_storable::DateTimeUnit::Nanoseconds => NumpyTimeUnit::Nanosecond, + }, + scale_factor: NonZero::new(1).unwrap(), + }, + data.len(), + FillValue::new_null(), + ), + _ => { + return Err(anyhow::anyhow!( + "Unsupported coordinate type for coordinate {}", + name + )); + } }; let name: &String = name; @@ -53,6 +86,12 @@ pub fn store_coords( &Value::I64(ref v) => coord_array.store_chunk_elements::(&subset, v)?, &Value::Bool(ref v) => coord_array.store_chunk_elements::(&subset, v)?, &Value::Strings(ref v) => coord_array.store_chunk_elements::(&subset, v)?, + &Value::DateTime64(_, ref data) => { + coord_array.store_chunk_elements::(&subset, data)? + } + &Value::TimeDelta64(_, ref data) => { + coord_array.store_chunk_elements::(&subset, data)? + } _ => unreachable!(), } coord_array.store_metadata()?; From 25f4d75b5106d5b2971e1915f5806c3145367349 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 10 Nov 2025 12:07:32 +0100 Subject: [PATCH 4/5] fix: mindepth when check_turning=True was misbehaving --- src/nuts.rs | 42 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/src/nuts.rs b/src/nuts.rs index 6c5af53..8cb94ed 100644 --- a/src/nuts.rs +++ b/src/nuts.rs @@ -257,6 +257,19 @@ pub struct NutsOptions { pub store_divergences: bool, } +impl Default for NutsOptions { + fn default() -> Self { + NutsOptions { + maxdepth: 10, + mindepth: 0, + store_gradient: false, + store_unconstrained: false, + check_turning: true, + store_divergences: false, + } + } +} + pub(crate) fn draw( math: &mut M, init: &mut State, @@ -282,18 +295,31 @@ where return Ok((init.clone(), info)); } + let options_no_check = NutsOptions { + check_turning: false, + ..*options + }; + while tree.depth < options.maxdepth { let direction: Direction = rng.random(); - tree = match tree.extend(math, rng, hamiltonian, direction, collector, options) { + let current_options = if tree.depth < options.mindepth { + &options_no_check + } else { + options + }; + tree = match tree.extend( + math, + rng, + hamiltonian, + direction, + collector, + current_options, + ) { ExtendResult::Ok(tree) => tree, ExtendResult::Turning(tree) => { - if tree.depth < options.mindepth { - tree - } else { - let info = tree.info(false, None); - collector.register_draw(math, &tree.draw, &info); - return Ok((tree.draw, info)); - } + let info = tree.info(false, None); + collector.register_draw(math, &tree.draw, &info); + return Ok((tree.draw, info)); } ExtendResult::Diverging(tree, info) => { let info = tree.info(false, Some(info)); From 8184846f2e40e5d794ab423c5249b87b5d9b54b0 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 10 Nov 2025 12:07:32 +0100 Subject: [PATCH 5/5] chore(release): bump version --- CHANGELOG.md | 23 +++++++++++++++++++++++ Cargo.toml | 2 +- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dcdde05..67ce4f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,25 @@ All notable changes to this project will be documented in this file. +## [0.17.1] - 2025-11-13 + +### Bug Fixes + +- Store step size info in transform_adapt_strategy (Adrian Seyboldt) + +- Mindepth when check_turning=True was misbehaving (Adrian Seyboldt) + + +### Features + +- Support datetime coordinates (Adrian Seyboldt) + + +### Miscellaneous Tasks + +- Update dependencies (Adrian Seyboldt) + + ## [0.17.0] - 2025-10-08 ### Bug Fixes @@ -52,6 +71,10 @@ All notable changes to this project will be documented in this file. - Update dependencies (Adrian Seyboldt) +- Prepare 0.17.0 (Adrian Seyboldt) + +- Correctly specify dependencies in workspace (Adrian Seyboldt) + ### Performance diff --git a/Cargo.toml b/Cargo.toml index 059892d..ea99bfb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "nuts-rs" -version = "0.17.0" +version = "0.17.1" authors = [ "Adrian Seyboldt ", "PyMC Developers ",