Skip to content
Permalink
Browse files

Merge pull request #1753 from privacyidea/1436/user-info-finetuning

Finetuning policy matching against arbitrary userinfo attributes
  • Loading branch information...
plettich committed Jul 25, 2019
2 parents 96b3e18 + 9c93c15 commit 680771497a1aceb8b56956d3bb808e661e972d52
@@ -26,21 +26,46 @@
2) implement a comparison function and add it to COMPARATOR_FUNCTIONS
3) add a description of the comparator to COMPARATOR_DESCRIPTIONS
"""
import csv
import logging
import re
from six import wraps

from privacyidea.lib.framework import _

log = logging.getLogger(__name__)


class CompareError(Exception):
"""
Signals that an error occurred when carrying out a comparison.
The error message is not presented to the user, but written to the logfile.
"""
def __init__(self, message):
self.message = message

def __repr__(self):
return u"CompareError({!r})".format(self.message)


def parse_comma_separated_string(input_string):
"""
Parse a string that contains a list of comma-separated values and return the list of values.
Each value may be quoted with a doublequote, and doublequotes may be escaped with a backslash.
Whitespace immediately following a delimiter is skipped.
Raise a CompareError if the input is malformed.
:param input_string: an input string
:return: a list of strings
"""
# We use Python's csv module because it supports quoted values
try:
reader = csv.reader([input_string], strict=True, skipinitialspace=True, doublequote=False, escapechar="\\")
rows = list(reader)
except csv.Error as exx:
raise CompareError(u"Malformed comma-separated value: {!r}".format(input_string, exx))
return rows[0]


def _compare_equality(left, comparator, right):
"""
Return True if two values are exactly equal, according to Python semantics.
@@ -62,25 +87,92 @@ def _compare_contains(left, comparator, right):
raise CompareError(u"Left value must be a list, not {!r}".format(type(left)))


def _compare_matches(left, comparator, right):
"""
Return True if the string in ``left`` completely matches the regular expression given in ``right``.
Raise a CompareError if ``right`` is not a valid regular expression, or
if any other matching error occurs.
:param left: a string
:param right: a regular expression
:return: True or False
"""
try:
return re.match("^" + right + "$", left) is not None
except re.error as e:
raise CompareError(u"Error during matching: {!r}".format(e))


def _compare_in(left, comparator, right):
"""
Return True if ``left`` is a member of ``right``, which is a string containing a
list of values, separated by commas (see ``parse_comma_separated_string``).
:param left: a string
:param right: a string of comma-separated values
:return: True or False
"""
return left in parse_comma_separated_string(right)


def negate(func):
"""
Given a comparison function ``func``, build and return a comparison function that negates
the result of ``func``.
:param func: a comparison function taking three arguments
:return: a comparison function taking three arguments
"""
@wraps(func)
def negated(left, comparator, right):
return not func(left, comparator, right)
return negated


#: This class enumerates all available comparators.
#: In order to add a comparator to this module, add a suitable member to COMPARATORS
#: and suitable entries to COMPARATOR_FUNCTIONS and COMPARATOR_DESCRIPTIONS.
class COMPARATORS(object):
EQUALS = "=="
EQUALS = "equals"
NOT_EQUALS = "!equals"

CONTAINS = "contains"
NOT_CONTAINS = "!contains"

MATCHES = "matches"
NOT_MATCHES = "!matches"

IN = "in"
NOT_IN = "!in"


#: This dictionary connects comparators to comparator functions.
#: A comparison function takes three parameters ``left``, ``comparator``, ``right``.
COMPARATOR_FUNCTIONS = {
COMPARATORS.EQUALS: _compare_equality,
COMPARATORS.NOT_EQUALS: negate(_compare_equality),

COMPARATORS.CONTAINS: _compare_contains,
COMPARATORS.NOT_CONTAINS: negate(_compare_contains),

COMPARATORS.MATCHES: _compare_matches,
COMPARATORS.NOT_MATCHES: negate(_compare_matches),

COMPARATORS.IN: _compare_in,
COMPARATORS.NOT_IN: negate(_compare_in),
}


