Skip to content

Commit

Permalink
set_token_storage_sqlalchemy()
Browse files Browse the repository at this point in the history
  • Loading branch information
singingwolfboy committed Oct 11, 2014
1 parent 3bd8462 commit ba9b02e
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 29 deletions.
46 changes: 42 additions & 4 deletions flask_dance/consumer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, name, import_name,
view_func=self.authorized,
)

self.create_token_accessors()
self.set_token_storage_session()
self.logged_in_funcs = []
self.before_app_request(self.assign_token_to_session)

Expand Down Expand Up @@ -91,17 +91,55 @@ def token_deleter(self, func):
"""
self.delete_token = func

def create_token_accessors(self):
def set_token_storage_session(self):
key = "{name}_oauth_token".format(name=self.name)

@self.token_getter
def get_token():
return flask.session.get(key)

@self.token_setter
def set_token(value):
flask.session[key] = value
def set_token(token):
flask.session[key] = token

@self.token_deleter
def delete_token():
del flask.session[key]

def set_token_storage_sqlalchemy(self, model, session, user=None):
"""
Set token accessors to work with a SQLAlchemy database for token
storage/retrieval.
"""
from sqlalchemy.orm.exc import NoResultFound

@self.token_getter
def get_token():
query = session.query(model).filter_by(provider=self.name)
if hasattr(model, "user"):
u = user() if callable(user) else user
query = query.filter_by(user=u)
try:
return query.one().token
except NoResultFound:
return None

@self.token_setter
def set_token(token):
kwargs = {
"provider": self.name,
"token": token,
}
if hasattr(model, "user"):
u = user() if callable(user) else user
kwargs["user"] = u
session.add(model(**kwargs))
session.commit()

@self.token_deleter
def delete_token():
query = session.query(model).filter_by(provider=self.name)
if hasattr(model, "user"):
u = user() if callable(user) else user
query = query.filter_by(user=u)
query.delete()
85 changes: 60 additions & 25 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,65 @@ def test_model(request):
app.config["SQLALCHEMY_DATABASE_URI"] = os.environ["DATABASE_URI"]
db = SQLAlchemy(app)

class OAuth(db.Model, OAuthMixin):
pass

blueprint = OAuth2ConsumerBlueprint("test-service", __name__,
client_id="client_id",
client_secret="client_secret",
state="random-string",
base_url="https://example.com",
authorization_url="https://example.com/oauth/authorize",
token_url="https://example.com/oauth/access_token",
redirect_url="/oauth_done",
)
blueprint.set_token_storage_sqlalchemy(OAuth, db.session)
app.secret_key = "secret"
app.register_blueprint(blueprint, url_prefix="/login")

db.create_all()
def done():
db.session.remove()
db.drop_all()
request.addfinalizer(done)

responses.add(
responses.POST,
"https://example.com/oauth/access_token",
body='{"access_token":"foobar","token_type":"bearer","scope":""}',
)
with app.test_client() as client:
# reset the session before the request
with client.session_transaction() as sess:
sess["test-service_oauth_state"] = "random-string"
# make the request
resp = client.get(
"/login/test-service/authorized?code=secret-code&state=random-string",
base_url="https://a.b.c",
)
# check that we redirected the client
assert resp.status_code == 302
assert resp.headers["Location"] == "https://a.b.c/oauth_done"

# check the database
authorizations = OAuth.query.all()
assert len(authorizations) == 1
oauth = authorizations[0]
assert oauth.provider == "test-service"
assert isinstance(oauth.token, dict)
assert oauth.token == {
"access_token": "foobar",
"token_type": "bearer",
"scope": [""],
}


@pytest.mark.usefixtures("responses")
def test_model_with_user(request):
app = Flask(__name__)
app.config["SQLALCHEMY_DATABASE_URI"] = os.environ["DATABASE_URI"]
db = SQLAlchemy(app)

class User(db.Model):
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(80))
Expand All @@ -34,6 +93,7 @@ class OAuth(db.Model, OAuthMixin):
token_url="https://example.com/oauth/access_token",
redirect_url="/oauth_done",
)
blueprint.set_token_storage_sqlalchemy(OAuth, db.session, lambda: User.query.first())
app.secret_key = "secret"
app.register_blueprint(blueprint, url_prefix="/login")

Expand All @@ -48,31 +108,6 @@ def done():
db.session.add(alice)
db.session.commit()

@blueprint.token_setter
def set_token(token):
alice = User.query.first()
oauth = OAuth(
provider=blueprint.name,
token=token,
user=alice,
)
db.session.add(oauth)
db.session.commit()

@blueprint.token_getter
def get_token():
alice = User.query.first()
query = OAuth.query.filter_by(provider=blueprint.name, user=alice)
try:
return query.one().token
except NoResultFound:
return None

@blueprint.token_deleter
def delete_token():
alice = User.query.first()
OAuth.query.filter_by(provider=blueprint.name, user=alice).delete()

responses.add(
responses.POST,
"https://example.com/oauth/access_token",
Expand Down

0 comments on commit ba9b02e

Please sign in to comment.