Skip to content

Commit

Permalink
fix(saml): Respect SOCIALACCOUNT_LOGIN_ON_GET
Browse files Browse the repository at this point in the history
  • Loading branch information
pennersr committed Jan 15, 2024
1 parent 0c49379 commit 3b65b11
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 14 deletions.
17 changes: 4 additions & 13 deletions allauth/socialaccount/providers/base/mixins.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
from django.shortcuts import render

from allauth.account import app_settings as account_app_settings
from allauth.socialaccount import app_settings
from allauth.socialaccount.providers.base.utils import respond_to_login_on_get


class OAuthLoginMixin:
def dispatch(self, request, *args, **kwargs):
provider = self.adapter.get_provider()
if (not app_settings.LOGIN_ON_GET) and request.method == "GET":
return render(
request,
"socialaccount/login." + account_app_settings.TEMPLATE_EXTENSION,
{
"provider": provider,
"process": request.GET.get("process"),
},
)
resp = respond_to_login_on_get(request, provider)
if resp:
return resp
return self.login(request, *args, **kwargs)

def login(self, request, *args, **kwargs):
Expand Down
16 changes: 16 additions & 0 deletions allauth/socialaccount/providers/base/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from django.shortcuts import render

from allauth.account import app_settings as account_app_settings
from allauth.socialaccount import app_settings


def respond_to_login_on_get(request, provider):
if (not app_settings.LOGIN_ON_GET) and request.method == "GET":
return render(
request,
"socialaccount/login." + account_app_settings.TEMPLATE_EXTENSION,
{
"provider": provider,
"process": request.GET.get("process"),
},
)
9 changes: 8 additions & 1 deletion allauth/socialaccount/providers/saml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from django.utils.http import urlencode

import pytest
from pytest_django.asserts import assertTemplateUsed

from allauth.account.models import EmailAddress
from allauth.socialaccount.adapter import get_adapter
Expand Down Expand Up @@ -79,6 +80,12 @@ def test_sls_get(client, db, saml_settings):
assert resp.status_code == 400


def test_login_on_get(client, db, saml_settings):
resp = client.get(reverse("saml_login", kwargs={"organization_slug": "org"}))
assert resp.status_code == 200
assertTemplateUsed(resp, "socialaccount/login.html")


@pytest.mark.parametrize(
"query,expected_relay_state",
[
Expand All @@ -89,7 +96,7 @@ def test_sls_get(client, db, saml_settings):
],
)
def test_login(client, db, saml_settings, query, expected_relay_state):
resp = client.get(
resp = client.post(
reverse("saml_login", kwargs={"organization_slug": "org"}) + query
)
assert resp.status_code == 302
Expand Down
4 changes: 4 additions & 0 deletions allauth/socialaccount/providers/saml/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from allauth.socialaccount.models import SocialLogin
from allauth.socialaccount.providers.base.constants import AuthError
from allauth.socialaccount.providers.base.utils import respond_to_login_on_get
from allauth.socialaccount.sessions import LoginSession

from .utils import (
Expand Down Expand Up @@ -185,6 +186,9 @@ def dispatch(self, request, organization_slug):
class LoginView(SAMLViewMixin, View):
def dispatch(self, request, organization_slug):
provider = self.get_provider(organization_slug)
resp = respond_to_login_on_get(request, provider)
if resp:
return resp
auth = self.build_auth(provider, organization_slug)
process = self.request.GET.get("process")
next_url = get_next_redirect_url(request)
Expand Down

0 comments on commit 3b65b11

Please sign in to comment.