#: This dictionary connects comparators to their human-readable (and translated) descriptions.
COMPARATOR_DESCRIPTIONS = {
COMPARATORS.CONTAINS: _("true if the left value contains the right value"),
COMPARATORS.EQUALS: _("true if the two values are equal")
COMPARATORS.NOT_CONTAINS: _("false if the left value contains the right value"),

COMPARATORS.EQUALS: _("true if the two values are equal"),
COMPARATORS.NOT_EQUALS: _("false if the two values are equal"),

COMPARATORS.MATCHES: _("true if the left value completely matches the given regular expression pattern"),
COMPARATORS.NOT_MATCHES: _("false if the left value completely matches the given regular expression pattern"),

COMPARATORS.IN: _("true if the left value is contained in the comma-separated values on the right"),
COMPARATORS.NOT_IN: _("false if the left value is contained in the comma-separated values on the right")
}


@@ -96,5 +188,4 @@ def compare_values(left, comparator, right):
if comparator in COMPARATOR_FUNCTIONS:
return COMPARATOR_FUNCTIONS[comparator](left, comparator, right)
else:
# We intentionally leave out the values, in case sensitive values are compared
raise CompareError(u"Invalid comparator: {!r}".format(comparator))
@@ -1500,7 +1500,7 @@ class PolicyCondition(MethodsMixin, db.Model):
# We use upper-case "Key" and "Value" to prevent conflicts with databases
# that do not support "key" or "value" as column names
Key = db.Column(db.Unicode(255), nullable=False)
comparator = db.Column(db.Unicode(255), nullable=False, default=u'==')
comparator = db.Column(db.Unicode(255), nullable=False, default=u'equals')
Value = db.Column(db.Unicode(2000), nullable=False, default=u'')
active = db.Column(db.Boolean, nullable=False, default=True)

@@ -455,7 +455,7 @@ myApp.directive("piPolicyConditions", function (instanceUrl) {
// Called when the user clicks on the "add condition" button.
// Adds a condition with default values
scope.addCondition = function () {
scope.policyConditions.push(["userinfo", "", "==", "", false]);
scope.policyConditions.push(["userinfo", "", "equals", "", false]);
scope.editIndex = scope.policyConditions.length - 1;
};
},
@@ -1106,22 +1106,22 @@ def test_28_conditions(self):

# Set policy with conditions
set_policy("act1", scope=SCOPE.AUTH, action="{0!s}=userstore".format(ACTION.OTPPIN),
conditions=[("userinfo", "type", "==", "verysecure", True)])
conditions=[("userinfo", "type", "equals", "verysecure", True)])

P = PolicyClass()
self.assertEqual(P.list_policies()[0]["conditions"],
[("userinfo", "type", "==", "verysecure", True)])
[("userinfo", "type", "equals", "verysecure", True)])

# Update existing policy with conditions
set_policy("act1", conditions=[
("userinfo", "type", "==", "notverysecure", True),
("request", "user_agent", "==", "vpn", True)
("userinfo", "type", "equals", "notverysecure", True),
("request", "user_agent", "equals", "vpn", True)
])
P = PolicyClass()

self.assertEqual(P.list_policies()[0]["conditions"],
[("userinfo", "type", "==", "notverysecure", True),
("request", "user_agent", "==", "vpn", True)])
[("userinfo", "type", "equals", "notverysecure", True),
("request", "user_agent", "equals", "vpn", True)])

delete_policy("act1")
delete_realm("realm1")
@@ -1132,9 +1132,9 @@ def _names(policies):
return set(p['name'] for p in policies)

