Skip to content

Commit

Permalink
fix worker type logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Yifan Zhang committed Dec 16, 2021
1 parent 37865d3 commit 84d353b
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 44 deletions.
11 changes: 8 additions & 3 deletions src/pipeline/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,6 @@ def parse_args(self, args: List[str] = sys.argv[1:]) -> None:
)
)
self.logger.info(f"Destination: {self.destination}")
else:
self.worker_type = WorkerType.NoOutput

if options.help:
print(parser.format_help())
Expand All @@ -170,6 +168,13 @@ def parse_args(self, args: List[str] = sys.argv[1:]) -> None:
if self.has_input() and self.settings.in_kind is None:
self.logger.critical("Please specify '--in-kind' or environment 'IN_KIND'!")
raise PipelineError("Please specify '--in-kind' or environment 'IN_KIND'!")
elif self.has_output() and self.settings.out_kind is None:
self.logger.critical(
"Please specify '--out-kind' or environment 'OUT_KIND'!"
)
raise PipelineError(
"Please specify '--out-kind' or environment 'OUT_KIND'!"
)

# report worker info to monitor
self.monitor.record_worker_info()
Expand Down Expand Up @@ -405,7 +410,7 @@ def __init__(
if output_class is None:
super().__init__(settings, worker_type=WorkerType.NoOutput, logger=logger)
else:
super().__init__(settings, logger=logger)
super().__init__(settings, worker_type=WorkerType.Normal, logger=logger)
self.retryEnabled = False
self.input_class = input_class
self.output_class = output_class
Expand Down
63 changes: 22 additions & 41 deletions tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Processor,
SplitterSettings,
Splitter,
PipelineError,
PipelineOutputError,
CommandActions,
Definition,
Expand Down Expand Up @@ -42,21 +43,29 @@ def test_mem_producer(self, monkeypatch):
class Output(BaseModel):
key: int

class MyProducerSettings(ProducerSettings):
pass

class MyProducer(Producer):
def __init__(self):
settings = MyProducerSettings(
name="producer",
version="0.1.0",
description="test",
)
super().__init__(
settings=settings,
output_class=Output,
)

def generate(self):
for i in range(3):
yield Output(key=i)

settings = ProducerSettings(
name="producer",
version="0.1.0",
description="",
out_kind=TapKind.MEM,
debug=True,
monitoring=True,
)
producer = MyProducer(settings, output_class=Output)
producer.parse_args(args=["--out-topic", "test", "--debug"])
producer = MyProducer()
producer.parse_args(args="--out-kind MEM".split())
assert producer.has_input() is False
assert producer.has_output() is True
monkeypatch.setattr(producer, "monitor", mock.MagicMock())
producer.start()
assert len(producer.destination.results) == 3
Expand Down Expand Up @@ -362,51 +371,23 @@ class Output(BaseModel):

settings = ProcessorSettings(name="processor", version="0.0.0", description="")
processor = Processor(settings, input_class=Input, output_class=Output)
processor.parse_args(args=["--in-kind", "MEM"])
assert processor.settings.out_kind is None
assert not hasattr(processor, "destination")

def test_mem_processor_limit(self):
class Input(BaseModel):
pass

class Output(BaseModel):
pass

class MyProcessor(Processor):
def process(self, msg: Input, id: str) -> Output:
return Output()

msgs = [{}, {}, {}]
settings = ProcessorSettings(
name="processor",
version="0.0.0",
description="",
in_kind=TapKind.MEM,
out_kind=TapKind.MEM,
)
processor = MyProcessor(settings, input_class=Input, output_class=Output)
processor.parse_args(args=["--limit", "1", "--out-topic", "test"])
processor.source.load_data(msgs)
processor.start()
assert len(processor.destination.results) == 1
with pytest.raises(PipelineError):
processor.parse_args(args="--in-kind MEM".split())

def test_splitter(self):
msgs = [{"language": "en"}, {"language": "it"}]
settings = SplitterSettings(
name="splitter",
version="0.0.0",
description="",
in_kind=TapKind.MEM,
out_kind=TapKind.MEM,
)

class MySplitter(Splitter):
def get_topic(self, msg):
return f'test-{msg.get("language")}'

splitter = MySplitter(settings)
splitter.parse_args(args=["--out-topic", "test"])
splitter.parse_args(args="--in-kind MEM --out-kind MEM".split())
splitter.source.load_data(msgs)
splitter.start()
assert len(splitter.destinations["test-en"].results) == 1
Expand Down

0 comments on commit 84d353b

Please sign in to comment.