Skip to content
This repository has been archived by the owner on May 5, 2020. It is now read-only.

Commit

Permalink
feat: Support next in rest
Browse files Browse the repository at this point in the history
  • Loading branch information
rubengrill authored and relekang committed Sep 14, 2018
1 parent 44363c5 commit deba82e
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 17 deletions.
16 changes: 10 additions & 6 deletions nopassword/forms.py
Expand Up @@ -18,6 +18,8 @@ class LoginForm(forms.Form):
'inactive': _("This account is inactive."),
}

next = forms.CharField(max_length=200, required=False, widget=forms.HiddenInput)

def __init__(self, *args, **kwargs):
super(LoginForm, self).__init__(*args, **kwargs)

Expand All @@ -42,14 +44,15 @@ def clean_username(self):
code='inactive',
)

self.cleaned_data['login_code'] = models.LoginCode.create_code_for_user(user)
self.cleaned_data['user'] = user

return username

def save(self, request, login_code_url='login_code', domain_override=None, extra_context=None):
login_code = self.cleaned_data['login_code']
login_code.next = request.GET.get('next')
login_code.save()
login_code = models.LoginCode.create_code_for_user(
user=self.cleaned_data['user'],
next=self.cleaned_data['next'],
)

if not domain_override:
current_site = get_current_site(request)
Expand All @@ -58,12 +61,11 @@ def save(self, request, login_code_url='login_code', domain_override=None, extra
else:
site_name = domain = domain_override

url = '{}://{}{}?code={}&next={}'.format(
url = '{}://{}{}?code={}'.format(
'https' if request.is_secure() else 'http',
domain,
resolve_url(login_code_url),
login_code.code,
login_code.next,
)

context = {
Expand All @@ -78,6 +80,8 @@ def save(self, request, login_code_url='login_code', domain_override=None, extra

self.send_login_code(login_code, context)

return login_code

def send_login_code(self, login_code, context, **kwargs):
for backend in get_backends():
if hasattr(backend, 'send_login_code'):
Expand Down
3 changes: 2 additions & 1 deletion nopassword/rest/serializers.py
Expand Up @@ -7,6 +7,7 @@

class LoginSerializer(serializers.Serializer):
username = serializers.CharField()
next = serializers.CharField(required=False, allow_null=True)

form_class = forms.LoginForm

Expand All @@ -20,7 +21,7 @@ def validate(self, data):

def save(self):
request = self.context.get('request')
self.form.save(request=request)
return self.form.save(request=request)


class LoginCodeSerializer(serializers.Serializer):
Expand Down
6 changes: 4 additions & 2 deletions nopassword/rest/views.py
Expand Up @@ -49,11 +49,13 @@ def login(self):
self.process_login()

def get_response(self):
serializer = self.token_serializer_class(
token_serializer = self.token_serializer_class(
instance=self.token,
context=self.get_serializer_context(),
)
return Response(serializer.data, status=status.HTTP_200_OK)
data = token_serializer.data
data['next'] = self.serializer.validated_data['code'].next
return Response(data, status=status.HTTP_200_OK)

def post(self, request, *args, **kwargs):
self.serializer = self.get_serializer(data=request.data)
Expand Down
5 changes: 5 additions & 0 deletions nopassword/views.py
Expand Up @@ -24,6 +24,11 @@ class LoginView(FormView):
def dispatch(self, request, *args, **kwargs):
return super(LoginView, self).dispatch(request, *args, **kwargs)

def get_form_kwargs(self):
kwargs = super(LoginView, self).get_form_kwargs()
kwargs['initial'] = {'next': self.request.GET.get('next')}
return kwargs

def form_valid(self, form):
form.save(request=self.request)
return super(LoginView, self).form_valid(form)
Expand Down
17 changes: 14 additions & 3 deletions tests/test_rest_views.py
@@ -1,5 +1,6 @@
# -*- coding: utf8 -*-
from django.contrib.auth import get_user_model
from django.core import mail
from django.test import TestCase
from rest_framework.authtoken.models import Token

Expand All @@ -9,18 +10,25 @@
class TestRestViews(TestCase):

def setUp(self):
self.user = get_user_model().objects.create(username='user')
self.user = get_user_model().objects.create(username='user', email='foo@bar.com')

def test_request_login_code(self):
response = self.client.post('/accounts-rest/login/', {
'username': self.user.username,
'next': '/private/',
})

self.assertEqual(response.status_code, 200)

login_code = LoginCode.objects.filter(user=self.user).first()

self.assertIsNotNone(login_code)
self.assertEqual(login_code.next, '/private/')
self.assertEqual(len(mail.outbox), 1)
self.assertIn(
'http://testserver/accounts/login/code/?code={}'.format(login_code.code),
mail.outbox[0].body,
)

def test_request_login_code_missing_username(self):
response = self.client.post('/accounts-rest/login/')
Expand Down Expand Up @@ -54,7 +62,7 @@ def test_request_login_code_inactive_user(self):
})

def test_login(self):
login_code = LoginCode.objects.create(user=self.user, code='foobar')
login_code = LoginCode.objects.create(user=self.user, code='foobar', next='/private/')

response = self.client.post('/accounts-rest/login/code/', {
'code': login_code.code,
Expand All @@ -66,7 +74,10 @@ def test_login(self):
token = Token.objects.filter(user=self.user).first()

self.assertIsNotNone(token)
self.assertEqual(response.data['key'], token.key)
self.assertEqual(response.data, {
'key': token.key,
'next': '/private/',
})

def test_login_missing_code(self):
response = self.client.post('/accounts-rest/login/code/')
Expand Down
18 changes: 13 additions & 5 deletions tests/test_views.py
@@ -1,5 +1,6 @@
# -*- coding: utf8 -*-
from django.contrib.auth import get_user_model
from django.core import mail
from django.test import TestCase, override_settings

from nopassword.models import LoginCode
Expand All @@ -8,11 +9,12 @@
class TestViews(TestCase):

def setUp(self):
self.user = get_user_model().objects.create(username='user')
self.user = get_user_model().objects.create(username='user', email='foo@bar.com')

def test_request_login_code(self):
response = self.client.post('/accounts/login/', {
'username': self.user.username,
'next': '/private/',
})

self.assertEqual(response.status_code, 302)
Expand All @@ -21,6 +23,12 @@ def test_request_login_code(self):
login_code = LoginCode.objects.filter(user=self.user).first()

self.assertIsNotNone(login_code)
self.assertEqual(login_code.next, '/private/')
self.assertEqual(len(mail.outbox), 1)
self.assertIn(
'http://testserver/accounts/login/code/?code={}'.format(login_code.code),
mail.outbox[0].body,
)

def test_request_login_code_missing_username(self):
response = self.client.post('/accounts/login/')
Expand Down Expand Up @@ -54,14 +62,14 @@ def test_request_login_code_inactive_user(self):
})

def test_login_post(self):
login_code = LoginCode.objects.create(user=self.user, code='foobar')
login_code = LoginCode.objects.create(user=self.user, code='foobar', next='/private/')

response = self.client.post('/accounts/login/code/', {
'code': login_code.code,
})

self.assertEqual(response.status_code, 302)
self.assertEqual(response['Location'], '/accounts/profile/')
self.assertEqual(response['Location'], '/private/')
self.assertEqual(response.wsgi_request.user, self.user)
self.assertFalse(LoginCode.objects.filter(pk=login_code.pk).exists())

Expand All @@ -79,14 +87,14 @@ def test_login_get(self):

@override_settings(NOPASSWORD_LOGIN_ON_GET=True)
def test_login_get_non_idempotent(self):
login_code = LoginCode.objects.create(user=self.user, code='foobar')
login_code = LoginCode.objects.create(user=self.user, code='foobar', next='/private/')

response = self.client.get('/accounts/login/code/', {
'code': login_code.code,
})

self.assertEqual(response.status_code, 302)
self.assertEqual(response['Location'], '/accounts/profile/')
self.assertEqual(response['Location'], '/private/')
self.assertEqual(response.wsgi_request.user, self.user)
self.assertFalse(LoginCode.objects.filter(pk=login_code.pk).exists())

Expand Down

0 comments on commit deba82e

Please sign in to comment.