set_policy("verysecure", scope=SCOPE.AUTH, action="{0!s}=userstore".format(ACTION.OTPPIN),
conditions=[("userinfo", "type", "==", "verysecure", True)])
conditions=[("userinfo", "type", "equals", "verysecure", True)])
set_policy("notverysecure", scope=SCOPE.AUTH, action="{0!s}=userstore".format(ACTION.OTPPIN),
conditions=[("userinfo", "type", "==", "notverysecure", True),
conditions=[("userinfo", "type", "equals", "notverysecure", True),
("userinfo", "groups", "contains", "b", True)])
P = PolicyClass()

@@ -1183,14 +1183,14 @@ class MockUser(object):

# Policy with initially inactive condition
set_policy("extremelysecure", scope=SCOPE.AUTH, action="{0!s}=userstore".format(ACTION.OTPPIN),
conditions=[("userinfo", "type", "==", "notverysecure", False)])
conditions=[("userinfo", "type", "equals", "notverysecure", False)])

# user1 matches, because the condition on type is inactive
self.assertEqual(_names(P.match_policies(user_object=user1)),
{"extremelysecure"})

# activate the condition
set_policy("extremelysecure", conditions=[("userinfo", "type", "==", "notverysecure", True)])
set_policy("extremelysecure", conditions=[("userinfo", "type", "equals", "notverysecure", True)])

# user1 does not match anymore, because the condition on type is active
self.assertEqual(_names(P.match_policies(user_object=user1)),
@@ -1213,21 +1213,21 @@ class MockUser(object):

# an unknown section in the condition
set_policy("unknownsection", scope=SCOPE.AUTH, action="{0!s}=userstore".format(ACTION.OTPPIN),
conditions=[("somesection", "bla", "==", "verysecure", True)])
conditions=[("somesection", "bla", "equals", "verysecure", True)])
with self.assertRaisesRegexp(PolicyError, r".*unknown section.*"):
P.match_policies(user_object=user1)
delete_policy("unknownsection")

# ... but the error does not occur if the condition is inactive
set_policy("unknownsection", scope=SCOPE.AUTH, action="{0!s}=userstore".format(ACTION.OTPPIN),
conditions=[("somesection", "bla", "==", "verysecure", False)])
conditions=[("somesection", "bla", "equals", "verysecure", False)])
all_policies = P.list_policies()
self.assertEqual(P.match_policies(user_object=user1), all_policies)
delete_policy("unknownsection")

# an unknown key in the condition
set_policy("unknownkey", scope=SCOPE.AUTH, action="{0!s}=userstore".format(ACTION.OTPPIN),
conditions=[("userinfo", "bla", "==", "verysecure", True)])
conditions=[("userinfo", "bla", "equals", "verysecure", True)])
with self.assertRaisesRegexp(PolicyError, r".*Unknown key.*"):
P.match_policies(user_object=user1)
delete_policy("unknownkey")
@@ -2,17 +2,20 @@
"""
This tests the module lib.utils.compare
"""
from privacyidea.lib.utils.compare import compare_values, CompareError
from privacyidea.lib.utils.compare import compare_values, CompareError, parse_comma_separated_string
from .base import MyTestCase


class UtilsCompareTestCase(MyTestCase):
def test_01_compare_equal(self):
self.assertTrue(compare_values("hello", "==", "hello"))
self.assertTrue(compare_values(1, "==", 1))
self.assertFalse(compare_values("hello", "==", " hello"))
self.assertFalse(compare_values(1, "==", 2))
self.assertFalse(compare_values(1, "==", "1"))
self.assertTrue(compare_values("hello", "equals", "hello"))
self.assertTrue(compare_values(1, "equals", 1))
self.assertFalse(compare_values("hello", "equals", " hello"))
self.assertFalse(compare_values(1, "equals", 2))
self.assertFalse(compare_values(1, "equals", "1"))
# negation
self.assertFalse(compare_values("hello", "!equals", "hello"))
self.assertTrue(compare_values(1, "!equals", "1"))

