Skip to content

Commit

Permalink
expand tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gotcha committed Feb 1, 2018
1 parent 01aaedf commit c071b8d
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 4 deletions.
20 changes: 17 additions & 3 deletions src/AccessControl/tainted.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,28 @@ def createOneOptArgWrapper(func):
class TaintedBytes(TaintedString):

def __init__(self, value):
if isinstance(value, int):
assert six.PY3
if isinstance(value, bytes):
self._value = value
elif isinstance(value, int):
if six.PY2:
raise ValueError(
"Constructing from a single character as an int "
"is valid only with Python 3."
)
value = bytes([value])
self._value = value
self._value = value
else:
raise ValueError(
"Can be constructed only from bytes "
"(or a single int with Python3)."
)

def quoted(self):
result = escape(self._value.decode('utf8'), 1)
return result.encode('utf8')

def __str__(self):
return self._value.decode('utf8')

def decode(self, *args):
return TaintedString(self._value.decode(*args))
99 changes: 98 additions & 1 deletion src/AccessControl/tests/test_tainted.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def testRepr(self):
self.assertEqual(repr(self.tainted), repr(self.quoted))

def testEqual(self):
self.assertTrue(self.tainted, self.unquoted)
self.assertEqual(self.tainted, self.unquoted)

def testCmp(self):
self.assertTrue(self.tainted == self.unquoted)
Expand Down Expand Up @@ -195,6 +195,7 @@ def _getClass(self):

def testCmp(self):
self.assertTrue(self.tainted == self.unquoted)
self.assertEqual(self.tainted, self.unquoted)
self.assertTrue(self.tainted < b'a')
self.assertTrue(self.tainted > b'.')

Expand All @@ -214,3 +215,99 @@ def testGetSlice(self):
self.assertEqual(self.tainted[0:1], b'<')
self.assertFalse(isinstance(self.tainted[1:], self._getClass()))
self.assertEqual(self.tainted[1:], self.unquoted[1:])

def testInterpolate(self):
tainted = self._getClass()(b'<%s>')
self.assertTrue(isinstance(tainted % b'foo', self._getClass()))
self.assertEqual(tainted % b'foo', b'<foo>')
tainted = self._getClass()(b'<%s attr="%s">')
self.assertTrue(isinstance(tainted % (b'foo', b'bar'), self._getClass()))
self.assertEqual(tainted % (b'foo', b'bar'), b'<foo attr="bar">')

def testStringMethods(self):
simple = "capitalize isalpha isdigit islower isspace istitle isupper" \
" lower lstrip rstrip strip swapcase upper".split()
returnsTainted = "capitalize lower lstrip rstrip strip swapcase upper"
returnsTainted = returnsTainted.split()
unquoted = b'\tThis is a test '
tainted = self._getClass()(unquoted)
for f in simple:
v = getattr(tainted, f)()
self.assertEqual(v, getattr(unquoted, f)())
if f in returnsTainted:
self.assertTrue(isinstance(v, self._getClass()))
else:
self.assertFalse(isinstance(v, self._getClass()))

optArg = "lstrip rstrip strip".split()
for f in optArg:
v = getattr(tainted, f)(b" ")
self.assertEqual(v, getattr(unquoted, f)(b" "))
self.assertTrue(isinstance(v, self._getClass()))

justify = "center ljust rjust".split()
for f in justify:
v = getattr(tainted, f)(30)
self.assertEqual(v, getattr(unquoted, f)(30))
self.assertTrue(isinstance(v, self._getClass()))

searches = "find index rfind rindex endswith startswith".split()
searchraises = "index rindex".split()
for f in searches:
v = getattr(tainted, f)(b'test')
self.assertEqual(v, getattr(unquoted, f)(b'test'))
if f in searchraises:
self.assertRaises(ValueError, getattr(tainted, f), b'nada')

self.assertEqual(tainted.count(b'test', 1, -1),
unquoted.count(b'test', 1, -1))

self.assertEqual(tainted.decode(), unquoted.decode())
from AccessControl.tainted import TaintedString
self.assertTrue(isinstance(tainted.decode(), TaintedString))

self.assertEqual(tainted.expandtabs(10), unquoted.expandtabs(10))
self.assertTrue(isinstance(tainted.expandtabs(), self._getClass()))

self.assertEqual(tainted.replace(b'test', b'spam'),
unquoted.replace(b'test', b'spam'))
self.assertTrue(isinstance(tainted.replace(b'test', b'<'),
self._getClass()))
self.assertFalse(isinstance(tainted.replace(b'test', b'spam'),
self._getClass()))

self.assertEqual(tainted.split(), unquoted.split())
for part in self._getClass()(b'< < <').split():
self.assertTrue(isinstance(part, self._getClass()))
for part in tainted.split():
self.assertFalse(isinstance(part, self._getClass()))

multiline = b'test\n<tainted>'
lines = self._getClass()(multiline).split()
self.assertEqual(lines, multiline.split())
self.assertTrue(isinstance(lines[1], self._getClass()))
self.assertFalse(isinstance(lines[0], self._getClass()))

if six.PY3:
transtable = bytes(range(256))
else:
transtable = ''.join(map(chr, range(256)))
self.assertEqual(tainted.translate(transtable),
unquoted.translate(transtable))
self.assertTrue(isinstance(self._getClass()(b'<').translate(transtable),
self._getClass()))
if six.PY2:
# Translate no longer supports a second argument
self.assertFalse(isinstance(self._getClass()(b'<').translate(transtable,
b'<'),
self._getClass()))

def testConstructor(self):
from AccessControl.tainted import TaintedBytes
if six.PY2:
self.assertRaises(ValueError, TaintedBytes, [60])
if six.PY3:
self.assertEqual(TaintedBytes(60), b'<')
self.assertEqual(TaintedBytes(32), b' ')
self.assertEqual(TaintedBytes(b'abc'), b'abc')
self.assertRaises(ValueError, TaintedBytes, "abc")

0 comments on commit c071b8d

Please sign in to comment.