Skip to content

Commit

Permalink
refactor: Proper ?next= redirect handling throughout
Browse files Browse the repository at this point in the history
  • Loading branch information
pennersr committed Apr 12, 2024
1 parent fb13bc5 commit 26832f0
Show file tree
Hide file tree
Showing 19 changed files with 424 additions and 340 deletions.
178 changes: 178 additions & 0 deletions allauth/account/mixins.py
@@ -0,0 +1,178 @@
from django.contrib.auth import REDIRECT_FIELD_NAME
from django.core.exceptions import ImproperlyConfigured
from django.http import HttpResponsePermanentRedirect, HttpResponseRedirect
from django.utils.html import format_html

from allauth.account import app_settings
from allauth.account.adapter import get_adapter
from allauth.account.internal import flows
from allauth.account.utils import (
get_login_redirect_url,
get_next_redirect_url,
passthrough_next_redirect_url,
)
from allauth.core.exceptions import ImmediateHttpResponse
from allauth.utils import get_request_param


def _ajax_response(request, response, form=None, data=None):
adapter = get_adapter()
if adapter.is_ajax(request):
if isinstance(response, HttpResponseRedirect) or isinstance(
response, HttpResponsePermanentRedirect
):
redirect_to = response["Location"]
else:
redirect_to = None
response = adapter.ajax_response(
request, response, form=form, data=data, redirect_to=redirect_to
)
return response


class RedirectAuthenticatedUserMixin:
def dispatch(self, request, *args, **kwargs):
if request.user.is_authenticated and app_settings.AUTHENTICATED_LOGIN_REDIRECTS:
redirect_to = self.get_authenticated_redirect_url()
response = HttpResponseRedirect(redirect_to)
return _ajax_response(request, response)
else:
response = super().dispatch(request, *args, **kwargs)
return response

def get_authenticated_redirect_url(self):
redirect_field_name = self.redirect_field_name
return get_login_redirect_url(
self.request,
url=self.get_success_url(),
redirect_field_name=redirect_field_name,
)


class LogoutFunctionalityMixin:
def logout(self):
flows.logout.logout(self.request)


class AjaxCapableProcessFormViewMixin:
def get(self, request, *args, **kwargs):
response = super().get(request, *args, **kwargs)
form = self.get_form()
return _ajax_response(
self.request, response, form=form, data=self._get_ajax_data_if()
)

def post(self, request, *args, **kwargs):
form_class = self.get_form_class()
form = self.get_form(form_class)
if form.is_valid():
response = self.form_valid(form)
else:
response = self.form_invalid(form)
return _ajax_response(
self.request, response, form=form, data=self._get_ajax_data_if()
)

def get_form(self, form_class=None):
form = getattr(self, "_cached_form", None)
if form is None:
form = super().get_form(form_class)
self._cached_form = form
return form

def _get_ajax_data_if(self):
return (
self.get_ajax_data()
if get_adapter(self.request).is_ajax(self.request)
else None
)

def get_ajax_data(self):
return None


class CloseableSignupMixin:
template_name_signup_closed = (
"account/signup_closed." + app_settings.TEMPLATE_EXTENSION
)

def dispatch(self, request, *args, **kwargs):
try:
if not self.is_open():
return self.closed()
except ImmediateHttpResponse as e:
return e.response
return super().dispatch(request, *args, **kwargs)

def is_open(self):
return get_adapter(self.request).is_open_for_signup(self.request)

def closed(self):
response_kwargs = {
"request": self.request,
"template": self.template_name_signup_closed,
}
return self.response_class(**response_kwargs)


class NextRedirectMixin:
redirect_field_name = REDIRECT_FIELD_NAME

def get_context_data(self, **kwargs):
ret = super().get_context_data(**kwargs)
redirect_field_value = get_request_param(self.request, self.redirect_field_name)
ret.update(
{
"redirect_field_name": self.redirect_field_name,
"redirect_field_value": redirect_field_value,
"redirect_field": format_html(
'<input type="hidden" name="{}" value="{}">',
self.redirect_field_name,
redirect_field_value,
)
if redirect_field_value
else "",
}
)
return ret

def get_success_url(self):
"""
We're in a mixin, so we cannot rely on the fact that our super() has a get_success_url.
Also, we want to check for -- in this order:
1) The `?next=/foo`
2) The `get_succes_url()` if available.
3) The `.success_url` if available.
4) A fallback default success URL: `get_default_success_url()`.
"""
url = self.get_next_url()
if url:
return url

