Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@

All notable changes to this project will be documented in this file.

## [0.17.2] - 2025-11-27

### Bug Fixes

- Handle backpressure in async zarr storage (Adrian Seyboldt)


### Miscellaneous Tasks

- Update pulp dependency (Adrian Seyboldt)


## [0.17.1] - 2025-11-13

### Bug Fixes
Expand All @@ -20,6 +32,10 @@ All notable changes to this project will be documented in this file.

- Update dependencies (Adrian Seyboldt)

- Bump version (Adrian Seyboldt)

- Bump nuts-storable version (Adrian Seyboldt)


## [0.17.0] - 2025-10-08

Expand Down
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "nuts-rs"
version = "0.17.1"
version = "0.17.2"
authors = [
"Adrian Seyboldt <adrian.seyboldt@gmail.com>",
"PyMC Developers <pymc.devs@gmail.com>",
Expand All @@ -25,7 +25,7 @@ thiserror = "2.0.3"
rand_chacha = "0.9.0"
anyhow = "1.0.72"
faer = { version = "0.23.2", default-features = false, features = ["linalg"] }
pulp = "0.21.4"
pulp = "0.22.2"
rayon = "1.10.0"
zarrs = { version = "0.22.0", features = [
"filesystem",
Expand All @@ -42,7 +42,7 @@ nuts-derive = { path = "./nuts-derive", version = "0.1.0" }
nuts-storable = { path = "./nuts-storable", version = "0.2.0" }
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1.0", features = ["rt"], optional = true }
tokio = { version = "1.0", features = ["rt", "sync", "fs"], optional = true }

[dev-dependencies]
proptest = "1.6.0"
Expand Down
133 changes: 108 additions & 25 deletions src/storage/zarr/async_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use std::collections::HashMap;
use std::iter::once;
use std::num::NonZero;
use std::sync::Arc;
use tokio::task::JoinHandle;
use tokio::runtime::Handle;
use tokio::task::JoinSet;

use anyhow::{Context, Result};
use nuts_storable::{ItemType, Value};
Expand Down Expand Up @@ -43,8 +44,9 @@ pub struct ZarrAsyncChainStorage {
arrays: Arc<ArrayCollection>,
chain: u64,
last_sample_was_warmup: bool,
pending_writes: Vec<JoinHandle<Result<()>>>,
pending_writes: Arc<tokio::sync::Mutex<JoinSet<Result<()>>>>,
rt_handle: tokio::runtime::Handle,
max_queued_writes: usize,
}

/// Write a chunk of data to a Zarr array asynchronously
Expand Down Expand Up @@ -240,22 +242,28 @@ impl ZarrAsyncChainStorage {
chain: u64,
rt_handle: tokio::runtime::Handle,
) -> Self {
let draw_buffers = draw_types
let draw_buffers: HashMap<String, SampleBuffer> = draw_types
.iter()
.map(|(name, item_type)| (name.clone(), SampleBuffer::new(*item_type, buffer_size)))
.collect();

let stats_buffers = param_types
let stats_buffers: HashMap<String, SampleBuffer> = param_types
.iter()
.map(|(name, item_type)| (name.clone(), SampleBuffer::new(*item_type, buffer_size)))
.collect();

let num_arrays = draw_buffers.len() + stats_buffers.len();

Self {
draw_buffers,
stats_buffers,
arrays,
chain,
last_sample_was_warmup: true,
pending_writes: Vec::new(),
pending_writes: Arc::new(tokio::sync::Mutex::new(JoinSet::new())),
// We allow up to the number of arrays in pending writes, so
// that we queue one write per draw.
max_queued_writes: num_arrays.max(1),
rt_handle,
}
}
Expand All @@ -275,10 +283,15 @@ impl ZarrAsyncChainStorage {
self.arrays.sample_param_arrays[name].clone()
};
let chain = self.chain;
let handle = self
.rt_handle
.spawn(async move { store_zarr_chunk_async(array, chunk, chain).await });
self.pending_writes.push(handle);

queue_write(
&self.rt_handle,
self.pending_writes.clone(),
self.max_queued_writes,
array,
chunk,
chain,
)?;
}
Ok(())
}
Expand All @@ -298,15 +311,57 @@ impl ZarrAsyncChainStorage {
self.arrays.sample_draw_arrays[name].clone()
};
let chain = self.chain;
let handle = self
.rt_handle
.spawn(async move { store_zarr_chunk_async(array, chunk, chain).await });
self.pending_writes.push(handle);

queue_write(
&self.rt_handle,
self.pending_writes.clone(),
self.max_queued_writes,
array,
chunk,
chain,
)?;
}
Ok(())
}
}

