From 0c493797b2f279ee0eaa0b5d4b186a2120cbbb40 Mon Sep 17 00:00:00 2001 From: Raymond Penners Date: Mon, 15 Jan 2024 19:34:37 +0100 Subject: [PATCH] fix(saml): Handle wrong methods at acs/sls --- allauth/socialaccount/providers/saml/tests.py | 13 +++++++ allauth/socialaccount/providers/saml/views.py | 36 +++++++++++++++---- 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/allauth/socialaccount/providers/saml/tests.py b/allauth/socialaccount/providers/saml/tests.py index 2c0a66fdb8..32525e80d6 100644 --- a/allauth/socialaccount/providers/saml/tests.py +++ b/allauth/socialaccount/providers/saml/tests.py @@ -66,6 +66,19 @@ def test_acs_error(client, db, saml_settings): assert "socialaccount/authentication_error.html" in (t.name for t in resp.templates) +def test_acs_get(client, db, saml_settings): + """ACS expects POST""" + resp = client.get(reverse("saml_acs", kwargs={"organization_slug": "org"})) + assert resp.status_code == 200 + assert "socialaccount/authentication_error.html" in (t.name for t in resp.templates) + + +def test_sls_get(client, db, saml_settings): + """SLS expects POST""" + resp = client.get(reverse("saml_sls", kwargs={"organization_slug": "org"})) + assert resp.status_code == 400 + + @pytest.mark.parametrize( "query,expected_relay_state", [ diff --git a/allauth/socialaccount/providers/saml/views.py b/allauth/socialaccount/providers/saml/views.py index 7adb359498..6a863fb77d 100644 --- a/allauth/socialaccount/providers/saml/views.py +++ b/allauth/socialaccount/providers/saml/views.py @@ -8,6 +8,7 @@ from django.views.decorators.csrf import csrf_exempt from onelogin.saml2.auth import OneLogin_Saml2_Auth, OneLogin_Saml2_Settings +from onelogin.saml2.errors import OneLogin_Saml2_Error from allauth.account.adapter import get_adapter as get_account_adapter from allauth.account.utils import get_next_redirect_url @@ -54,24 +55,30 @@ class ACSView(SAMLViewMixin, View): def dispatch(self, request, organization_slug): provider = self.get_provider(organization_slug) auth = self.build_auth(provider, organization_slug) + error_reason = None + errors = [] try: auth.process_response() except binascii.Error: errors = ["invalid_response"] - else: + error_reason = "Invalid response" + except OneLogin_Saml2_Error as e: + error_reason = str(e) + if not errors: errors = auth.get_errors() if errors: # e.g. ['invalid_response'] + error_reason = auth.get_last_error_reason() or error_reason logger.error( - "Error processing SAML response: %s: %s" - % (", ".join(errors), auth.get_last_error_reason()) + "Error processing SAML ACS response: %s: %s" + % (", ".join(errors), error_reason) ) return render_authentication_error( request, provider, extra_context={ "saml_errors": errors, - "saml_last_error_reason": auth.get_last_error_reason(), + "saml_last_error_reason": error_reason, }, ) if not auth.is_authenticated(): @@ -126,9 +133,24 @@ def dispatch(self, request, organization_slug): def force_logout(): account_adapter.logout(request) - redirect_to = auth.process_slo( - delete_session_cb=force_logout, keep_local_session=not should_logout - ) + redirect_to = None + error_reason = None + try: + redirect_to = auth.process_slo( + delete_session_cb=force_logout, keep_local_session=not should_logout + ) + except OneLogin_Saml2_Error as e: + error_reason = str(e) + errors = auth.get_errors() + if errors: + error_reason = auth.get_last_error_reason() or error_reason + logger.error( + "Error processing SAML SLS response: %s: %s" + % (", ".join(errors), error_reason) + ) + resp = HttpResponse(error_reason, content_type="text/plain") + resp.status_code = 400 + return resp if not redirect_to: redirect_to = account_adapter.get_logout_redirect_url(request) return HttpResponseRedirect(redirect_to)