Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor PublicTask into a decorator task #4656

Merged
merged 2 commits into from Oct 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion readthedocs/core/utils/tasks/__init__.py
Expand Up @@ -3,7 +3,6 @@
from .permission_checks import user_id_matches # noqa for unused import
from .public import PublicTask # noqa
from .public import TaskNoPermission # noqa
from .public import permission_check # noqa
from .public import get_public_task_data # noqa
from .retrieve import TaskNotFound # noqa
from .retrieve import get_task_data # noqa
75 changes: 43 additions & 32 deletions readthedocs/core/utils/tasks/public.py
@@ -1,16 +1,20 @@
"""Celery tasks with publicly viewable status"""

from __future__ import absolute_import
from __future__ import (
absolute_import,
division,
print_function,
unicode_literals,
)

from celery import Task, states
from django.conf import settings

from .retrieve import TaskNotFound
from .retrieve import get_task_data

from .retrieve import TaskNotFound, get_task_data

__all__ = (
'PublicTask', 'TaskNoPermission', 'permission_check',
'get_public_task_data')
'PublicTask', 'TaskNoPermission', 'get_public_task_data'
)


STATUS_UPDATES_ENABLED = not getattr(settings, 'CELERY_ALWAYS_EAGER', False)
Expand All @@ -19,22 +23,20 @@
class PublicTask(Task):

"""
See oauth.tasks for usage example.
Encapsulates common behaviour to expose a task publicly.

Subclasses need to define a ``run_public`` method.
"""
Tasks should use this class as ``base``. And define a ``check_permission``
property or use the ``permission_check`` decorator.

public_name = 'unknown'
The check_permission should be a function like:
function(request, state, context), and needs to return a boolean value.

@classmethod
def check_permission(cls, request, state, context):
"""Override this method to define who can monitor this task."""
# pylint: disable=unused-argument
return False
See oauth.tasks for usage example.
"""

def get_task_data(self):
"""Return tuple with state to be set next and results task."""
state = 'STARTED'
state = states.STARTED
info = {
'task_name': self.name,
'context': self.request.get('permission_context', {}),
Expand Down Expand Up @@ -66,12 +68,13 @@ def set_public_data(self, data):
self.request.update(public_data=data)
self.update_progress_data()

def run(self, *args, **kwargs):
def __call__(self, *args, **kwargs):
# We override __call__ to let tasks use the run method.
error = False
exception_raised = None
self.set_permission_context(kwargs)
try:
result = self.run_public(*args, **kwargs)
result = self.run(*args, **kwargs)
except Exception as e:
# With Celery 4 we lost the ability to keep our data dictionary into
# ``AsyncResult.info`` when an exception was raised inside the
Expand All @@ -90,22 +93,26 @@ def run(self, *args, **kwargs):

return info

@staticmethod
def permission_check(check):
"""
Decorator for tasks that have PublicTask as base.

