Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 0 additions & 84 deletions src/common/callback_stream.rs

This file was deleted.

84 changes: 84 additions & 0 deletions src/common/map_last_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use futures::{Stream, StreamExt, stream};
use std::task::Poll;

/// Maps the last element of the provided stream.
pub(crate) fn map_last_stream<T>(
mut input: impl Stream<Item = T> + Unpin,
map_f: impl FnOnce(T) -> T,
) -> impl Stream<Item = T> + Unpin {
let mut final_closure = Some(map_f);

// this is used to peek the new value so that we can map upon emitting the last message
let mut current_value = None;

stream::poll_fn(move |cx| match futures::ready!(input.poll_next_unpin(cx)) {
Some(new_val) => {
match current_value.take() {
// This is the first value, so we store it and repoll to get the next value
None => {
current_value = Some(new_val);
cx.waker().wake_by_ref();
Poll::Pending
}

Some(existing) => {
current_value = Some(new_val);

Poll::Ready(Some(existing))
}
}
}
// this is our last value, so we map it using the user provided closure
None => match current_value.take() {
Some(existing) => {
// make sure we wake ourselves to finish the stream
cx.waker().wake_by_ref();

if let Some(closure) = final_closure.take() {
Poll::Ready(Some(closure(existing)))
} else {
unreachable!("the closure is only executed once")
}
}
None => Poll::Ready(None),
},
})
}

#[cfg(test)]
mod tests {
use super::*;
use futures::stream;

#[tokio::test]
async fn test_map_last_stream_empty_stream() {
let input = stream::empty::<i32>();
let mapped = map_last_stream(input, |x| x + 10);
let result: Vec<i32> = mapped.collect().await;
assert_eq!(result, Vec::<i32>::new());
}

#[tokio::test]
async fn test_map_last_stream_single_element() {
let input = stream::iter(vec![5]);
let mapped = map_last_stream(input, |x| x * 2);
let result: Vec<i32> = mapped.collect().await;
assert_eq!(result, vec![10]);
}

#[tokio::test]
async fn test_map_last_stream_multiple_elements() {
let input = stream::iter(vec![1, 2, 3, 4]);
let mapped = map_last_stream(input, |x| x + 100);
let result: Vec<i32> = mapped.collect().await;
assert_eq!(result, vec![1, 2, 3, 104]); // Only the last element is transformed
}

#[tokio::test]
async fn test_map_last_stream_preserves_order() {
let input = stream::iter(vec![10, 20, 30, 40, 50]);
let mapped = map_last_stream(input, |x| x - 50);
let result: Vec<i32> = mapped.collect().await;
assert_eq!(result, vec![10, 20, 30, 40, 0]); // Last element: 50 - 50 = 0
}
}
4 changes: 2 additions & 2 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
mod callback_stream;
mod composed_extension_codec;
mod map_last_stream;
mod partitioning;
#[allow(unused)]
pub mod ttl_map;

pub(crate) use callback_stream::with_callback;
pub(crate) use composed_extension_codec::ComposedPhysicalExtensionCodec;
pub(crate) use map_last_stream::map_last_stream;
pub(crate) use partitioning::{scale_partitioning, scale_partitioning_props};
126 changes: 20 additions & 106 deletions src/flight_service/do_get.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::common::with_callback;
use crate::common::map_last_stream;
use crate::config_extension_ext::ContextGrpcMetadata;
use crate::execution_plans::{DistributedTaskContext, StageExec};
use crate::flight_service::service::ArrowFlightEndpoint;
Expand All @@ -17,15 +17,11 @@ use arrow_flight::flight_service_server::FlightService;
use bytes::Bytes;

use datafusion::common::exec_datafusion_err;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::prelude::SessionContext;
use futures::stream;
use futures::{StreamExt, TryStreamExt};
use futures::TryStreamExt;
use prost::Message;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::Poll;
use tonic::{Request, Response, Status};

