-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Include queue extra config in
task_router
(#361)
* Include queue extra config in `task_router` We now want to have queue-specific config for enterprise-queues. To get that specific config we need the task name, but it should only be applied if using enterprise-queue. So a natural way to put it is when we get the queue for a task, we also get any extra config for it. Putting this in `shared` to make sure it's consistent across `worker` and `api`.
- Loading branch information
1 parent
6478226
commit 4b845db
Showing
4 changed files
with
92 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,30 @@ | ||
from shared.billing import BillingPlan, is_enterprise_cloud_plan | ||
from shared.celery_config import BaseCeleryConfig | ||
from shared.celery_config import BaseCeleryConfig, get_task_group | ||
from shared.config import get_config | ||
|
||
|
||
def route_tasks_based_on_user_plan(task_name: str, user_plan: str): | ||
"""Helper function to dynamically route tasks based on the user plan. | ||
This cannot be used as a celery router function directly. | ||
Returns extra config for the queue, if any. | ||
""" | ||
default_task_queue = BaseCeleryConfig.task_routes.get( | ||
task_name, dict(queue=BaseCeleryConfig.task_default_queue) | ||
)["queue"] | ||
billing_plan = BillingPlan.from_str(user_plan) | ||
if is_enterprise_cloud_plan(billing_plan): | ||
return {"queue": "enterprise_" + default_task_queue} | ||
return {"queue": default_task_queue} | ||
default_enterprise_queue_specific_config = get_config( | ||
"setup", "tasks", "celery", "enterprise", default=dict() | ||
) | ||
this_queue_specific_config = get_config( | ||
"setup", | ||
"tasks", | ||
get_task_group(task_name), | ||
"enterprise", | ||
default=default_enterprise_queue_specific_config, | ||
) | ||
return { | ||
"queue": "enterprise_" + default_task_queue, | ||
"extra_config": this_queue_specific_config, | ||
} | ||
return {"queue": default_task_queue, "extra_config": {}} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,44 @@ | ||
from shared.billing import BillingPlan | ||
from shared.celery_config import upload_task_name | ||
from shared.celery_config import timeseries_backfill_task_name, upload_task_name | ||
from shared.celery_router import route_tasks_based_on_user_plan | ||
|
||
|
||
def test_route_tasks_based_on_user_plan_defaults(): | ||
assert route_tasks_based_on_user_plan( | ||
upload_task_name, BillingPlan.users_basic.db_name | ||
) == {"queue": "celery"} | ||
) == {"queue": "celery", "extra_config": {}} | ||
assert route_tasks_based_on_user_plan( | ||
upload_task_name, BillingPlan.enterprise_cloud_monthly.db_name | ||
) == {"queue": "enterprise_celery"} | ||
) == {"queue": "enterprise_celery", "extra_config": {}} | ||
assert route_tasks_based_on_user_plan( | ||
"misterious_task", BillingPlan.users_basic.db_name | ||
) == {"queue": "celery"} | ||
) == {"queue": "celery", "extra_config": {}} | ||
assert route_tasks_based_on_user_plan( | ||
"misterious_task", BillingPlan.enterprise_cloud_monthly.db_name | ||
) == {"queue": "enterprise_celery"} | ||
) == {"queue": "enterprise_celery", "extra_config": {}} | ||
|
||
|
||
def test_route_tasks_with_config(mock_configuration): | ||
mock_configuration._params["setup"] = { | ||
"tasks": { | ||
"celery": {"enterprise": {"soft_timelimit": 100, "hard_timelimit": 200}}, | ||
"timeseries": { | ||
"enterprise": {"soft_timelimit": 400, "hard_timelimit": 500} | ||
}, | ||
} | ||
} | ||
assert route_tasks_based_on_user_plan( | ||
upload_task_name, BillingPlan.users_basic.db_name | ||
) == {"queue": "celery", "extra_config": {}} | ||
assert route_tasks_based_on_user_plan( | ||
upload_task_name, BillingPlan.enterprise_cloud_monthly.db_name | ||
) == { | ||
"queue": "enterprise_celery", | ||
"extra_config": {"soft_timelimit": 100, "hard_timelimit": 200}, | ||
} | ||
assert route_tasks_based_on_user_plan( | ||
timeseries_backfill_task_name, BillingPlan.enterprise_cloud_monthly.db_name | ||
) == { | ||
"queue": "enterprise_celery", | ||
"extra_config": {"soft_timelimit": 400, "hard_timelimit": 500}, | ||
} |