def test_02_compare_contains(self):
self.assertTrue(compare_values(["hello", "world"], "contains", "hello"))
@@ -24,7 +27,73 @@ def test_02_compare_contains(self):
with self.assertRaises(CompareError):
compare_values("hello world", "contains", "hello")

# negation
self.assertTrue(compare_values([1, "world"], "!contains", "hello"))
self.assertFalse(compare_values([1, "world"], "!contains", "world"))
with self.assertRaises(CompareError):
compare_values("hello world", "!contains", "hello")

def test_03_compare_errors(self):
with self.assertRaises(CompareError) as cm:
compare_values("hello world", "something", "hello")
self.assertIn("Invalid comparator", repr(cm.exception))

def test_04_compare_matches(self):
self.assertTrue(compare_values("hello world", "matches", "hello world"))
self.assertTrue(compare_values("hello world", "matches", ".*world"))
self.assertTrue(compare_values("uid=hello,cn=users,dc=test,dc=intranet", "matches",
"uid=[^,]+,cn=users,dc=test,dc=intranet"))
# only complete matches
self.assertFalse(compare_values("hello world", "matches", "world"))
self.assertFalse(compare_values("uid=hello,cn=users,dc=test,dc=intranet,dc=world", "matches",
"uid=[^,]+,cn=users,dc=test,dc=intranet"))
# supports more advanced regex features
self.assertTrue(compare_values("hElLo WoRLd", "matches", "(?i)hello world( and stuff)?"))
# raises errors on invalid patterns
with self.assertRaises(CompareError):
compare_values("hello world", "matches", "this is (invalid")

# negation
self.assertTrue(compare_values("uid=hello,cn=users,dc=test,dc=intranet", "!matches",
"uid=[^,]+,cn=admins,dc=test,dc=intranet"))
self.assertFalse(compare_values("uid=hello,cn=admins,dc=test,dc=intranet", "!matches",
"uid=[^,]+,cn=admins,dc=test,dc=intranet"))

def test_05_parse_comma_separated_string(self):
self.assertEquals(parse_comma_separated_string("hello world"), ["hello world"])
# whitespace immediately following a delimiter is skipped
self.assertEquals(parse_comma_separated_string("realm1, realm2,realm3"),
["realm1", "realm2", "realm3"])
# whitespace before delimiters is not skipped
self.assertEquals(parse_comma_separated_string(" realm1 ,realm2"),
["realm1 ", "realm2"])
# strings can be quoted
self.assertEquals(parse_comma_separated_string('realm1, "realm2", " realm3"'),
["realm1", "realm2", " realm3"])
# even with commas
self.assertEquals(parse_comma_separated_string('realm1, "realm2, with a, strange, name", other stuff'),
["realm1", "realm2, with a, strange, name", "other stuff"])
# double quotes can be escaped
self.assertEquals(parse_comma_separated_string(r'realm\", realm2'),
['realm"', 'realm2'])
# error if a string is not properly quoted
with self.assertRaises(CompareError):
parse_comma_separated_string('"no')
# error if we pass multiple lines
with self.assertRaises(CompareError):
parse_comma_separated_string('realm1\nrealm2')
# but we can quote newlines
self.assertEquals(parse_comma_separated_string('"realm1\nrealm2"'),
["realm1\nrealm2"])

def test_06_compare_in(self):
self.assertTrue(compare_values("hello", "in", "hello"))
self.assertTrue(compare_values("world", "in", "hello, world, this is a list"))
self.assertFalse(compare_values("hello", "in", "hello world"))
self.assertFalse(compare_values("hello,world", "in", 'hello,world'))
self.assertTrue(compare_values("hello,world", "in", '"hello,world"'))

# negation
self.assertTrue(compare_values("hello", "!in", "world"))
self.assertFalse(compare_values("hello", "!in", " hello, world"))
self.assertTrue(compare_values("hello", "!in", "hello world"))

0 comments on commit 6807714

Please sign in to comment.
You can’t perform that action at this time.