Skip to content
This repository was archived by the owner on May 5, 2025. It is now read-only.

Commit 4b845db

Browse files
Include queue extra config in task_router (#361)
* Include queue extra config in `task_router` We now want to have queue-specific config for enterprise-queues. To get that specific config we need the task name, but it should only be applied if using enterprise-queue. So a natural way to put it is when we get the queue for a task, we also get any extra config for it. Putting this in `shared` to make sure it's consistent across `worker` and `api`.
1 parent 6478226 commit 4b845db

File tree

4 files changed

+92
-39
lines changed

4 files changed

+92
-39
lines changed

shared/celery_config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# http://docs.celeryq.org/en/latest/configuration.html#configuration
2-
32
from typing import Optional
43

54
from shared.config import get_config

shared/celery_router.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,30 @@
11
from shared.billing import BillingPlan, is_enterprise_cloud_plan
2-
from shared.celery_config import BaseCeleryConfig
2+
from shared.celery_config import BaseCeleryConfig, get_task_group
3+
from shared.config import get_config
34

45

56
def route_tasks_based_on_user_plan(task_name: str, user_plan: str):
67
"""Helper function to dynamically route tasks based on the user plan.
78
This cannot be used as a celery router function directly.
9+
Returns extra config for the queue, if any.
810
"""
911
default_task_queue = BaseCeleryConfig.task_routes.get(
1012
task_name, dict(queue=BaseCeleryConfig.task_default_queue)
1113
)["queue"]
1214
billing_plan = BillingPlan.from_str(user_plan)
1315
if is_enterprise_cloud_plan(billing_plan):
14-
return {"queue": "enterprise_" + default_task_queue}
15-
return {"queue": default_task_queue}
16+
default_enterprise_queue_specific_config = get_config(
17+
"setup", "tasks", "celery", "enterprise", default=dict()
18+
)
19+
this_queue_specific_config = get_config(
20+
"setup",
21+
"tasks",
22+
get_task_group(task_name),
23+
"enterprise",
24+
default=default_enterprise_queue_specific_config,
25+
)
26+
return {
27+
"queue": "enterprise_" + default_task_queue,
28+
"extra_config": this_queue_specific_config,
29+
}
30+
return {"queue": default_task_queue, "extra_config": {}}

tests/unit/test_celery_config.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22

33
import shared.celery_config as celery_config
4+
from shared.utils.enums import TaskConfigGroup
45

56

67
def test_celery_config():
@@ -62,36 +63,48 @@ def test_celery_config():
6263
@pytest.mark.parametrize(
6364
"task_name,task_group",
6465
[
65-
("app.cron.healthcheck.HealthCheckTask", "healthcheck"),
66-
("app.cron.profiling.findinguncollected", "profiling"),
67-
("app.tasks.add_to_sendgrid_list.AddToSendgridList", "add_to_sendgrid_list"),
68-
("app.tasks.archive.MigrateToArchive", "archive"),
69-
("app.tasks.verify_bot.VerifyBot", "verify_bot"),
70-
("app.tasks.comment.Comment", "comment"),
71-
("app.tasks.commit_update.CommitUpdate", "commit_update"),
72-
("app.tasks.compute_comparison.ComputeComparison", "compute_comparison"),
73-
("app.tasks.delete_owner.DeleteOwner", "delete_owner"),
74-
("app.tasks.flush_repo.FlushRepo", "flush_repo"),
75-
("app.tasks.sync_plans.SyncPlans", "sync_plans"),
76-
("app.tasks.new_user_activated.NewUserActivated", "new_user_activated"),
77-
("app.tasks.notify.Notify", "notify"),
78-
("app.tasks.profiling.collection", "profiling"),
79-
("app.tasks.profiling.normalizer", "profiling"),
80-
("app.tasks.profiling.summarization", "profiling"),
81-
("app.tasks.pulls.Sync", "pulls"),
82-
("app.tasks.remove_webhook.RemoveOldHook", "remove_webhook"),
83-
("app.tasks.status.SetError", "status"),
84-
("app.tasks.status.SetPending", "status"),
85-
("app.tasks.sync_repos.SyncRepos", "sync_repos"),
86-
("app.tasks.sync_teams.SyncTeams", "sync_teams"),
87-
("app.tasks.synchronize.Synchronize", "synchronize"),
88-
("app.tasks.timeseries.backfill", "timeseries"),
89-
("app.tasks.timeseries.backfill_commits", "timeseries"),
90-
("app.tasks.timeseries.backfill_dataset", "timeseries"),
91-
("app.tasks.timeseries.delete", "timeseries"),
92-
("app.tasks.upload.Upload", "upload"),
93-
("app.tasks.upload.UploadProcessor", "upload"),
94-
("app.tasks.upload.UploadFinisher", "upload"),
66+
("app.cron.healthcheck.HealthCheckTask", TaskConfigGroup.healthcheck.value),
67+
("app.cron.profiling.findinguncollected", TaskConfigGroup.profiling.value),
68+
(
69+
"app.tasks.add_to_sendgrid_list.AddToSendgridList",
70+
TaskConfigGroup.add_to_sendgrid_list.value,
71+
),
72+
("app.tasks.archive.MigrateToArchive", TaskConfigGroup.archive.value),
73+
("app.tasks.verify_bot.VerifyBot", TaskConfigGroup.verify_bot.value),
74+
("app.tasks.comment.Comment", TaskConfigGroup.comment.value),
75+
("app.tasks.commit_update.CommitUpdate", TaskConfigGroup.commit_update.value),
76+
(
77+
"app.tasks.compute_comparison.ComputeComparison",
78+
TaskConfigGroup.compute_comparison.value,
79+
),
80+
("app.tasks.delete_owner.DeleteOwner", TaskConfigGroup.delete_owner.value),
81+
("app.tasks.flush_repo.FlushRepo", TaskConfigGroup.flush_repo.value),
82+
("app.tasks.sync_plans.SyncPlans", TaskConfigGroup.sync_plans.value),
83+
(
84+
"app.tasks.new_user_activated.NewUserActivated",
85+
TaskConfigGroup.new_user_activated.value,
86+
),
87+
("app.tasks.notify.Notify", TaskConfigGroup.notify.value),
88+
("app.tasks.profiling.collection", TaskConfigGroup.profiling.value),
89+
("app.tasks.profiling.normalizer", TaskConfigGroup.profiling.value),
90+
("app.tasks.profiling.summarization", TaskConfigGroup.profiling.value),
91+
("app.tasks.pulls.Sync", TaskConfigGroup.pulls.value),
92+
(
93+
"app.tasks.remove_webhook.RemoveOldHook",
94+
TaskConfigGroup.remove_webhook.value,
95+
),
96+
("app.tasks.status.SetError", TaskConfigGroup.status.value),
97+
("app.tasks.status.SetPending", TaskConfigGroup.status.value),
98+
("app.tasks.sync_repos.SyncRepos", TaskConfigGroup.sync_repos.value),
99+
("app.tasks.sync_teams.SyncTeams", TaskConfigGroup.sync_teams.value),
100+
("app.tasks.synchronize.Synchronize", TaskConfigGroup.synchronize.value),
101+
("app.tasks.timeseries.backfill", TaskConfigGroup.timeseries.value),
102+
("app.tasks.timeseries.backfill_commits", TaskConfigGroup.timeseries.value),
103+
("app.tasks.timeseries.backfill_dataset", TaskConfigGroup.timeseries.value),
104+
("app.tasks.timeseries.delete", TaskConfigGroup.timeseries.value),
105+
("app.tasks.upload.Upload", TaskConfigGroup.upload.value),
106+
("app.tasks.upload.UploadProcessor", TaskConfigGroup.upload.value),
107+
("app.tasks.upload.UploadFinisher", TaskConfigGroup.upload.value),
95108
("unknown.task", None),
96109
("app.tasks.legacy", None),
97110
],

tests/unit/test_router.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,44 @@
11
from shared.billing import BillingPlan
2-
from shared.celery_config import upload_task_name
2+
from shared.celery_config import timeseries_backfill_task_name, upload_task_name
33
from shared.celery_router import route_tasks_based_on_user_plan
44

55

66
def test_route_tasks_based_on_user_plan_defaults():
77
assert route_tasks_based_on_user_plan(
88
upload_task_name, BillingPlan.users_basic.db_name
9-
) == {"queue": "celery"}
9+
) == {"queue": "celery", "extra_config": {}}
1010
assert route_tasks_based_on_user_plan(
1111
upload_task_name, BillingPlan.enterprise_cloud_monthly.db_name
12-
) == {"queue": "enterprise_celery"}
12+
) == {"queue": "enterprise_celery", "extra_config": {}}
1313
assert route_tasks_based_on_user_plan(
1414
"misterious_task", BillingPlan.users_basic.db_name
15-
) == {"queue": "celery"}
15+
) == {"queue": "celery", "extra_config": {}}
1616
assert route_tasks_based_on_user_plan(
1717
"misterious_task", BillingPlan.enterprise_cloud_monthly.db_name
18-
) == {"queue": "enterprise_celery"}
18+
) == {"queue": "enterprise_celery", "extra_config": {}}
19+
20+
21+
def test_route_tasks_with_config(mock_configuration):
22+
mock_configuration._params["setup"] = {
23+
"tasks": {
24+
"celery": {"enterprise": {"soft_timelimit": 100, "hard_timelimit": 200}},
25+
"timeseries": {
26+
"enterprise": {"soft_timelimit": 400, "hard_timelimit": 500}
27+
},
28+
}
29+
}
30+
assert route_tasks_based_on_user_plan(
31+
upload_task_name, BillingPlan.users_basic.db_name
32+
) == {"queue": "celery", "extra_config": {}}
33+
assert route_tasks_based_on_user_plan(
34+
upload_task_name, BillingPlan.enterprise_cloud_monthly.db_name
35+
) == {
36+
"queue": "enterprise_celery",
37+
"extra_config": {"soft_timelimit": 100, "hard_timelimit": 200},
38+
}
39+
assert route_tasks_based_on_user_plan(
40+
timeseries_backfill_task_name, BillingPlan.enterprise_cloud_monthly.db_name
41+
) == {
42+
"queue": "enterprise_celery",
43+
"extra_config": {"soft_timelimit": 400, "hard_timelimit": 500},
44+
}

0 commit comments

Comments
 (0)