if not url:
if hasattr(super(), "get_success_url"):
try:
url = super().get_success_url()
except ImproperlyConfigured:
# Django's default get_success_url() checks self.succes_url,
# and throws this if that is not set. Yet, in our case, we
# want to fallback to the default.
pass
elif hasattr(self, "success_url"):
url = self.success_url
if url:
url = str(url) # reverse_lazy
if not url:
url = self.get_default_success_url()
return url

def get_default_success_url(self):
return None

def get_next_url(self):
return get_next_redirect_url(self.request, self.redirect_field_name)

def passthrough_next_url(self, url):
return passthrough_next_redirect_url(
self.request, url, self.redirect_field_name
)
39 changes: 22 additions & 17 deletions allauth/account/tests/test_change_password.py
Expand Up @@ -19,16 +19,13 @@ def test_set_usable_password_redirects_to_change(auth_client, user):


@pytest.mark.parametrize(
"logout,redirect_chain",
"logout,next_url,redirect_chain",
[
(
False,
[
(reverse("account_change_password"), 302),
],
),
(False, "", [(reverse("account_change_password"), 302)]),
(False, "/foo", [("/foo", 302)]),
(
True,
"",
[
(reverse("account_change_password"), 302),
(
Expand All @@ -37,33 +34,36 @@ def test_set_usable_password_redirects_to_change(auth_client, user):
),
],
),
(True, "/foo", [("/foo", 302)]),
],
)
def test_set_password(client, user, password_factory, logout, settings, redirect_chain):
def test_set_password(
client, user, next_url, password_factory, logout, settings, redirect_chain
):
settings.ACCOUNT_LOGOUT_ON_PASSWORD_CHANGE = logout
user.set_unusable_password()
user.save()
client.force_login(user)
password = password_factory()
data = {"password1": password, "password2": password}
if next_url:
data["next"] = next_url
resp = client.post(
reverse("account_set_password"),
{"password1": password, "password2": password},
data,
follow=True,
)
assert resp.redirect_chain == redirect_chain


@pytest.mark.parametrize(
"logout,redirect_chain",
"logout,next_url,redirect_chain",
[
(
False,
[
(reverse("account_change_password"), 302),
],
),
(False, "", [(reverse("account_change_password"), 302)]),
(False, "/foo", [("/foo", 302)]),
(
True,
"",
[
(reverse("account_change_password"), 302),
(
Expand All @@ -72,12 +72,14 @@ def test_set_password(client, user, password_factory, logout, settings, redirect
),
],
),
(True, "/foo", [("/foo", 302)]),
],
)
def test_change_password(
auth_client,
user,
user_password,
next_url,
password_factory,
logout,
settings,
Expand All @@ -87,9 +89,12 @@ def test_change_password(
settings.ACCOUNT_LOGOUT_ON_PASSWORD_CHANGE = logout
settings.ACCOUNT_EMAIL_NOTIFICATIONS = True
password = password_factory()
data = {"oldpassword": user_password, "password1": password, "password2": password}
if next_url:
data["next"] = next_url
resp = auth_client.post(
reverse("account_change_password"),
{"oldpassword": user_password, "password1": password, "password2": password},
data,
follow=True,
)
assert resp.redirect_chain == redirect_chain
Expand Down
13 changes: 11 additions & 2 deletions allauth/account/tests/test_confirm_email.py
Expand Up @@ -8,6 +8,7 @@
from django.urls import reverse
from django.utils.timezone import now

import pytest
from pytest_django.asserts import (
assertRedirects,
assertTemplateNotUsed,
Expand All @@ -26,7 +27,14 @@
from .test_models import UUIDUser


def test_login_on_confirm(user_factory, client):
@pytest.mark.parametrize(
"query,expected_location",
[
("", settings.LOGIN_REDIRECT_URL),
("?next=/foo", "/foo"),
],
)
def test_login_on_confirm(user_factory, client, query, expected_location):
settings.ACCOUNT_EMAIL_CONFIRMATION_HMAC = True
settings.ACCOUNT_LOGIN_ON_EMAIL_CONFIRMATION = True
user = user_factory(email_verified=False)
Expand All @@ -41,7 +49,8 @@ def test_login_on_confirm(user_factory, client):
session["account_user"] = user_pk_to_url_str(user)
session.save()

resp = client.post(reverse("account_confirm_email", args=[key]))
resp = client.post(reverse("account_confirm_email", args=[key]) + query)
assert resp["location"] == expected_location
email = EmailAddress.objects.get(pk=email.pk)
assert email.verified

Expand Down

0 comments on commit 26832f0

Please sign in to comment.