diff --git a/CHANGELOG.md b/CHANGELOG.md index 67ce4f3..b5d617a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,18 @@ All notable changes to this project will be documented in this file. +## [0.17.2] - 2025-11-27 + +### Bug Fixes + +- Handle backpressure in async zarr storage (Adrian Seyboldt) + + +### Miscellaneous Tasks + +- Update pulp dependency (Adrian Seyboldt) + + ## [0.17.1] - 2025-11-13 ### Bug Fixes @@ -20,6 +32,10 @@ All notable changes to this project will be documented in this file. - Update dependencies (Adrian Seyboldt) +- Bump version (Adrian Seyboldt) + +- Bump nuts-storable version (Adrian Seyboldt) + ## [0.17.0] - 2025-10-08 diff --git a/Cargo.toml b/Cargo.toml index 6148b3e..6aa9f1a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "nuts-rs" -version = "0.17.1" +version = "0.17.2" authors = [ "Adrian Seyboldt ", "PyMC Developers ", @@ -25,7 +25,7 @@ thiserror = "2.0.3" rand_chacha = "0.9.0" anyhow = "1.0.72" faer = { version = "0.23.2", default-features = false, features = ["linalg"] } -pulp = "0.21.4" +pulp = "0.22.2" rayon = "1.10.0" zarrs = { version = "0.22.0", features = [ "filesystem", @@ -42,7 +42,7 @@ nuts-derive = { path = "./nuts-derive", version = "0.1.0" } nuts-storable = { path = "./nuts-storable", version = "0.2.0" } serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0" -tokio = { version = "1.0", features = ["rt"], optional = true } +tokio = { version = "1.0", features = ["rt", "sync", "fs"], optional = true } [dev-dependencies] proptest = "1.6.0" diff --git a/src/storage/zarr/async_impl.rs b/src/storage/zarr/async_impl.rs index 69a59d3..77c8ca0 100644 --- a/src/storage/zarr/async_impl.rs +++ b/src/storage/zarr/async_impl.rs @@ -2,7 +2,8 @@ use std::collections::HashMap; use std::iter::once; use std::num::NonZero; use std::sync::Arc; -use tokio::task::JoinHandle; +use tokio::runtime::Handle; +use tokio::task::JoinSet; use anyhow::{Context, Result}; use nuts_storable::{ItemType, Value}; @@ -43,8 +44,9 @@ pub struct ZarrAsyncChainStorage { arrays: Arc, chain: u64, last_sample_was_warmup: bool, - pending_writes: Vec>>, + pending_writes: Arc>>>, rt_handle: tokio::runtime::Handle, + max_queued_writes: usize, } /// Write a chunk of data to a Zarr array asynchronously @@ -240,22 +242,28 @@ impl ZarrAsyncChainStorage { chain: u64, rt_handle: tokio::runtime::Handle, ) -> Self { - let draw_buffers = draw_types + let draw_buffers: HashMap = draw_types .iter() .map(|(name, item_type)| (name.clone(), SampleBuffer::new(*item_type, buffer_size))) .collect(); - let stats_buffers = param_types + let stats_buffers: HashMap = param_types .iter() .map(|(name, item_type)| (name.clone(), SampleBuffer::new(*item_type, buffer_size))) .collect(); + + let num_arrays = draw_buffers.len() + stats_buffers.len(); + Self { draw_buffers, stats_buffers, arrays, chain, last_sample_was_warmup: true, - pending_writes: Vec::new(), + pending_writes: Arc::new(tokio::sync::Mutex::new(JoinSet::new())), + // We allow up to the number of arrays in pending writes, so + // that we queue one write per draw. + max_queued_writes: num_arrays.max(1), rt_handle, } } @@ -275,10 +283,15 @@ impl ZarrAsyncChainStorage { self.arrays.sample_param_arrays[name].clone() }; let chain = self.chain; - let handle = self - .rt_handle - .spawn(async move { store_zarr_chunk_async(array, chunk, chain).await }); - self.pending_writes.push(handle); + + queue_write( + &self.rt_handle, + self.pending_writes.clone(), + self.max_queued_writes, + array, + chunk, + chain, + )?; } Ok(()) } @@ -298,15 +311,57 @@ impl ZarrAsyncChainStorage { self.arrays.sample_draw_arrays[name].clone() }; let chain = self.chain; - let handle = self - .rt_handle - .spawn(async move { store_zarr_chunk_async(array, chunk, chain).await }); - self.pending_writes.push(handle); + + queue_write( + &self.rt_handle, + self.pending_writes.clone(), + self.max_queued_writes, + array, + chunk, + chain, + )?; } Ok(()) } } +fn queue_write( + handle: &Handle, + queue: Arc>>>, + max_queued_writes: usize, + array: Array, + chunk: Chunk, + chain: u64, +) -> Result<()> { + let rt_handle = handle.clone(); + // We need an async task to interface with the async storage + // and JoinSet API. + let spawn_write_task = handle.spawn(async move { + // This should never actually block, because this lock + // is only held in tasks that are spawned and immediately blocked_on + // from the sampling thread. + let mut writes_guard = queue.lock().await; + + while writes_guard.len() >= max_queued_writes { + let out = writes_guard.join_next().await; + if let Some(out) = out { + out.context("Failed to await previous trace write operation")? + .context("Chunk write operation failed")?; + } else { + break; + } + } + writes_guard.spawn_on( + async move { store_zarr_chunk_async(array, chunk, chain).await }, + &rt_handle, + ); + Ok(()) + }); + let res: Result<()> = handle.block_on(spawn_write_task)?; + res?; + Ok(()) +} + impl ChainStorage for ZarrAsyncChainStorage { type Finalized = (); @@ -323,20 +378,30 @@ impl ChainStorage for ZarrAsyncChainStorage { if let Some(chunk) = buffer.reset() { let array = self.arrays.warmup_draw_arrays[key].clone(); let chain = self.chain; - let handle = self - .rt_handle - .spawn(async move { store_zarr_chunk_async(array, chunk, chain).await }); - self.pending_writes.push(handle); + + queue_write( + &self.rt_handle, + self.pending_writes.clone(), + self.max_queued_writes, + array, + chunk, + chain, + )?; } } for (key, buffer) in self.stats_buffers.iter_mut() { if let Some(chunk) = buffer.reset() { let array = self.arrays.warmup_param_arrays[key].clone(); let chain = self.chain; - let handle = self - .rt_handle - .spawn(async move { store_zarr_chunk_async(array, chunk, chain).await }); - self.pending_writes.push(handle); + + queue_write( + &self.rt_handle, + self.pending_writes.clone(), + self.max_queued_writes, + array, + chunk, + chain, + )?; } } self.last_sample_was_warmup = false; @@ -382,11 +447,14 @@ impl ChainStorage for ZarrAsyncChainStorage { } // Join all pending writes + // All tasks that hold a reference to the queue are blocked_on + // right away, so we hold the only refercne to `self.pending_writes`. + let pending_writes = Arc::into_inner(self.pending_writes) + .expect("Could not take ownership of pending writes queue") + .into_inner(); self.rt_handle.block_on(async move { - for join_handle in self.pending_writes { - let _ = join_handle - .await - .context("Failed to await async chunk write operation")?; + for join_handle in pending_writes.join_all().await { + let _ = join_handle.context("Failed to await async chunk write operation")?; } Ok::<(), anyhow::Error>(()) })?; @@ -420,6 +488,21 @@ impl ChainStorage for ZarrAsyncChainStorage { } } + // Join all pending writes + let pending_writes = self.pending_writes.clone(); + self.rt_handle.block_on(async move { + let mut pending_writes = pending_writes.lock().await; + loop { + let Some(join_handle) = pending_writes.join_next().await else { + break; + }; + join_handle + .context("Failed to await async chunk write operation")? + .context("Chunk write operation failed")?; + } + Ok::<(), anyhow::Error>(()) + })?; + Ok(()) } }