From deba82ec696cc6ddb4d8eda435bf91d0796ec4c6 Mon Sep 17 00:00:00 2001 From: Ruben Grill Date: Sat, 1 Sep 2018 23:00:46 +0200 Subject: [PATCH] feat: Support next in rest --- nopassword/forms.py | 16 ++++++++++------ nopassword/rest/serializers.py | 3 ++- nopassword/rest/views.py | 6 ++++-- nopassword/views.py | 5 +++++ tests/test_rest_views.py | 17 ++++++++++++++--- tests/test_views.py | 18 +++++++++++++----- 6 files changed, 48 insertions(+), 17 deletions(-) diff --git a/nopassword/forms.py b/nopassword/forms.py index adedbba..16633fe 100644 --- a/nopassword/forms.py +++ b/nopassword/forms.py @@ -18,6 +18,8 @@ class LoginForm(forms.Form): 'inactive': _("This account is inactive."), } + next = forms.CharField(max_length=200, required=False, widget=forms.HiddenInput) + def __init__(self, *args, **kwargs): super(LoginForm, self).__init__(*args, **kwargs) @@ -42,14 +44,15 @@ def clean_username(self): code='inactive', ) - self.cleaned_data['login_code'] = models.LoginCode.create_code_for_user(user) + self.cleaned_data['user'] = user return username def save(self, request, login_code_url='login_code', domain_override=None, extra_context=None): - login_code = self.cleaned_data['login_code'] - login_code.next = request.GET.get('next') - login_code.save() + login_code = models.LoginCode.create_code_for_user( + user=self.cleaned_data['user'], + next=self.cleaned_data['next'], + ) if not domain_override: current_site = get_current_site(request) @@ -58,12 +61,11 @@ def save(self, request, login_code_url='login_code', domain_override=None, extra else: site_name = domain = domain_override - url = '{}://{}{}?code={}&next={}'.format( + url = '{}://{}{}?code={}'.format( 'https' if request.is_secure() else 'http', domain, resolve_url(login_code_url), login_code.code, - login_code.next, ) context = { @@ -78,6 +80,8 @@ def save(self, request, login_code_url='login_code', domain_override=None, extra self.send_login_code(login_code, context) + return login_code + def send_login_code(self, login_code, context, **kwargs): for backend in get_backends(): if hasattr(backend, 'send_login_code'): diff --git a/nopassword/rest/serializers.py b/nopassword/rest/serializers.py index edf5812..de27857 100644 --- a/nopassword/rest/serializers.py +++ b/nopassword/rest/serializers.py @@ -7,6 +7,7 @@ class LoginSerializer(serializers.Serializer): username = serializers.CharField() + next = serializers.CharField(required=False, allow_null=True) form_class = forms.LoginForm @@ -20,7 +21,7 @@ def validate(self, data): def save(self): request = self.context.get('request') - self.form.save(request=request) + return self.form.save(request=request) class LoginCodeSerializer(serializers.Serializer): diff --git a/nopassword/rest/views.py b/nopassword/rest/views.py index 11250ab..780eb59 100644 --- a/nopassword/rest/views.py +++ b/nopassword/rest/views.py @@ -49,11 +49,13 @@ def login(self): self.process_login() def get_response(self): - serializer = self.token_serializer_class( + token_serializer = self.token_serializer_class( instance=self.token, context=self.get_serializer_context(), ) - return Response(serializer.data, status=status.HTTP_200_OK) + data = token_serializer.data + data['next'] = self.serializer.validated_data['code'].next + return Response(data, status=status.HTTP_200_OK) def post(self, request, *args, **kwargs): self.serializer = self.get_serializer(data=request.data) diff --git a/nopassword/views.py b/nopassword/views.py index 52a3eb6..43b46fe 100644 --- a/nopassword/views.py +++ b/nopassword/views.py @@ -24,6 +24,11 @@ class LoginView(FormView): def dispatch(self, request, *args, **kwargs): return super(LoginView, self).dispatch(request, *args, **kwargs) + def get_form_kwargs(self): + kwargs = super(LoginView, self).get_form_kwargs() + kwargs['initial'] = {'next': self.request.GET.get('next')} + return kwargs + def form_valid(self, form): form.save(request=self.request) return super(LoginView, self).form_valid(form) diff --git a/tests/test_rest_views.py b/tests/test_rest_views.py index 97bcb5f..d67b37f 100644 --- a/tests/test_rest_views.py +++ b/tests/test_rest_views.py @@ -1,5 +1,6 @@ # -*- coding: utf8 -*- from django.contrib.auth import get_user_model +from django.core import mail from django.test import TestCase from rest_framework.authtoken.models import Token @@ -9,11 +10,12 @@ class TestRestViews(TestCase): def setUp(self): - self.user = get_user_model().objects.create(username='user') + self.user = get_user_model().objects.create(username='user', email='foo@bar.com') def test_request_login_code(self): response = self.client.post('/accounts-rest/login/', { 'username': self.user.username, + 'next': '/private/', }) self.assertEqual(response.status_code, 200) @@ -21,6 +23,12 @@ def test_request_login_code(self): login_code = LoginCode.objects.filter(user=self.user).first() self.assertIsNotNone(login_code) + self.assertEqual(login_code.next, '/private/') + self.assertEqual(len(mail.outbox), 1) + self.assertIn( + 'http://testserver/accounts/login/code/?code={}'.format(login_code.code), + mail.outbox[0].body, + ) def test_request_login_code_missing_username(self): response = self.client.post('/accounts-rest/login/') @@ -54,7 +62,7 @@ def test_request_login_code_inactive_user(self): }) def test_login(self): - login_code = LoginCode.objects.create(user=self.user, code='foobar') + login_code = LoginCode.objects.create(user=self.user, code='foobar', next='/private/') response = self.client.post('/accounts-rest/login/code/', { 'code': login_code.code, @@ -66,7 +74,10 @@ def test_login(self): token = Token.objects.filter(user=self.user).first() self.assertIsNotNone(token) - self.assertEqual(response.data['key'], token.key) + self.assertEqual(response.data, { + 'key': token.key, + 'next': '/private/', + }) def test_login_missing_code(self): response = self.client.post('/accounts-rest/login/code/') diff --git a/tests/test_views.py b/tests/test_views.py index 8955b29..a001fe6 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -1,5 +1,6 @@ # -*- coding: utf8 -*- from django.contrib.auth import get_user_model +from django.core import mail from django.test import TestCase, override_settings from nopassword.models import LoginCode @@ -8,11 +9,12 @@ class TestViews(TestCase): def setUp(self): - self.user = get_user_model().objects.create(username='user') + self.user = get_user_model().objects.create(username='user', email='foo@bar.com') def test_request_login_code(self): response = self.client.post('/accounts/login/', { 'username': self.user.username, + 'next': '/private/', }) self.assertEqual(response.status_code, 302) @@ -21,6 +23,12 @@ def test_request_login_code(self): login_code = LoginCode.objects.filter(user=self.user).first() self.assertIsNotNone(login_code) + self.assertEqual(login_code.next, '/private/') + self.assertEqual(len(mail.outbox), 1) + self.assertIn( + 'http://testserver/accounts/login/code/?code={}'.format(login_code.code), + mail.outbox[0].body, + ) def test_request_login_code_missing_username(self): response = self.client.post('/accounts/login/') @@ -54,14 +62,14 @@ def test_request_login_code_inactive_user(self): }) def test_login_post(self): - login_code = LoginCode.objects.create(user=self.user, code='foobar') + login_code = LoginCode.objects.create(user=self.user, code='foobar', next='/private/') response = self.client.post('/accounts/login/code/', { 'code': login_code.code, }) self.assertEqual(response.status_code, 302) - self.assertEqual(response['Location'], '/accounts/profile/') + self.assertEqual(response['Location'], '/private/') self.assertEqual(response.wsgi_request.user, self.user) self.assertFalse(LoginCode.objects.filter(pk=login_code.pk).exists()) @@ -79,14 +87,14 @@ def test_login_get(self): @override_settings(NOPASSWORD_LOGIN_ON_GET=True) def test_login_get_non_idempotent(self): - login_code = LoginCode.objects.create(user=self.user, code='foobar') + login_code = LoginCode.objects.create(user=self.user, code='foobar', next='/private/') response = self.client.get('/accounts/login/code/', { 'code': login_code.code, }) self.assertEqual(response.status_code, 302) - self.assertEqual(response['Location'], '/accounts/profile/') + self.assertEqual(response['Location'], '/private/') self.assertEqual(response.wsgi_request.user, self.user) self.assertFalse(LoginCode.objects.filter(pk=login_code.pk).exists())