fn queue_write(
handle: &Handle,
queue: Arc<tokio::sync::Mutex<JoinSet<Result<()>>>>,
max_queued_writes: usize,
array: Array,
chunk: Chunk,
chain: u64,
) -> Result<()> {
let rt_handle = handle.clone();
// We need an async task to interface with the async storage
// and JoinSet API.
let spawn_write_task = handle.spawn(async move {
// This should never actually block, because this lock
// is only held in tasks that are spawned and immediately blocked_on
// from the sampling thread.
let mut writes_guard = queue.lock().await;

while writes_guard.len() >= max_queued_writes {
let out = writes_guard.join_next().await;
if let Some(out) = out {
out.context("Failed to await previous trace write operation")?
.context("Chunk write operation failed")?;
} else {
break;
}
}
writes_guard.spawn_on(
async move { store_zarr_chunk_async(array, chunk, chain).await },
&rt_handle,
);
Ok(())
});
let res: Result<()> = handle.block_on(spawn_write_task)?;
res?;
Ok(())
}

impl ChainStorage for ZarrAsyncChainStorage {
type Finalized = ();

Expand All @@ -323,20 +378,30 @@ impl ChainStorage for ZarrAsyncChainStorage {
if let Some(chunk) = buffer.reset() {
let array = self.arrays.warmup_draw_arrays[key].clone();
let chain = self.chain;
let handle = self
.rt_handle
.spawn(async move { store_zarr_chunk_async(array, chunk, chain).await });
self.pending_writes.push(handle);

queue_write(
&self.rt_handle,
self.pending_writes.clone(),
self.max_queued_writes,
array,
chunk,
chain,
)?;
}
}
for (key, buffer) in self.stats_buffers.iter_mut() {
if let Some(chunk) = buffer.reset() {
let array = self.arrays.warmup_param_arrays[key].clone();
let chain = self.chain;
let handle = self
.rt_handle
.spawn(async move { store_zarr_chunk_async(array, chunk, chain).await });
self.pending_writes.push(handle);

queue_write(
&self.rt_handle,
self.pending_writes.clone(),
self.max_queued_writes,
array,
chunk,
chain,
)?;
}
}
self.last_sample_was_warmup = false;
Expand Down Expand Up @@ -382,11 +447,14 @@ impl ChainStorage for ZarrAsyncChainStorage {
}

// Join all pending writes
// All tasks that hold a reference to the queue are blocked_on
// right away, so we hold the only refercne to `self.pending_writes`.
let pending_writes = Arc::into_inner(self.pending_writes)
.expect("Could not take ownership of pending writes queue")
.into_inner();
self.rt_handle.block_on(async move {
for join_handle in self.pending_writes {
let _ = join_handle
.await
.context("Failed to await async chunk write operation")?;
for join_handle in pending_writes.join_all().await {
let _ = join_handle.context("Failed to await async chunk write operation")?;
}
Ok::<(), anyhow::Error>(())
})?;
Expand Down Expand Up @@ -420,6 +488,21 @@ impl ChainStorage for ZarrAsyncChainStorage {
}
}

// Join all pending writes
let pending_writes = self.pending_writes.clone();
self.rt_handle.block_on(async move {
let mut pending_writes = pending_writes.lock().await;
loop {
let Some(join_handle) = pending_writes.join_next().await else {
break;
};
join_handle
.context("Failed to await async chunk write operation")?
.context("Chunk write operation failed")?;
}
Ok::<(), anyhow::Error>(())
})?;

Ok(())
}
}
Expand Down
Loading