From 5e670b2600173e3903de437c377fb020da28f411 Mon Sep 17 00:00:00 2001 From: Gabriel Musat Mestre Date: Sat, 4 Oct 2025 13:09:11 +0200 Subject: [PATCH] Use map_last_stream.rs in favor of callback_stream.rs --- src/common/callback_stream.rs | 84 ----------------------- src/common/map_last_stream.rs | 84 +++++++++++++++++++++++ src/common/mod.rs | 4 +- src/flight_service/do_get.rs | 126 ++++++---------------------------- 4 files changed, 106 insertions(+), 192 deletions(-) delete mode 100644 src/common/callback_stream.rs create mode 100644 src/common/map_last_stream.rs diff --git a/src/common/callback_stream.rs b/src/common/callback_stream.rs deleted file mode 100644 index 2edf97a..0000000 --- a/src/common/callback_stream.rs +++ /dev/null @@ -1,84 +0,0 @@ -use futures::Stream; -use pin_project::{pin_project, pinned_drop}; -use std::fmt::Display; -use std::pin::Pin; -use std::task::{Context, Poll}; - -/// The reason why the stream ended: -/// - [CallbackStreamEndReason::Finished] if it finished gracefully -/// - [CallbackStreamEndReason::Aborted] if it was abandoned. -#[derive(Debug)] -pub enum CallbackStreamEndReason { - /// The stream finished gracefully. - Finished, - /// The stream was abandoned. - Aborted, -} - -impl Display for CallbackStreamEndReason { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) - } -} - -/// Stream that executes a callback when it is fully consumed or gets cancelled. -#[pin_project(PinnedDrop)] -pub struct CallbackStream -where - S: Stream, - F: FnOnce(CallbackStreamEndReason), -{ - #[pin] - stream: S, - callback: Option, -} - -impl Stream for CallbackStream -where - S: Stream, - F: FnOnce(CallbackStreamEndReason), -{ - type Item = S::Item; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - - match this.stream.poll_next(cx) { - Poll::Ready(None) => { - // Stream is fully consumed, execute the callback - if let Some(callback) = this.callback.take() { - callback(CallbackStreamEndReason::Finished); - } - Poll::Ready(None) - } - other => other, - } - } -} - -#[pinned_drop] -impl PinnedDrop for CallbackStream -where - S: Stream, - F: FnOnce(CallbackStreamEndReason), -{ - fn drop(self: Pin<&mut Self>) { - let this = self.project(); - if let Some(callback) = this.callback.take() { - callback(CallbackStreamEndReason::Aborted); - } - } -} - -/// Wrap a stream with a callback that will be executed when the stream is fully -/// consumed or gets canceled. -pub fn with_callback(stream: S, callback: F) -> CallbackStream -where - S: Stream, - F: FnOnce(CallbackStreamEndReason) + Send + 'static, -{ - CallbackStream { - stream, - callback: Some(callback), - } -} diff --git a/src/common/map_last_stream.rs b/src/common/map_last_stream.rs new file mode 100644 index 0000000..d0eb779 --- /dev/null +++ b/src/common/map_last_stream.rs @@ -0,0 +1,84 @@ +use futures::{Stream, StreamExt, stream}; +use std::task::Poll; + +/// Maps the last element of the provided stream. +pub(crate) fn map_last_stream( + mut input: impl Stream + Unpin, + map_f: impl FnOnce(T) -> T, +) -> impl Stream + Unpin { + let mut final_closure = Some(map_f); + + // this is used to peek the new value so that we can map upon emitting the last message + let mut current_value = None; + + stream::poll_fn(move |cx| match futures::ready!(input.poll_next_unpin(cx)) { + Some(new_val) => { + match current_value.take() { + // This is the first value, so we store it and repoll to get the next value + None => { + current_value = Some(new_val); + cx.waker().wake_by_ref(); + Poll::Pending + } + + Some(existing) => { + current_value = Some(new_val); + + Poll::Ready(Some(existing)) + } + } + } + // this is our last value, so we map it using the user provided closure + None => match current_value.take() { + Some(existing) => { + // make sure we wake ourselves to finish the stream + cx.waker().wake_by_ref(); + + if let Some(closure) = final_closure.take() { + Poll::Ready(Some(closure(existing))) + } else { + unreachable!("the closure is only executed once") + } + } + None => Poll::Ready(None), + }, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::stream; + + #[tokio::test] + async fn test_map_last_stream_empty_stream() { + let input = stream::empty::(); + let mapped = map_last_stream(input, |x| x + 10); + let result: Vec = mapped.collect().await; + assert_eq!(result, Vec::::new()); + } + + #[tokio::test] + async fn test_map_last_stream_single_element() { + let input = stream::iter(vec![5]); + let mapped = map_last_stream(input, |x| x * 2); + let result: Vec = mapped.collect().await; + assert_eq!(result, vec![10]); + } + + #[tokio::test] + async fn test_map_last_stream_multiple_elements() { + let input = stream::iter(vec![1, 2, 3, 4]); + let mapped = map_last_stream(input, |x| x + 100); + let result: Vec = mapped.collect().await; + assert_eq!(result, vec![1, 2, 3, 104]); // Only the last element is transformed + } + + #[tokio::test] + async fn test_map_last_stream_preserves_order() { + let input = stream::iter(vec![10, 20, 30, 40, 50]); + let mapped = map_last_stream(input, |x| x - 50); + let result: Vec = mapped.collect().await; + assert_eq!(result, vec![10, 20, 30, 40, 0]); // Last element: 50 - 50 = 0 + } +} diff --git a/src/common/mod.rs b/src/common/mod.rs index 2cdf6ee..c0c7978 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,9 +1,9 @@ -mod callback_stream; mod composed_extension_codec; +mod map_last_stream; mod partitioning; #[allow(unused)] pub mod ttl_map; -pub(crate) use callback_stream::with_callback; pub(crate) use composed_extension_codec::ComposedPhysicalExtensionCodec; +pub(crate) use map_last_stream::map_last_stream; pub(crate) use partitioning::{scale_partitioning, scale_partitioning_props}; diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index ebec452..8b754de 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -1,4 +1,4 @@ -use crate::common::with_callback; +use crate::common::map_last_stream; use crate::config_extension_ext::ContextGrpcMetadata; use crate::execution_plans::{DistributedTaskContext, StageExec}; use crate::flight_service::service::ArrowFlightEndpoint; @@ -17,15 +17,11 @@ use arrow_flight::flight_service_server::FlightService; use bytes::Bytes; use datafusion::common::exec_datafusion_err; -use datafusion::execution::SendableRecordBatchStream; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::prelude::SessionContext; -use futures::stream; -use futures::{StreamExt, TryStreamExt}; +use futures::TryStreamExt; use prost::Message; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::task::Poll; use tonic::{Request, Response, Status}; #[derive(Clone, PartialEq, ::prost::Message)] @@ -132,30 +128,26 @@ impl ArrowFlightEndpoint { .execute(doget.target_partition as usize, session_state.task_ctx()) .map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?; - let schema = stream.schema(); - - // TODO: We don't need to do this since the stage / plan is captured again by the - // TrailingFlightDataStream. However, we will eventuall only use the TrailingFlightDataStream - // if we are running an `explain (analyze)` command. We should update this section - // to only use one or the other - not both. - let plan_capture = stage.plan.clone(); - let stream = with_callback(stream, move |_| { - // We need to hold a reference to the plan for at least as long as the stream is - // execution. Some plans might store state necessary for the stream to work, and - // dropping the plan early could drop this state too soon. - let _ = plan_capture; + let stream = FlightDataEncoderBuilder::new() + .with_schema(stream.schema().clone()) + .build(stream.map_err(|err| { + FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err))) + })); + + let task_data_entries = Arc::clone(&self.task_data_entries); + let num_partitions_remaining = Arc::clone(&stage_data.num_partitions_remaining); + + let stream = map_last_stream(stream, move |last| { + if num_partitions_remaining.fetch_sub(1, Ordering::SeqCst) == 1 { + task_data_entries.remove(key.clone()); + } + last.and_then(|el| collect_and_create_metrics_flight_data(key, stage, el)) }); - let record_batch_stream = Box::pin(RecordBatchStreamAdapter::new(schema, stream)); - let task_data_capture = self.task_data_entries.clone(); - Ok(flight_stream_from_record_batch_stream( - key.clone(), - stage_data.clone(), - move || { - task_data_capture.remove(key.clone()); - }, - record_batch_stream, - )) + Ok(Response::new(Box::pin(stream.map_err(|err| match err { + FlightError::Tonic(status) => *status, + _ => Status::internal(format!("Error during flight stream: {err}")), + })))) } } @@ -163,84 +155,6 @@ fn missing(field: &'static str) -> impl FnOnce() -> Status { move || Status::invalid_argument(format!("Missing field '{field}'")) } -/// Creates a tonic response from a stream of record batches. Handles -/// - RecordBatch to flight conversion partition tracking, stage eviction, and trailing metrics. -fn flight_stream_from_record_batch_stream( - stage_key: StageKey, - stage_data: TaskData, - evict_stage: impl FnOnce() + Send + 'static, - stream: SendableRecordBatchStream, -) -> Response<::DoGetStream> { - let mut flight_data_stream = - FlightDataEncoderBuilder::new() - .with_schema(stream.schema().clone()) - .build(stream.map_err(|err| { - FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err))) - })); - - // executed once when the stream ends - // decorates the last flight data with our metrics - let mut final_closure = Some(move |last_flight_data| { - if stage_data - .num_partitions_remaining - .fetch_sub(1, Ordering::SeqCst) - == 1 - { - evict_stage(); - - collect_and_create_metrics_flight_data(stage_key, stage_data.stage, last_flight_data) - } else { - Ok(last_flight_data) - } - }); - - // this is used to peek the new value - // so that we can add our metrics to the last flight data - let mut current_value = None; - - let stream = - stream::poll_fn( - move |cx| match futures::ready!(flight_data_stream.poll_next_unpin(cx)) { - Some(Ok(new_val)) => { - match current_value.take() { - // This is the first value, so we store it and repoll to get the next value - None => { - current_value = Some(new_val); - cx.waker().wake_by_ref(); - Poll::Pending - } - - Some(existing) => { - current_value = Some(new_val); - - Poll::Ready(Some(Ok(existing))) - } - } - } - // this is our last value, so we add our metrics to this flight data - None => match current_value.take() { - Some(existing) => { - // make sure we wake ourselves to finish the stream - cx.waker().wake_by_ref(); - - if let Some(closure) = final_closure.take() { - Poll::Ready(Some(closure(existing))) - } else { - unreachable!("the closure is only executed once") - } - } - None => Poll::Ready(None), - }, - err => Poll::Ready(err), - }, - ); - - Response::new(Box::pin(stream.map_err(|err| match err { - FlightError::Tonic(status) => *status, - _ => Status::internal(format!("Error during flight stream: {err}")), - }))) -} - /// Collects metrics from the provided stage and includes it in the flight data fn collect_and_create_metrics_flight_data( stage_key: StageKey,