Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions taskiq/cli/worker/args.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from dataclasses import dataclass, field
from typing import List, Optional, Sequence, Tuple
from typing import List, Optional, Sequence, Tuple, Union

from taskiq.abc.broker import AsyncBroker
from taskiq.acks import AcknowledgeType
from taskiq.cli.common_args import LogLevel

Expand All @@ -24,7 +25,7 @@ def receiver_arg_type(string: str) -> Tuple[str, str]:
class WorkerArgs:
"""Taskiq worker CLI arguments."""

broker: str
broker: Union[str, AsyncBroker]
modules: List[str]
app_dir: Optional[str] = None
tasks_pattern: Sequence[str] = ("**/tasks.py",)
Expand Down
19 changes: 11 additions & 8 deletions taskiq/cli/worker/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,17 @@ def interrupt_handler(signum: int, _frame: Any) -> None:
# We must set this field before importing tasks,
# so broker will remember all tasks it's related to.

broker = import_object(args.broker, app_dir=args.app_dir)
if inspect.isfunction(broker):
broker = broker()
if not isinstance(broker, AsyncBroker):
raise ValueError(
"Unknown broker type. Please use AsyncBroker instance "
"or pass broker factory function that returns an AsyncBroker instance.",
)
if isinstance(args.broker, AsyncBroker):
broker = args.broker
else:
broker = import_object(args.broker, app_dir=args.app_dir)
if inspect.isfunction(broker):
broker = broker()
if not isinstance(broker, AsyncBroker):
raise ValueError(
"Unknown broker type. Please use AsyncBroker instance "
"or pass broker factory function that returns an AsyncBroker instance.",
)

broker.is_worker_process = True
import_tasks(args.modules, args.tasks_pattern, args.fs_discover)
Expand Down
Loading