From 84d353b5ea4decb998079c3075ae74165e2776a7 Mon Sep 17 00:00:00 2001 From: Yifan Zhang Date: Thu, 16 Dec 2021 16:42:33 +0300 Subject: [PATCH] fix worker type logic --- src/pipeline/worker.py | 11 ++++++-- tests/test_worker.py | 63 +++++++++++++++--------------------------- 2 files changed, 30 insertions(+), 44 deletions(-) diff --git a/src/pipeline/worker.py b/src/pipeline/worker.py index c886e35..05a86b4 100644 --- a/src/pipeline/worker.py +++ b/src/pipeline/worker.py @@ -160,8 +160,6 @@ def parse_args(self, args: List[str] = sys.argv[1:]) -> None: ) ) self.logger.info(f"Destination: {self.destination}") - else: - self.worker_type = WorkerType.NoOutput if options.help: print(parser.format_help()) @@ -170,6 +168,13 @@ def parse_args(self, args: List[str] = sys.argv[1:]) -> None: if self.has_input() and self.settings.in_kind is None: self.logger.critical("Please specify '--in-kind' or environment 'IN_KIND'!") raise PipelineError("Please specify '--in-kind' or environment 'IN_KIND'!") + elif self.has_output() and self.settings.out_kind is None: + self.logger.critical( + "Please specify '--out-kind' or environment 'OUT_KIND'!" + ) + raise PipelineError( + "Please specify '--out-kind' or environment 'OUT_KIND'!" + ) # report worker info to monitor self.monitor.record_worker_info() @@ -405,7 +410,7 @@ def __init__( if output_class is None: super().__init__(settings, worker_type=WorkerType.NoOutput, logger=logger) else: - super().__init__(settings, logger=logger) + super().__init__(settings, worker_type=WorkerType.Normal, logger=logger) self.retryEnabled = False self.input_class = input_class self.output_class = output_class diff --git a/tests/test_worker.py b/tests/test_worker.py index 58bac3d..740b18d 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -15,6 +15,7 @@ Processor, SplitterSettings, Splitter, + PipelineError, PipelineOutputError, CommandActions, Definition, @@ -42,21 +43,29 @@ def test_mem_producer(self, monkeypatch): class Output(BaseModel): key: int + class MyProducerSettings(ProducerSettings): + pass + class MyProducer(Producer): + def __init__(self): + settings = MyProducerSettings( + name="producer", + version="0.1.0", + description="test", + ) + super().__init__( + settings=settings, + output_class=Output, + ) + def generate(self): for i in range(3): yield Output(key=i) - settings = ProducerSettings( - name="producer", - version="0.1.0", - description="", - out_kind=TapKind.MEM, - debug=True, - monitoring=True, - ) - producer = MyProducer(settings, output_class=Output) - producer.parse_args(args=["--out-topic", "test", "--debug"]) + producer = MyProducer() + producer.parse_args(args="--out-kind MEM".split()) + assert producer.has_input() is False + assert producer.has_output() is True monkeypatch.setattr(producer, "monitor", mock.MagicMock()) producer.start() assert len(producer.destination.results) == 3 @@ -362,34 +371,8 @@ class Output(BaseModel): settings = ProcessorSettings(name="processor", version="0.0.0", description="") processor = Processor(settings, input_class=Input, output_class=Output) - processor.parse_args(args=["--in-kind", "MEM"]) - assert processor.settings.out_kind is None - assert not hasattr(processor, "destination") - - def test_mem_processor_limit(self): - class Input(BaseModel): - pass - - class Output(BaseModel): - pass - - class MyProcessor(Processor): - def process(self, msg: Input, id: str) -> Output: - return Output() - - msgs = [{}, {}, {}] - settings = ProcessorSettings( - name="processor", - version="0.0.0", - description="", - in_kind=TapKind.MEM, - out_kind=TapKind.MEM, - ) - processor = MyProcessor(settings, input_class=Input, output_class=Output) - processor.parse_args(args=["--limit", "1", "--out-topic", "test"]) - processor.source.load_data(msgs) - processor.start() - assert len(processor.destination.results) == 1 + with pytest.raises(PipelineError): + processor.parse_args(args="--in-kind MEM".split()) def test_splitter(self): msgs = [{"language": "en"}, {"language": "it"}] @@ -397,8 +380,6 @@ def test_splitter(self): name="splitter", version="0.0.0", description="", - in_kind=TapKind.MEM, - out_kind=TapKind.MEM, ) class MySplitter(Splitter): @@ -406,7 +387,7 @@ def get_topic(self, msg): return f'test-{msg.get("language")}' splitter = MySplitter(settings) - splitter.parse_args(args=["--out-topic", "test"]) + splitter.parse_args(args="--in-kind MEM --out-kind MEM".split()) splitter.source.load_data(msgs) splitter.start() assert len(splitter.destinations["test-en"].results) == 1