diff --git a/streamz/core.py b/streamz/core.py index 65b02066..fb055501 100644 --- a/streamz/core.py +++ b/streamz/core.py @@ -181,6 +181,30 @@ def _inform_asynchronous(self, asynchronous): if downstream: downstream._inform_asynchronous(asynchronous) + def _add_upstream(self, upstream): + """Add upstream to current upstreams, this method is overridden for + classes which handle stream specific buffers/caches""" + if self.upstreams == [None]: + self.upstreams[0] = upstream + else: + self.upstreams.append(upstream) + + def _add_downstream(self, downstream): + """Add downstream to current downstreams""" + self.downstreams.add(downstream) + + def _remove_downstream(self, downstream): + """Remove downstream from current downstreams""" + self.downstreams.remove(downstream) + + def _remove_upstream(self, upstream): + """Remove upstream from current upstreams, this method is overridden for + classes which handle stream specific buffers/caches""" + if len(self.upstreams) == 1: + self.upstreams[0] = [None] + else: + self.upstreams.remove(upstream) + @classmethod def register_api(cls, modifier=identity): """ Add callable to Stream API @@ -349,12 +373,8 @@ def connect(self, downstream): downstream: Stream The downstream stream to connect to ''' - self.downstreams.add(downstream) - - if downstream.upstreams == [None]: - downstream.upstreams = [self] - else: - downstream.upstreams.append(self) + self._add_downstream(downstream) + downstream._add_upstream(self) def disconnect(self, downstream): ''' Disconnect this stream to a downstream element. @@ -364,9 +384,9 @@ def disconnect(self, downstream): downstream: Stream The downstream stream to disconnect from ''' - self.downstreams.remove(downstream) + self._remove_downstream(downstream) - downstream.upstreams.remove(self) + downstream._remove_upstream(self) @property def upstream(self): @@ -789,7 +809,8 @@ def update(self, x, who=None): def _check_end(self): if self.end and self.state >= self.end: # we're done - self.upstream.downstreams.remove(self) + for upstream in self.upstreams: + upstream._remove_downstream(self) @Stream.register_api() @@ -1013,6 +1034,16 @@ def __init__(self, *upstreams, **kwargs): Stream.__init__(self, upstreams=upstreams2, **kwargs) + def _add_upstream(self, upstream): + # Override method to handle setup of buffer for new stream + self.buffers[upstream] = deque() + super(zip, self)._add_upstream(upstream) + + def _remove_upstream(self, upstream): + # Override method to handle removal of buffer for stream + self.buffers.pop(upstream) + super(zip, self)._remove_upstream(upstream) + def pack_literals(self, tup): """ Fill buffers for literals whenever we empty them """ inp = list(tup)[::-1] @@ -1064,6 +1095,7 @@ class combine_latest(Stream): def __init__(self, *upstreams, **kwargs): emit_on = kwargs.pop('emit_on', None) + self._initial_emit_on = emit_on self.last = [None for _ in upstreams] self.missing = set(upstreams) @@ -1077,6 +1109,30 @@ def __init__(self, *upstreams, **kwargs): self.emit_on = upstreams Stream.__init__(self, upstreams=upstreams, **kwargs) + def _add_upstream(self, upstream): + # Override method to handle setup of last and missing for new stream + self.last.append(None) + self.missing.update([upstream]) + super(combine_latest, self)._add_upstream(upstream) + if self._initial_emit_on is None: + self.emit_on = self.upstreams + + def _remove_upstream(self, upstream): + # Override method to handle removal of last and missing for stream + if self.emit_on == upstream: + raise RuntimeError("Can't remove the ``emit_on`` stream since that" + "would cause no data to be emitted. " + "Consider adding an ``emit_on`` first by " + "running ``node.emit_on=(upstream,)`` to add " + "a new ``emit_on`` or running " + "``node.emit_on=tuple(node.upstreams)`` to " + "emit on all incoming data") + self.last.pop(self.upstreams.index(upstream)) + self.missing.remove(upstream) + super(combine_latest, self)._remove_upstream(upstream) + if self._initial_emit_on is None: + self.emit_on = self.upstreams + def update(self, x, who=None): if self.missing and who in self.missing: self.missing.remove(who) diff --git a/streamz/tests/test_core.py b/streamz/tests/test_core.py index 1b07823b..f08e1417 100644 --- a/streamz/tests/test_core.py +++ b/streamz/tests/test_core.py @@ -1207,5 +1207,59 @@ def start(self): assert flag == [True] +def test_connect_zip(): + a = Stream() + b = Stream() + c = Stream() + x = a.zip(b) + L = x.sink_to_list() + c.connect(x) + a.emit(1) + b.emit(1) + assert not L + c.emit(1) + assert L == [(1, 1, 1)] + + +def test_disconnect_zip(): + a = Stream() + b = Stream() + c = Stream() + x = a.zip(b, c) + L = x.sink_to_list() + b.disconnect(x) + a.emit(1) + b.emit(1) + assert not L + c.emit(1) + assert L == [(1, 1)] + + +def test_connect_combine_latest(): + a = Stream() + b = Stream() + c = Stream() + x = a.combine_latest(b, emit_on=a) + L = x.sink_to_list() + c.connect(x) + b.emit(1) + c.emit(1) + a.emit(1) + assert L == [(1, 1, 1)] + + +def test_connect_discombine_latest(): + a = Stream() + b = Stream() + c = Stream() + x = a.combine_latest(b, c, emit_on=a) + L = x.sink_to_list() + c.disconnect(x) + b.emit(1) + c.emit(1) + a.emit(1) + assert L == [(1, 1)] + + if sys.version_info >= (3, 5): from streamz.tests.py3_test_core import * # noqa