diff --git a/project/tests/utils.py b/project/tests/utils.py index 974a06a..b74c7cb 100644 --- a/project/tests/utils.py +++ b/project/tests/utils.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from flask import url_for from flask.ext.testing import TestCase from project import create_app as app_factory @@ -6,15 +7,11 @@ from project.models import User -class DisableCsrf: - def __init__(self, app): - self.app = app - - def __enter__(self): - self.app.config['WTF_CSRF_ENABLED'] = False - - def __exit__(self, exc_type, exc_val, exc_tb): - self.app.config['WTF_CSRF_ENABLED'] = True +@contextmanager +def disable_csrf(app): + app.config['WTF_CSRF_ENABLED'] = False + yield + app.config['WTF_CSRF_ENABLED'] = True class ProjectTestCase(TestCase): @@ -25,7 +22,7 @@ def log_in(self): user = User.query.get(1) # assume that login is equal to password credentials = {"login": user.login, "password": user.login} - with DisableCsrf(self.app): + with disable_csrf(self.app): self.client.post(url_for("auth.login"), data=credentials) # noinspection PyAttributeOutsideInit