Skip to content

Commit

Permalink
Added some basic tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Ask Solem committed Sep 8, 2010
1 parent 09a67e2 commit 52f8a8c
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 19 deletions.
45 changes: 28 additions & 17 deletions flaskext/celery.py
@@ -1,16 +1,25 @@
# -*- coding: utf-8 -*-
"""
flaskext.celery
~~~~~~~~~~~~~~~
Celery integration for Flask.
:copyright: (c) 2010 Ask Solem <ask@celeryproject.org>
:license: BSD, see LICENSE for more details.
"""
from __future__ import absolute_import

import os

from flask import g
from functools import partial

from celery.datastructures import AttributeDict
from celery.loaders import default as _default
from celery.utils import get_full_cls_name




class FlaskLoader(_default.Loader):

def read_configuration(self):
Expand All @@ -29,31 +38,30 @@ def __init__(self, app):
from celery.conf import prepare
prepare(self.conf, AttributeDict(self.app.config))


def create_task_cls(self):
from celery.backends import default_backend, get_backend_cls
from celery.task.base import Task
conf = self.conf
defaults = self.conf

class BaseFlaskTask(Task):
abstract = True
app = self.app
ignore_result = conf.IGNORE_RESULT
serializer = conf.TASK_SERIALIZER
rate_limit = conf.DEFAULT_RATE_LIMIT
track_started = conf.TRACK_STARTED
acks_late = conf.ACKS_LATE
backend = get_backend_cls(conf.RESULT_BACKEND)()
ignore_result = defaults.IGNORE_RESULT
serializer = defaults.TASK_SERIALIZER
rate_limit = defaults.DEFAULT_RATE_LIMIT
track_started = defaults.TRACK_STARTED
acks_late = defaults.ACKS_LATE
backend = get_backend_cls(defaults.RESULT_BACKEND)()

@classmethod
def apply_async(self, *args, **kwargs):
def apply_async(cls, *args, **kwargs):
if not kwargs.get("connection") or kwargs.get("publisher"):
kwargs["connection"] = self.establish_connection(
kwargs["connection"] = cls.establish_connection(
connect_timeout=kwargs.get("connect_timeout"))
return super(BaseFlaskTask, self).apply_async(*args, **kwargs)
return super(BaseFlaskTask, cls).apply_async(*args, **kwargs)

@classmethod
def establish_connection(self, *args, **kwargs):
def establish_connection(cls, *args, **kwargs):
from celery.messaging import establish_connection
kwargs["defaults"] = conf
return establish_connection(*args, **kwargs)
Expand All @@ -62,8 +70,11 @@ def establish_connection(self, *args, **kwargs):

def task(self, *args, **kwargs):
from celery.decorators import task
kwargs.setdefault("base", self.create_task_cls())
return task(*args, **kwargs)
if len(args) == 1 and callable(args[0]):
return task(base=self.create_task_cls())(*args)
if "base" not in kwargs:
kwargs["base"] = self.create_task_cls()
return task(*args, **kwargs)

def Worker(self, **kwargs):
from celery.bin.celeryd import Worker
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Expand Up @@ -32,6 +32,6 @@
'Operating System :: OS Independent',
'Programming Language :: Python',
'Topic :: Internet :: WWW/HTTP :: Dynamic Content',
'Topic :: Software Development :: Libraries :: Python Modules'
]
'Topic :: Software Development :: Libraries :: Python Modules',
],
)
Empty file added tests/__init__.py
Empty file.
40 changes: 40 additions & 0 deletions tests/test_basic.py
@@ -0,0 +1,40 @@
import unittest2 as unittest

import flask

from flaskext import celery


class test_Celery(unittest.TestCase):

def setUp(self):
self.app = flask.Flask(__name__)
self.c = celery.Celery(self.app)

def test_loader_is_configured(self):
from celery.loaders import current_loader, load_settings
loader = current_loader()
self.assertIsInstance(loader, celery.FlaskLoader)
settings = load_settings()
self.assertTrue(loader.configured)

def test_task_honors_app_settings(self):
app = flask.Flask(__name__)
app.config["CELERY_IGNORE_RESULT"] = True
app.config["CELERY_TASK_SERIALIZER"] = "msgpack"
c = celery.Celery(app)

@c.task(foo=1)
def add_task_args(x, y):
return x + y

@c.task
def add_task_noargs(x, y):
return x + y

for task in add_task_args, add_task_noargs:
self.assertTrue(any("BaseFlaskTask" in repr(cls)
for cls in task.__class__.mro()))
self.assertEqual(task(2, 2), 4)
self.assertEqual(task.serializer, "msgpack")
self.assertTrue(task.ignore_result)

0 comments on commit 52f8a8c

Please sign in to comment.