#[derive(Clone, PartialEq, ::prost::Message)]
Expand Down Expand Up @@ -132,115 +128,33 @@ impl ArrowFlightEndpoint {
.execute(doget.target_partition as usize, session_state.task_ctx())
.map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?;

let schema = stream.schema();

// TODO: We don't need to do this since the stage / plan is captured again by the
// TrailingFlightDataStream. However, we will eventuall only use the TrailingFlightDataStream
// if we are running an `explain (analyze)` command. We should update this section
// to only use one or the other - not both.
let plan_capture = stage.plan.clone();
let stream = with_callback(stream, move |_| {
// We need to hold a reference to the plan for at least as long as the stream is
// execution. Some plans might store state necessary for the stream to work, and
// dropping the plan early could drop this state too soon.
let _ = plan_capture;
let stream = FlightDataEncoderBuilder::new()
.with_schema(stream.schema().clone())
.build(stream.map_err(|err| {
FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err)))
}));

let task_data_entries = Arc::clone(&self.task_data_entries);
let num_partitions_remaining = Arc::clone(&stage_data.num_partitions_remaining);

let stream = map_last_stream(stream, move |last| {
if num_partitions_remaining.fetch_sub(1, Ordering::SeqCst) == 1 {
task_data_entries.remove(key.clone());
}
last.and_then(|el| collect_and_create_metrics_flight_data(key, stage, el))
});

let record_batch_stream = Box::pin(RecordBatchStreamAdapter::new(schema, stream));
let task_data_capture = self.task_data_entries.clone();
Ok(flight_stream_from_record_batch_stream(
key.clone(),
stage_data.clone(),
move || {
task_data_capture.remove(key.clone());
},
record_batch_stream,
))
Ok(Response::new(Box::pin(stream.map_err(|err| match err {
FlightError::Tonic(status) => *status,
_ => Status::internal(format!("Error during flight stream: {err}")),
}))))
}
}

fn missing(field: &'static str) -> impl FnOnce() -> Status {
move || Status::invalid_argument(format!("Missing field '{field}'"))
}

/// Creates a tonic response from a stream of record batches. Handles
/// - RecordBatch to flight conversion partition tracking, stage eviction, and trailing metrics.
fn flight_stream_from_record_batch_stream(
stage_key: StageKey,
stage_data: TaskData,
evict_stage: impl FnOnce() + Send + 'static,
stream: SendableRecordBatchStream,
) -> Response<<ArrowFlightEndpoint as FlightService>::DoGetStream> {
let mut flight_data_stream =
FlightDataEncoderBuilder::new()
.with_schema(stream.schema().clone())
.build(stream.map_err(|err| {
FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err)))
}));

// executed once when the stream ends
// decorates the last flight data with our metrics
let mut final_closure = Some(move |last_flight_data| {
if stage_data
.num_partitions_remaining
.fetch_sub(1, Ordering::SeqCst)
== 1
{
evict_stage();

collect_and_create_metrics_flight_data(stage_key, stage_data.stage, last_flight_data)
} else {
Ok(last_flight_data)
}
});

// this is used to peek the new value
// so that we can add our metrics to the last flight data
let mut current_value = None;

let stream =
stream::poll_fn(
move |cx| match futures::ready!(flight_data_stream.poll_next_unpin(cx)) {
Some(Ok(new_val)) => {
match current_value.take() {
// This is the first value, so we store it and repoll to get the next value
None => {
current_value = Some(new_val);
cx.waker().wake_by_ref();
Poll::Pending
}

Some(existing) => {
current_value = Some(new_val);

Poll::Ready(Some(Ok(existing)))
}
}
}
// this is our last value, so we add our metrics to this flight data
None => match current_value.take() {
Some(existing) => {
// make sure we wake ourselves to finish the stream
cx.waker().wake_by_ref();

if let Some(closure) = final_closure.take() {
Poll::Ready(Some(closure(existing)))
} else {
unreachable!("the closure is only executed once")
}
}
None => Poll::Ready(None),
},
err => Poll::Ready(err),
},
);

Response::new(Box::pin(stream.map_err(|err| match err {
FlightError::Tonic(status) => *status,
_ => Status::internal(format!("Error during flight stream: {err}")),
})))
}

/// Collects metrics from the provided stage and includes it in the flight data
fn collect_and_create_metrics_flight_data(
stage_key: StageKey,
Expand Down