-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
169 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import requests | ||
from django.contrib.auth import get_user_model | ||
from django.http import HttpResponseForbidden | ||
|
||
from accounts.settings import accounts_settings | ||
|
||
|
||
class OAuth2TokenMiddleware: | ||
""" | ||
When a view is requested using a Bearer Authorization header, | ||
check and set request.user to the owner of said token | ||
""" | ||
|
||
def __init__(self, get_response): | ||
self.get_response = get_response | ||
|
||
def __call__(self, request): | ||
authorization = request.META.get('HTTP_AUTHORIZATION') | ||
if authorization and ' ' in authorization: | ||
auth_type, token = authorization.split() | ||
if auth_type == 'Bearer': # Only validate if Authorization header type is Bearer | ||
body = {'token': token} | ||
headers = {'Authorization': 'Bearer {}'.format(token)} | ||
try: | ||
data = requests.post( | ||
url=accounts_settings.PLATFORM_URL + '/accounts/introspect/', | ||
headers=headers, | ||
data=body | ||
) | ||
if data.status_code == 200: # Access token is valid | ||
data = data.json() | ||
User = get_user_model() | ||
user = User.objects.filter(id=int(data['user']['pennid'])) | ||
if len(user) == 1: | ||
request.user = user.first() | ||
else: # User doesn't have an account on this product | ||
pass | ||
else: # Access token is invalid | ||
return HttpResponseForbidden() | ||
except requests.exceptions.RequestException: # Can't connect to platform | ||
return HttpResponseForbidden() | ||
|
||
response = self.get_response(request) | ||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from unittest.mock import Mock, patch | ||
|
||
from django.contrib.auth import get_user_model | ||
from django.contrib.auth.models import AnonymousUser | ||
from django.test import TestCase | ||
from django.urls import reverse | ||
|
||
from accounts.middleware import OAuth2TokenMiddleware | ||
|
||
|
||
@patch('accounts.middleware.requests.post') | ||
class OAuth2TokenMiddlewareTestCase(TestCase): | ||
def setUp(self): | ||
self.request = Mock() | ||
self.request.META = {} | ||
self.request.user = AnonymousUser | ||
self.middleware = OAuth2TokenMiddleware(Mock()) | ||
self.user = get_user_model().objects.create(id=123, username='username') | ||
self.valid_response = { | ||
'user': { | ||
'pennid': '123' | ||
} | ||
} | ||
|
||
def test_no_authorization_header(self, mock_request): | ||
self.middleware(self.request) | ||
self.assertEqual(AnonymousUser, self.request.user) | ||
|
||
def test_authorization_header_wrong_type(self, mock_request): | ||
self.request.META['HTTP_AUTHORIZATION'] = 'Basic abc' | ||
self.middleware(self.request) | ||
self.assertEqual(AnonymousUser, self.request.user) | ||
|
||
def test_authorization_header_malformed(self, mock_request): | ||
self.request.META['HTTP_AUTHORIZATION'] = 'Basic' | ||
self.middleware(self.request) | ||
self.assertEqual(AnonymousUser, self.request.user) | ||
|
||
def test_authorization_header_invalid(self, mock_request): | ||
mock_request.return_value.status_code = 403 | ||
self.request.META['HTTP_AUTHORIZATION'] = 'Bearer abc' | ||
self.middleware(self.request) | ||
self.assertEqual(AnonymousUser, self.request.user) | ||
|
||
def test_authorization_header_valid_user_exists(self, mock_request): | ||
mock_request.return_value.status_code = 200 | ||
mock_request.return_value.json.return_value = self.valid_response | ||
self.request.META['HTTP_AUTHORIZATION'] = 'Bearer abc123' | ||
self.middleware(self.request) | ||
self.assertEqual(self.user, self.request.user) | ||
|
||
def test_authorization_header_valid_user_no_exists(self, mock_request): | ||
mock_request.return_value.status_code = 200 | ||
self.valid_response['user']['pennid'] = '456' | ||
mock_request.return_value.json.return_value = self.valid_response | ||
self.request.META['HTTP_AUTHORIZATION'] = 'Bearer abc123' | ||
self.middleware(self.request) | ||
self.assertEqual(AnonymousUser, self.request.user) | ||
|
||
|
||
@patch('accounts.middleware.requests.post') | ||
class TestViewTestCase(TestCase): | ||
def setUp(self): | ||
self.user = get_user_model().objects.create(id=123, username='username') | ||
self.headers = {} | ||
self.valid_response = { | ||
'user': { | ||
'pennid': '123' | ||
} | ||
} | ||
|
||
def test_no_authorization_header(self, mock_request): | ||
response = self.client.get(reverse('test'), **self.headers) | ||
self.assertEqual(200, response.status_code) | ||
|
||
def test_authorization_header_wrong_type(self, mock_request): | ||
self.headers['HTTP_AUTHORIZATION'] = 'Basic abc' | ||
response = self.client.get(reverse('test'), **self.headers) | ||
self.assertEqual(200, response.status_code) | ||
|
||
def test_authorization_header_malformed(self, mock_request): | ||
self.headers['HTTP_AUTHORIZATION'] = 'Basic' | ||
response = self.client.get(reverse('test'), **self.headers) | ||
self.assertEqual(200, response.status_code) | ||
|
||
def test_authorization_header_invalid(self, mock_request): | ||
mock_request.return_value.status_code = 403 | ||
self.headers['HTTP_AUTHORIZATION'] = 'Bearer abc' | ||
response = self.client.get(reverse('test'), **self.headers) | ||
self.assertEqual(403, response.status_code) | ||
|
||
def test_authorization_header_valid_user_exists(self, mock_request): | ||
mock_request.return_value.status_code = 200 | ||
mock_request.return_value.json.return_value = self.valid_response | ||
self.headers['HTTP_AUTHORIZATION'] = 'Bearer abc123' | ||
response = self.client.get(reverse('test'), **self.headers) | ||
self.assertEqual(200, response.status_code) | ||
|
||
def test_authorization_header_valid_user_no_exists(self, mock_request): | ||
mock_request.return_value.status_code = 200 | ||
self.valid_response['user']['pennid'] = '456' | ||
mock_request.return_value.json.return_value = self.valid_response | ||
self.headers['HTTP_AUTHORIZATION'] = 'Bearer abc123' | ||
response = self.client.get(reverse('test'), **self.headers) | ||
self.assertEqual(200, response.status_code) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,11 @@ | ||
from django.contrib import admin | ||
from django.urls import include, path | ||
|
||
from tests.views import TestView | ||
|
||
|
||
urlpatterns = [ | ||
path('accounts/', include('accounts.urls', namespace='accounts')), | ||
path('admin/', admin.site.urls), | ||
path('test/', TestView.as_view(), name='test'), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from django.http import HttpResponse | ||
from django.views import View | ||
|
||
|
||
class TestView(View): | ||
def get(self, request): | ||
return HttpResponse('Success') |