Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions docs/connectors/sinks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,46 @@ sdf = app.dataframe(topic)
sdf.sink(influx_sink)
```

## Sinks Are Destinations
When `.sink()` is called on a StreamingDataFrame instance, it marks the end of the processing pipeline, and
the StreamingDataFrame can't be changed anymore.
## Sinks Are Terminal Operations
`StreamingDataFrame.sink()` is special in that it's "terminal":
**no additional operations can be added to it once called** (with branching, the branch
becomes terminal).

Make sure you call `StreamingDataFrame.sink()` as the last operation.
This is to ensure no further mutations can be applied to the outbound data.

_However_, you can continue other operations with other branches, including using
the same `Sink` to push another value (with another `SDF.sink()` call).

[Learn more about _branching_ here](../../advanced/branching.md).

### Branching after SDF.sink()

It is still possible to branch after using `SDF.sink()` assuming _you do NOT reassign
with it_ (it returns `None`):

```python
sdf = app.dataframe(topic)
sdf = sdf.apply()

# Approach 1... Allows branching from `sdf`
sdf.sink()

# Approach 2...Disables branching from `sdf`
sdf = sdf.sink()
```

### Suggested Use of SDF.sink()

If further operations are required (or you want to preserve various operations for
other branches), it's recommended to use `SDF.sink()` as a standalone operation:

```python
sdf = app.dataframe(topic)
# [other operations here...]
sdf = sdf.apply().apply() # last transforms before a sink
sdf.sink(influx_sink) # do sink as a standalone call, no reassignment
sdf = sdf.apply() # continue different operations with another branch...
```

## Supported Sinks

Expand Down
57 changes: 6 additions & 51 deletions quixstreams/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
Tuple,
Literal,
Collection,
TypeVar,
)

from typing_extensions import Self, ParamSpec
from typing_extensions import Self

from quixstreams.context import (
message_context,
Expand Down Expand Up @@ -48,7 +47,7 @@
from quixstreams.sinks import BaseSink
from quixstreams.state.types import State
from .base import BaseStreaming
from .exceptions import InvalidOperation, DataFrameLocked
from .exceptions import InvalidOperation
from .registry import DataframeRegistry
from .series import StreamingSeries
from .utils import ensure_milliseconds
Expand All @@ -62,27 +61,6 @@
FilterWithMetadataCallbackStateful = Callable[[Any, Any, int, Any, State], bool]


_T = TypeVar("_T")
_P = ParamSpec("_P")


def _ensure_unlocked(func: Callable[_P, _T]) -> Callable[_P, _T]:
"""
Ensure the SDF instance is not locked by the sink() call before adding new
operations to it.
"""

@functools.wraps(func)
def wrapper(self: StreamingDataFrame, *args, **kwargs):
if self._locked:
raise DataFrameLocked(
"StreamingDataFrame is already sinked and cannot be modified"
)
return func(self, *args, **kwargs)

return wrapper


class StreamingDataFrame(BaseStreaming):
"""
`StreamingDataFrame` is the main object you will use for ETL work.
Expand Down Expand Up @@ -189,7 +167,6 @@ def apply(
expand: bool = ...,
) -> Self: ...

@_ensure_unlocked
def apply(
self,
func: Union[
Expand Down Expand Up @@ -279,7 +256,6 @@ def update(
metadata: Literal[True],
) -> Self: ...

@_ensure_unlocked
def update(
self,
func: Union[
Expand Down Expand Up @@ -372,7 +348,6 @@ def filter(
metadata: Literal[True],
) -> Self: ...

@_ensure_unlocked
def filter(
self,
func: Union[
Expand Down Expand Up @@ -459,7 +434,6 @@ def group_by(
key_serializer: Optional[SerializerType] = ...,
) -> Self: ...

@_ensure_unlocked
def group_by(
self,
key: Union[str, Callable[[Any], Any]],
Expand Down Expand Up @@ -559,7 +533,6 @@ def contains(key: str) -> StreamingSeries:
lambda value, key_, timestamp, headers: key in value
)

@_ensure_unlocked
def to_topic(
self, topic: Topic, key: Optional[Callable[[Any], Any]] = None
) -> Self:
Expand Down Expand Up @@ -594,7 +567,7 @@ def to_topic(
By default, the current message key will be used.
:return: the updated StreamingDataFrame instance (reassignment NOT required).
"""
return self.apply(
return self._add_update(
lambda value, orig_key, timestamp, headers: self._produce(
topic=topic,
value=value,
Expand All @@ -605,7 +578,6 @@ def to_topic(
metadata=True,
)

@_ensure_unlocked
def set_timestamp(self, func: Callable[[Any, Any, int, Any], int]) -> Self:
"""
Set a new timestamp based on the current message value and its metadata.
Expand Down Expand Up @@ -647,7 +619,6 @@ def _set_timestamp_callback(
stream = self.stream.add_transform(func=_set_timestamp_callback)
return self.__dataframe_clone__(stream=stream)

@_ensure_unlocked
def set_headers(
self,
func: Callable[
Expand Down Expand Up @@ -699,7 +670,6 @@ def _set_headers_callback(
stream = self.stream.add_transform(func=_set_headers_callback)
return self.__dataframe_clone__(stream=stream)

@_ensure_unlocked
def print(self, pretty: bool = True, metadata: bool = False) -> Self:
"""
Print out the current message value (and optionally, the message metadata) to
Expand Down Expand Up @@ -813,7 +783,6 @@ def test(
context.run(composed[topic.name], value, key, timestamp, headers)
return result

@_ensure_unlocked
def tumbling_window(
self,
duration_ms: Union[int, timedelta],
Expand Down Expand Up @@ -890,7 +859,6 @@ def tumbling_window(
duration_ms=duration_ms, grace_ms=grace_ms, dataframe=self, name=name
)

@_ensure_unlocked
def hopping_window(
self,
duration_ms: Union[int, timedelta],
Expand Down Expand Up @@ -983,7 +951,6 @@ def hopping_window(
name=name,
)

@_ensure_unlocked
def drop(
self,
columns: Union[str, List[str]],
Expand Down Expand Up @@ -1028,7 +995,6 @@ def drop(
metadata=False,
)

@_ensure_unlocked
def sink(self, sink: BaseSink):
"""
Sink the processed data to the specified destination.
Expand All @@ -1044,8 +1010,8 @@ def sink(self, sink: BaseSink):
and resume again after the timeout.
The backpressure handling and timeouts are defined by the specific sinks.

Note: `sink()` is a terminal operation, and you cannot add new operations
to the same StreamingDataFrame after it's called.
Note: `sink()` is a terminal operation - it cannot receive any additional
operations, but branches can still be generated from its originating SDF.

"""
self._processing_context.sink_manager.register(sink)
Expand All @@ -1064,15 +1030,8 @@ def _sink_callback(
offset=ctx.offset,
)

# even though using apply, don't return since we lock afterward anyway
# uses apply without returning to make this operation terminal
self.apply(_sink_callback, metadata=True)
self._lock()

def _lock(self):
"""
Lock the StreamingDataFrame to prevent adding new operations to it.
"""
self._locked = True

def _produce(
self,
Expand All @@ -1087,8 +1046,6 @@ def _produce(
value=value, key=key, timestamp=timestamp, context=ctx, headers=headers
)
self._producer.produce_row(row=row, topic=topic, key=key, timestamp=timestamp)
# return value so produce can be an "apply" function (no branch copy required)
return value

def _add_update(
self,
Expand Down Expand Up @@ -1140,7 +1097,6 @@ def __dataframe_clone__(
)
return clone

@_ensure_unlocked
def __setitem__(self, item_key: Any, item: Union[Self, object]):
if isinstance(item, self.__class__):
# Update an item key with a result of another sdf.apply()
Expand Down Expand Up @@ -1173,7 +1129,6 @@ def __getitem__(self, item: str) -> StreamingSeries: ...
@overload
def __getitem__(self, item: Union[StreamingSeries, List[str], Self]) -> Self: ...

@_ensure_unlocked
def __getitem__(
self, item: Union[str, List[str], StreamingSeries, Self]
) -> Union[Self, StreamingSeries]:
Expand Down
4 changes: 0 additions & 4 deletions quixstreams/dataframe/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"GroupByNestingLimit",
"InvalidColumnReference",
"ColumnDoesNotExist",
"DataFrameLocked",
"StreamingDataFrameDuplicate",
"GroupByDuplicate",
)
Expand All @@ -27,7 +26,4 @@ class GroupByNestingLimit(QuixException): ...
class GroupByDuplicate(QuixException): ...


class DataFrameLocked(QuixException): ...


class StreamingDataFrameDuplicate(QuixException): ...
87 changes: 87 additions & 0 deletions tests/test_quixstreams/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2090,6 +2090,93 @@ def write(self, batch: SinkBatch):
)
assert committed.offset == total_messages

def test_run_with_sink_branches_success(
self,
app_factory,
executor,
):

processed_count = 0
total_messages = 3

def on_message_processed(topic_, partition, offset):
# Set the callback to track total messages processed
# The callback is not triggered if processing fails
nonlocal processed_count

processed_count += 1
# Stop processing after consuming all the messages
if processed_count == total_messages:
done.set_result(True)

app = app_factory(
auto_offset_reset="earliest",
on_message_processed=on_message_processed,
)
sink = DummySink()

topic = app.topic(
str(uuid.uuid4()),
value_deserializer="str",
config=TopicConfig(num_partitions=3, replication_factor=1),
)
sdf = app.dataframe(topic)
sdf = sdf.apply(lambda x: x + "_branch")
sdf.apply(lambda x: x + "0").sink(sink)
sdf.apply(lambda x: x + "1").sink(sink)
sdf = sdf.apply(lambda x: x + "2")
sdf.sink(sink)

key, value, timestamp_ms = b"key", "value", 1000
headers = [("key", b"value")]

# Produce messages to different topic partitions and flush
with app.get_producer() as producer:
for i in range(total_messages):
producer.produce(
topic=topic.name,
partition=i,
key=key,
value=value,
timestamp=timestamp_ms,
headers=headers,
)

done = Future()

# Stop app when the future is resolved
executor.submit(_stop_app_on_future, app, done, 15.0)
app.run(sdf)

# Check that all messages have been processed
assert processed_count == total_messages

# Ensure all messages were flushed to the sink
assert len(sink.results) == 9
for i in range(3):
assert (
len([r for r in sink.results if f"_branch{i}" in r.value])
== total_messages
)
for item in sink.results:
assert item.key == key
assert value in item.value
assert item.timestamp == timestamp_ms
assert item.headers == headers

# Ensure that the offsets are committed
with app.get_consumer() as consumer:
committed0, committed1, committed2 = consumer.committed(
[
TopicPartition(topic=topic.name, partition=0),
TopicPartition(topic=topic.name, partition=1),
TopicPartition(topic=topic.name, partition=2),
]
)
assert committed0.offset == 1
assert committed1.offset == 1
assert committed2.offset == 1


class TestApplicationMultipleSdf:
def test_multiple_sdfs(
Expand Down
Loading
Loading