def permission_check(check):
"""
Class decorator for subclasses of PublicTask to sprinkle in re-usable
.. note::

The decorator should be on top of the task decorator.

permission checks::
permission checks::

@permission_check(user_id_matches)
class MyTask(PublicTask):
def run_public(self, user_id):
@PublicTask.permission_check(user_id_matches)
@celery.task(base=PublicTask)
def my_public_task(user_id):
pass
"""
def decorator(cls):
cls.check_permission = staticmethod(check)
return cls
return decorator
"""
def decorator(func):
func.check_permission = check
return func
return decorator


class TaskNoPermission(Exception):
Expand Down Expand Up @@ -139,5 +146,9 @@ def get_public_task_data(request, task_id):
context = info.get('context', {})
if not task.check_permission(request, state, context):
raise TaskNoPermission(task_id)
public_name = task.public_name
return public_name, state, info.get('public_data', {}), info.get('error', None)
return (
task.name,
state,
info.get('public_data', {}),
info.get('error', None),
)
12 changes: 9 additions & 3 deletions readthedocs/core/utils/tasks/retrieve.py
@@ -1,9 +1,15 @@
"""Utilities for retrieving task data."""

from __future__ import absolute_import
from __future__ import (
absolute_import,
division,
print_function,
unicode_literals,
)

from celery import states
from celery.result import AsyncResult


__all__ = ('TaskNotFound', 'get_task_data')


Expand All @@ -23,7 +29,7 @@ def get_task_data(task_id):

result = AsyncResult(task_id)
state, info = result.state, result.info
if state == 'PENDING':
if state == states.PENDING:
raise TaskNotFound(task_id)
if 'task_name' not in info:
raise TaskNotFound(task_id)
Expand Down
5 changes: 0 additions & 5 deletions readthedocs/oauth/apps.py
Expand Up @@ -5,8 +5,3 @@

class OAuthConfig(AppConfig):
name = 'readthedocs.oauth'

def ready(self):
from .tasks import SyncRemoteRepositories
from readthedocs.worker import app
app.tasks.register(SyncRemoteRepositories)
35 changes: 16 additions & 19 deletions readthedocs/oauth/tasks.py
Expand Up @@ -2,17 +2,22 @@
"""Tasks for OAuth services."""

from __future__ import (
absolute_import, division, print_function, unicode_literals)
absolute_import,
division,
print_function,
unicode_literals,
)

import logging

from allauth.socialaccount.providers import registry as allauth_registry
from django.contrib.auth.models import User

from readthedocs.core.utils.tasks import (
PublicTask, permission_check, user_id_matches)
from readthedocs.core.utils.tasks import PublicTask, user_id_matches
from readthedocs.oauth.notifications import (
AttachWebhookNotification, InvalidProjectWebhookNotification)
AttachWebhookNotification,
InvalidProjectWebhookNotification,
)
from readthedocs.projects.models import Project
from readthedocs.worker import app

Expand All @@ -21,21 +26,13 @@
log = logging.getLogger(__name__)


@permission_check(user_id_matches)
class SyncRemoteRepositories(PublicTask):

name = __name__ + '.sync_remote_repositories'
public_name = 'sync_remote_repositories'
queue = 'web'

def run_public(self, user_id):
user = User.objects.get(pk=user_id)
for service_cls in registry:
for service in service_cls.for_user(user):
service.sync()


sync_remote_repositories = SyncRemoteRepositories()
@PublicTask.permission_check(user_id_matches)
@app.task(queue='web', base=PublicTask)
def sync_remote_repositories(user_id):
user = User.objects.get(pk=user_id)
for service_cls in registry:
for service in service_cls.for_user(user):
service.sync()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The api for defining a public task is cleaner now :D

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much nicer!



@app.task(queue='web')
Expand Down
29 changes: 19 additions & 10 deletions readthedocs/restapi/views/task_views.py
@@ -1,19 +1,23 @@
"""Endpoints relating to task/job status, etc."""

from __future__ import absolute_import
from __future__ import (
absolute_import,
division,
print_function,
unicode_literals,
)

import logging

from django.core.urlresolvers import reverse
from redis import ConnectionError
from rest_framework import decorators, permissions
from rest_framework.renderers import JSONRenderer
from rest_framework.response import Response
from redis import ConnectionError

from readthedocs.core.utils.tasks import TaskNoPermission
from readthedocs.core.utils.tasks import get_public_task_data
from readthedocs.core.utils.tasks import TaskNoPermission, get_public_task_data
from readthedocs.oauth import tasks


log = logging.getLogger(__name__)


Expand Down Expand Up @@ -43,20 +47,25 @@ def get_status_data(task_name, state, data, error=None):
@decorators.renderer_classes((JSONRenderer,))
def job_status(request, task_id):
try:
task_name, state, public_data, error = get_public_task_data(request, task_id)
task_name, state, public_data, error = get_public_task_data(
request, task_id
)
except (TaskNoPermission, ConnectionError):
return Response(
get_status_data('unknown', 'PENDING', {}))
get_status_data('unknown', 'PENDING', {})
)
return Response(
get_status_data(task_name, state, public_data, error))
get_status_data(task_name, state, public_data, error)
)


@decorators.api_view(['POST'])
@decorators.permission_classes((permissions.IsAuthenticated,))
@decorators.renderer_classes((JSONRenderer,))
def sync_remote_repositories(request):
result = tasks.SyncRemoteRepositories().delay(
user_id=request.user.id)
result = tasks.sync_remote_repositories.delay(
user_id=request.user.id
)
task_id = result.task_id
return Response({
'task_id': task_id,
Expand Down
12 changes: 4 additions & 8 deletions readthedocs/rtd_tests/tests/test_celery.py
Expand Up @@ -191,15 +191,11 @@ def test_public_task_exception(self):
from readthedocs.core.utils.tasks import PublicTask
from readthedocs.worker import app

class PublicTaskException(PublicTask):
name = 'public_task_exception'
@app.task(name='public_task_exception', base=PublicTask)
def public_task_exception():
raise Exception('Something bad happened')

def run_public(self):
raise Exception('Something bad happened')

app.tasks.register(PublicTaskException)
exception_task = PublicTaskException()
result = exception_task.delay()
result = public_task_exception.delay()

# although the task risen an exception, it's success since we add the
# exception into the ``info`` attributes
Expand Down