Skip to content

Commit

Permalink
Merge pull request #1828 from privacyidea/1826/checkserial
Browse files Browse the repository at this point in the history
Add serial number check
  • Loading branch information
Friedrich Weber committed Aug 28, 2019
2 parents 4599835 + 76b6ade commit 2c90807
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 3 deletions.
4 changes: 4 additions & 0 deletions privacyidea/lib/decorators.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from privacyidea.lib.error import TokenAdminError from privacyidea.lib.error import TokenAdminError
from privacyidea.lib.error import ParameterError from privacyidea.lib.error import ParameterError
from privacyidea.lib import _ from privacyidea.lib import _
from privacyidea.lib.utils import check_serial_valid
log = logging.getLogger(__name__) log = logging.getLogger(__name__)




Expand Down Expand Up @@ -66,6 +67,9 @@ def user_or_serial_wrapper(*args, **kwds):
# We either have an empty User object or None # We either have an empty User object or None
raise ParameterError(ParameterError.USER_OR_SERIAL) raise ParameterError(ParameterError.USER_OR_SERIAL)


if serial:
check_serial_valid(serial)

f_result = func(*args, **kwds) f_result = func(*args, **kwds)
return f_result return f_result


Expand Down
3 changes: 2 additions & 1 deletion privacyidea/lib/token.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
from privacyidea.lib.decorators import (check_user_or_serial, from privacyidea.lib.decorators import (check_user_or_serial,
check_copy_serials) check_copy_serials)
from privacyidea.lib.tokenclass import TokenClass from privacyidea.lib.tokenclass import TokenClass
from privacyidea.lib.utils import is_true, BASE58, hexlify_and_unicode from privacyidea.lib.utils import is_true, BASE58, hexlify_and_unicode, check_serial_valid
from privacyidea.lib.crypto import generate_password from privacyidea.lib.crypto import generate_password
from privacyidea.lib.log import log_with from privacyidea.lib.log import log_with
from privacyidea.models import (Token, Realm, TokenRealm, Challenge, from privacyidea.models import (Token, Realm, TokenRealm, Challenge,
Expand Down Expand Up @@ -974,6 +974,7 @@ def init_token(param, user=None, tokenrealms=None,


tokentype = param.get("type") or "hotp" tokentype = param.get("type") or "hotp"
serial = param.get("serial") or gen_serial(tokentype, param.get("prefix")) serial = param.get("serial") or gen_serial(tokentype, param.get("prefix"))
check_serial_valid(serial)
realms = [] realms = []


# unsupported tokentype # unsupported tokentype
Expand Down
17 changes: 16 additions & 1 deletion privacyidea/lib/utils/__init__.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@


BASE58 = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz" BASE58 = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"


ALLOWED_SERIAL = "^[0-9a-zA-Z\-_]+$"



def check_time_in_range(time_range, check_time=None): def check_time_in_range(time_range, check_time=None):
""" """
Expand Down Expand Up @@ -1253,4 +1255,17 @@ def create_tag_dict(logged_in_user=None,
escaped_tags[key] = cgi.escape(value) if value is not None else None escaped_tags[key] = cgi.escape(value) if value is not None else None
tags = escaped_tags tags = escaped_tags


return tags return tags


def check_serial_valid(serial):
"""
This function checks the given serial number for allowed values.
Raises an exception if the format of the serial number is not allowed
:param serial:
:return: True or Exception
"""
if not re.match(ALLOWED_SERIAL, serial):
raise ParameterError("Invalid serial number. Must comply to {0!s}.".format(ALLOWED_SERIAL))
return True
12 changes: 12 additions & 0 deletions tests/test_api_token.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -1782,6 +1782,18 @@ def test_30_force_app_pin(self):
remove_token("goog2") remove_token("goog2")
delete_policy('app_pin') delete_policy('app_pin')


def test_31_invalid_serial(self):
# Run a test with an invalid serial
with self.app.test_request_context('/token/init',
method='POST',
data={"serial": "invalid/character",
"genkey": 1},
headers={'Authorization': self.at}):
res = self.app.full_dispatch_request()
self.assertTrue(res.status_code == 400, res)
result = res.json.get("result")
self.assertTrue("Invalid serial number" in result.get("error").get("message"))



class API00TokenPerformance(MyApiTestCase): class API00TokenPerformance(MyApiTestCase):


Expand Down
8 changes: 8 additions & 0 deletions tests/test_lib_token.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -1481,6 +1481,14 @@ def test_56_get_tokens_paginated_generator_removal(self):
# Check that we did not miss any tokens # Check that we did not miss any tokens
self.assertEquals(set(t.token.serial for t in list1 + list2), all_serials) self.assertEquals(set(t.token.serial for t in list1 + list2), all_serials)


def test_0057_check_invalid_serial(self):
# This is an invalid serial, which will trigger an exception
self.assertRaises(Exception, reset_token, "hans wurst")

self.assertRaises(Exception, init_token,
{"serial": "invalid/chars",
"genkey": 1})



class TokenOutOfBandTestCase(MyTestCase): class TokenOutOfBandTestCase(MyTestCase):


Expand Down
18 changes: 17 additions & 1 deletion tests/test_lib_utils.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
b64encode_and_unicode, create_png, create_img, b64encode_and_unicode, create_png, create_img,
convert_timestamp_to_utc, modhex_encode, convert_timestamp_to_utc, modhex_encode,
modhex_decode, checksum, urlsafe_b64encode_and_unicode, modhex_decode, checksum, urlsafe_b64encode_and_unicode,
check_ip_in_policy, split_pin_pass, create_tag_dict) check_ip_in_policy, split_pin_pass, create_tag_dict,
check_serial_valid)
from datetime import timedelta, datetime from datetime import timedelta, datetime
from netaddr import IPAddress, IPNetwork, AddrFormatError from netaddr import IPAddress, IPNetwork, AddrFormatError
from dateutil.tz import tzlocal, tzoffset, gettz from dateutil.tz import tzlocal, tzoffset, gettz
Expand Down Expand Up @@ -681,3 +682,18 @@ class RequestMock():
self.assertEqual(dict2["ua_string"], "<b>hello world</b>") self.assertEqual(dict2["ua_string"], "<b>hello world</b>")
self.assertEqual(dict2["action"], "/validate/check") self.assertEqual(dict2["action"], "/validate/check")
self.assertEqual(dict2["recipient_givenname"], u"<b>Sömeone</b>") self.assertEqual(dict2["recipient_givenname"], u"<b>Sömeone</b>")

def test_32_allowed_serial_numbers(self):
self.assertTrue(check_serial_valid("TOTP12345"))
# Blank is not allowed
self.assertRaises(Exception, check_serial_valid, "TOTP 12345")

# Minus and underscore is allowed
self.assertTrue(check_serial_valid("spass-123"))
self.assertTrue(check_serial_valid("spass_123"))
# Slash and backslash is not allowed
self.assertRaises(Exception, check_serial_valid, "spass/123")
self.assertRaises(Exception, check_serial_valid, "spass\\123")

# an empty serial is not allowed
self.assertRaises(Exception, check_serial_valid, "")

0 comments on commit 2c90807

Please sign in to comment.