Skip to content

Commit

Permalink
Include queue extra config in task_router (#361)
Browse files Browse the repository at this point in the history
* Include queue extra config in `task_router`

We now want to have queue-specific config for enterprise-queues.
To get that specific config we need the task name, but it should only be applied if using enterprise-queue.

So a natural way to put it is when we get the queue for a task, we also get any extra config for it.

Putting this in `shared` to make sure it's consistent across `worker` and `api`.
  • Loading branch information
giovanni-guidini authored Mar 13, 2023
1 parent 6478226 commit 4b845db
Showing 4 changed files with 92 additions and 39 deletions.
1 change: 0 additions & 1 deletion shared/celery_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# http://docs.celeryq.org/en/latest/configuration.html#configuration

from typing import Optional

from shared.config import get_config
21 changes: 18 additions & 3 deletions shared/celery_router.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,30 @@
from shared.billing import BillingPlan, is_enterprise_cloud_plan
from shared.celery_config import BaseCeleryConfig
from shared.celery_config import BaseCeleryConfig, get_task_group
from shared.config import get_config


def route_tasks_based_on_user_plan(task_name: str, user_plan: str):
"""Helper function to dynamically route tasks based on the user plan.
This cannot be used as a celery router function directly.
Returns extra config for the queue, if any.
"""
default_task_queue = BaseCeleryConfig.task_routes.get(
task_name, dict(queue=BaseCeleryConfig.task_default_queue)
)["queue"]
billing_plan = BillingPlan.from_str(user_plan)
if is_enterprise_cloud_plan(billing_plan):
return {"queue": "enterprise_" + default_task_queue}
return {"queue": default_task_queue}
default_enterprise_queue_specific_config = get_config(
"setup", "tasks", "celery", "enterprise", default=dict()
)
this_queue_specific_config = get_config(
"setup",
"tasks",
get_task_group(task_name),
"enterprise",
default=default_enterprise_queue_specific_config,
)
return {
"queue": "enterprise_" + default_task_queue,
"extra_config": this_queue_specific_config,
}
return {"queue": default_task_queue, "extra_config": {}}
73 changes: 43 additions & 30 deletions tests/unit/test_celery_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

import shared.celery_config as celery_config
from shared.utils.enums import TaskConfigGroup


def test_celery_config():
@@ -62,36 +63,48 @@ def test_celery_config():
@pytest.mark.parametrize(
"task_name,task_group",
[
("app.cron.healthcheck.HealthCheckTask", "healthcheck"),
("app.cron.profiling.findinguncollected", "profiling"),
("app.tasks.add_to_sendgrid_list.AddToSendgridList", "add_to_sendgrid_list"),
("app.tasks.archive.MigrateToArchive", "archive"),
("app.tasks.verify_bot.VerifyBot", "verify_bot"),
("app.tasks.comment.Comment", "comment"),
("app.tasks.commit_update.CommitUpdate", "commit_update"),
("app.tasks.compute_comparison.ComputeComparison", "compute_comparison"),
("app.tasks.delete_owner.DeleteOwner", "delete_owner"),
("app.tasks.flush_repo.FlushRepo", "flush_repo"),
("app.tasks.sync_plans.SyncPlans", "sync_plans"),
("app.tasks.new_user_activated.NewUserActivated", "new_user_activated"),
("app.tasks.notify.Notify", "notify"),
("app.tasks.profiling.collection", "profiling"),
("app.tasks.profiling.normalizer", "profiling"),
("app.tasks.profiling.summarization", "profiling"),
("app.tasks.pulls.Sync", "pulls"),
("app.tasks.remove_webhook.RemoveOldHook", "remove_webhook"),
("app.tasks.status.SetError", "status"),
("app.tasks.status.SetPending", "status"),
("app.tasks.sync_repos.SyncRepos", "sync_repos"),
("app.tasks.sync_teams.SyncTeams", "sync_teams"),
("app.tasks.synchronize.Synchronize", "synchronize"),
("app.tasks.timeseries.backfill", "timeseries"),
("app.tasks.timeseries.backfill_commits", "timeseries"),
("app.tasks.timeseries.backfill_dataset", "timeseries"),
("app.tasks.timeseries.delete", "timeseries"),
("app.tasks.upload.Upload", "upload"),
("app.tasks.upload.UploadProcessor", "upload"),
("app.tasks.upload.UploadFinisher", "upload"),
("app.cron.healthcheck.HealthCheckTask", TaskConfigGroup.healthcheck.value),
("app.cron.profiling.findinguncollected", TaskConfigGroup.profiling.value),
(
"app.tasks.add_to_sendgrid_list.AddToSendgridList",
TaskConfigGroup.add_to_sendgrid_list.value,
),
("app.tasks.archive.MigrateToArchive", TaskConfigGroup.archive.value),
("app.tasks.verify_bot.VerifyBot", TaskConfigGroup.verify_bot.value),
("app.tasks.comment.Comment", TaskConfigGroup.comment.value),
("app.tasks.commit_update.CommitUpdate", TaskConfigGroup.commit_update.value),
(
"app.tasks.compute_comparison.ComputeComparison",
TaskConfigGroup.compute_comparison.value,
),
("app.tasks.delete_owner.DeleteOwner", TaskConfigGroup.delete_owner.value),
("app.tasks.flush_repo.FlushRepo", TaskConfigGroup.flush_repo.value),
("app.tasks.sync_plans.SyncPlans", TaskConfigGroup.sync_plans.value),
(
"app.tasks.new_user_activated.NewUserActivated",
TaskConfigGroup.new_user_activated.value,
),
("app.tasks.notify.Notify", TaskConfigGroup.notify.value),
("app.tasks.profiling.collection", TaskConfigGroup.profiling.value),
("app.tasks.profiling.normalizer", TaskConfigGroup.profiling.value),
("app.tasks.profiling.summarization", TaskConfigGroup.profiling.value),
("app.tasks.pulls.Sync", TaskConfigGroup.pulls.value),
(
"app.tasks.remove_webhook.RemoveOldHook",
TaskConfigGroup.remove_webhook.value,
),
("app.tasks.status.SetError", TaskConfigGroup.status.value),
("app.tasks.status.SetPending", TaskConfigGroup.status.value),
("app.tasks.sync_repos.SyncRepos", TaskConfigGroup.sync_repos.value),
("app.tasks.sync_teams.SyncTeams", TaskConfigGroup.sync_teams.value),
("app.tasks.synchronize.Synchronize", TaskConfigGroup.synchronize.value),
("app.tasks.timeseries.backfill", TaskConfigGroup.timeseries.value),
("app.tasks.timeseries.backfill_commits", TaskConfigGroup.timeseries.value),
("app.tasks.timeseries.backfill_dataset", TaskConfigGroup.timeseries.value),
("app.tasks.timeseries.delete", TaskConfigGroup.timeseries.value),
("app.tasks.upload.Upload", TaskConfigGroup.upload.value),
("app.tasks.upload.UploadProcessor", TaskConfigGroup.upload.value),
("app.tasks.upload.UploadFinisher", TaskConfigGroup.upload.value),
("unknown.task", None),
("app.tasks.legacy", None),
],
36 changes: 31 additions & 5 deletions tests/unit/test_router.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,44 @@
from shared.billing import BillingPlan
from shared.celery_config import upload_task_name
from shared.celery_config import timeseries_backfill_task_name, upload_task_name
from shared.celery_router import route_tasks_based_on_user_plan


def test_route_tasks_based_on_user_plan_defaults():
assert route_tasks_based_on_user_plan(
upload_task_name, BillingPlan.users_basic.db_name
) == {"queue": "celery"}
) == {"queue": "celery", "extra_config": {}}
assert route_tasks_based_on_user_plan(
upload_task_name, BillingPlan.enterprise_cloud_monthly.db_name
) == {"queue": "enterprise_celery"}
) == {"queue": "enterprise_celery", "extra_config": {}}
assert route_tasks_based_on_user_plan(
"misterious_task", BillingPlan.users_basic.db_name
) == {"queue": "celery"}
) == {"queue": "celery", "extra_config": {}}
assert route_tasks_based_on_user_plan(
"misterious_task", BillingPlan.enterprise_cloud_monthly.db_name
) == {"queue": "enterprise_celery"}
) == {"queue": "enterprise_celery", "extra_config": {}}


def test_route_tasks_with_config(mock_configuration):
mock_configuration._params["setup"] = {
"tasks": {
"celery": {"enterprise": {"soft_timelimit": 100, "hard_timelimit": 200}},
"timeseries": {
"enterprise": {"soft_timelimit": 400, "hard_timelimit": 500}
},
}
}
assert route_tasks_based_on_user_plan(
upload_task_name, BillingPlan.users_basic.db_name
) == {"queue": "celery", "extra_config": {}}
assert route_tasks_based_on_user_plan(
upload_task_name, BillingPlan.enterprise_cloud_monthly.db_name
) == {
"queue": "enterprise_celery",
"extra_config": {"soft_timelimit": 100, "hard_timelimit": 200},
}
assert route_tasks_based_on_user_plan(
timeseries_backfill_task_name, BillingPlan.enterprise_cloud_monthly.db_name
) == {
"queue": "enterprise_celery",
"extra_config": {"soft_timelimit": 400, "hard_timelimit": 500},
}

0 comments on commit 4b845db

Please sign in to comment.