Skip to content

Commit

Permalink
TaintedBytes to processInputs :lines
Browse files Browse the repository at this point in the history
  • Loading branch information
gotcha committed Jan 28, 2018
1 parent 1ec18f4 commit 22dc25d
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 10 deletions.
44 changes: 38 additions & 6 deletions src/AccessControl/tainted.py
Expand Up @@ -28,6 +28,23 @@
from cgi import escape


def should_be_tainted(value):
if isinstance(value, TaintedString):
return should_be_tainted(value._value)
elif isinstance(value, int):
return 60 == value
elif isinstance(value, bytes):
return 60 in value
else:
return '<' in value


def taint_string(value):
if isinstance(value, bytes):
return TaintedBytes(value)
else:
return TaintedString(value)

@total_ordering
class TaintedString(object):

Expand Down Expand Up @@ -58,15 +75,15 @@ def __len__(self):

def __getitem__(self, index):
v = self._value[index]
if '<' in v:
if should_be_tainted(v):
v = self.__class__(v)
return v

def __getslice__(self, i, j):
i = max(i, 0)
j = max(j, 0)
v = self._value[i:j]
if '<' in v:
if should_be_tainted(v):
v = self.__class__(v)
return v

Expand Down Expand Up @@ -120,21 +137,21 @@ def expandtabs(self, *args):

def replace(self, *args):
v = self._value.replace(*args)
if '<' in v:
if should_be_tainted(v):
v = self.__class__(v)
return v

def split(self, *args):
r = self._value.split(*args)
return list(map(lambda v, c=self.__class__: '<' in v and c(v) or v, r))
return list(map(lambda v, c=self.__class__: should_be_tainted(v) and c(v) or v, r))

def splitlines(self, *args):
r = self._value.splitlines(*args)
return list(map(lambda v, c=self.__class__: '<' in v and c(v) or v, r))
return list(map(lambda v, c=self.__class__: should_be_tainted(v) and c(v) or v, r))

def translate(self, *args):
v = self._value.translate(*args)
if '<' in v:
if should_be_tainted(v):
v = self.__class__(v)
return v

Expand Down Expand Up @@ -169,3 +186,18 @@ def createOneOptArgWrapper(func):

for f in oneOptArgWrappedMethods:
setattr(TaintedString, f, createOneOptArgWrapper(f))


class TaintedBytes(TaintedString):

def __init__(self, value):
if isinstance(value, int):
value = bytes([value])
self._value = value

def quoted(self):
result = escape(self._value.decode('utf8'), 1)
return result.encode('utf8')

def __str__(self):
return self._value.decode('utf8')
72 changes: 68 additions & 4 deletions src/AccessControl/tests/test_tainted.py
Expand Up @@ -18,6 +18,32 @@
import six


class TestFunctions(unittest.TestCase):

def test_taint_string(self):
from AccessControl.tainted import taint_string
from AccessControl.tainted import TaintedString
from AccessControl.tainted import TaintedBytes
self.assertIsInstance(taint_string('string'), TaintedString)
self.assertIsInstance(taint_string(b'bytes'), TaintedBytes)

def test_should_be_tainted(self):
from AccessControl.tainted import should_be_tainted
from AccessControl.tainted import taint_string
from AccessControl.tainted import TaintedString
from AccessControl.tainted import TaintedBytes
self.assertFalse(should_be_tainted('string'))
self.assertTrue(should_be_tainted('<string'))
self.assertFalse(should_be_tainted(b'string'))
self.assertTrue(should_be_tainted(b'<string'))
self.assertFalse(should_be_tainted(b'string'[0]))
self.assertTrue(should_be_tainted(b'<string'[0]))
self.assertFalse(should_be_tainted(taint_string('string')))
self.assertTrue(should_be_tainted(taint_string('<string')))
self.assertFalse(should_be_tainted(taint_string(b'string')))
self.assertTrue(should_be_tainted(taint_string(b'<string')))


class TestTaintedString(unittest.TestCase):

def setUp(self):
Expand All @@ -35,6 +61,9 @@ def testStr(self):
def testRepr(self):
self.assertEqual(repr(self.tainted), repr(self.quoted))

def testEqual(self):
self.assertTrue(self.tainted, self.unquoted)

def testCmp(self):
self.assertTrue(self.tainted == self.unquoted)
self.assertTrue(self.tainted < 'a')
Expand All @@ -61,11 +90,12 @@ def testGetSlice(self):
self.assertFalse(isinstance(self.tainted[1:], self._getClass()))
self.assertEqual(self.tainted[1:], self.unquoted[1:])

CONCAT = 'test'
def testConcat(self):
self.assertTrue(isinstance(self.tainted + 'test', self._getClass()))
self.assertEqual(self.tainted + 'test', self.unquoted + 'test')
self.assertTrue(isinstance('test' + self.tainted, self._getClass()))
self.assertEqual('test' + self.tainted, 'test' + self.unquoted)
self.assertTrue(isinstance(self.tainted + self.CONCAT, self._getClass()))
self.assertEqual(self.tainted + self.CONCAT, self.unquoted + self.CONCAT)
self.assertTrue(isinstance(self.CONCAT + self.tainted, self._getClass()))
self.assertEqual(self.CONCAT + self.tainted, self.CONCAT + self.unquoted)

def testMultiply(self):
self.assertTrue(isinstance(2 * self.tainted, self._getClass()))
Expand Down Expand Up @@ -157,3 +187,37 @@ def testStringMethods(self):

def testQuoted(self):
self.assertEqual(self.tainted.quoted(), self.quoted)


class TestTaintedBytes(TestTaintedString):

def setUp(self):
self.unquoted = b'<test attr="&">'
self.quoted = b'&lt;test attr=&quot;&amp;&quot;&gt;'
self.tainted = self._getClass()(self.unquoted)

def _getClass(self):
from AccessControl.tainted import TaintedBytes
return TaintedBytes

def testCmp(self):
self.assertTrue(self.tainted == self.unquoted)
self.assertTrue(self.tainted < b'a')
self.assertTrue(self.tainted > b'.')

CONCAT = b'test'

def testGetItem(self):
self.assertTrue(isinstance(self.tainted[0], self._getClass()))
self.assertEqual(self.tainted[0], self._getClass()(b'<'))
self.assertFalse(isinstance(self.tainted[-1], self._getClass()))
self.assertEqual(self.tainted[-1], 62)

def testStr(self):
self.assertEqual(str(self.tainted), self.unquoted.decode('utf8'))

def testGetSlice(self):
self.assertTrue(isinstance(self.tainted[0:1], self._getClass()))
self.assertEqual(self.tainted[0:1], b'<')
self.assertFalse(isinstance(self.tainted[1:], self._getClass()))
self.assertEqual(self.tainted[1:], self.unquoted[1:])

0 comments on commit 22dc25d

Please sign in to comment.