# Information Flow

In this chapter, we detail how to track information flows in python by tainting input strings, and tracking the taint across string operations.

Some material on `eval` exploitation is adapted from the excellent [blog post](https://nedbatchelder.com/blog/201206/eval_really_is_dangerous.html) by Ned Batchelder.

**Prerequisites**

* You should have read the [chapter on coverage](Coverage.ipynb).

Setting up our infrastructure

In [None]:
import fuzzingbook_utils

In [None]:
from ExpectError import ExpectError

In [None]:
import inspect
import enum

In [None]:
%%html
<div>
<style>
div.todo {
    color:red;
    font-weight: bold;
}
div.todo::before {
    content: "TODO: ";
}
div.done {
    color:blue;
    font-weight: bold;
}
div.done::after {
    content: " :DONE";
}

</style>
<script>
  function todo_toggle() {
    if (todo_shown){
      $('div.todo').hide('500');
      $('div.done').hide('500');
      $('#toggleButton').val('Show Todo')
    } else {
      $('div.todo').show('500');
      $('div.done').show('500');
      $('#toggleButton').val('Hide Todo')
    }
    todo_shown = !todo_shown
  }
  $( document ).ready(function(){
    todo_shown=false;
    $('div.todo').hide()
  });
</script>
<form action="javascript:todo_toggle()"><input type="submit" id="toggleButton" value="Show Todo"></form>

Say we want to implement a calculator service in Python. A really simple way to do that is to rely on the `eval()` function in Python. Since we do not want our users to be able to execute arbitrary commands on our server, we use `eval()` with empty `locals` and `globals`

In [None]:
def my_calculator(my_input):
    result = eval(my_input, {}, {})
    print("The result of %s was %d" % (my_input, result))

It wors as expected:

In [None]:
my_calculator('1+2')

Does it?

In [None]:
with ExpectError():
    my_calculator('__import__("os").popen("ls").read()')

As you can see from the error, `eval()` completed successfully, with the system command `ls` executing successfully. It is easy enough for the user to see the output if needed.

In [None]:
my_calculator("1 if __builtins__['print'](__import__('os').popen('ls').read()) else 0")

The problem is that the Python `__builtins__` is [inserted by default](https://docs.python.org/3/library/functions.html#eval) when one uses `eval()`. We can avoid this by restricting `__builtins__` in `eval` explicitly.

In [None]:
def my_calculator(my_input):
    result = eval(my_input, {"__builtins__":None}, {})
    print("The result of %s was %d" % (my_input, result))

Does it help?

In [None]:
with ExpectError():
    my_calculator("1 if __builtins__['print'](__import__('os').popen('ls').read()) else 0")

But does it actually?

In [None]:
my_calculator("1 if [x['print'](x['__import__']('os').popen('ls').read()) for x in ([x for x in (1).__class__.__base__.__subclasses__() if x.__name__ == 'Sized'][0].__len__.__globals__['__builtins__'],)] else 0")

The problem here is that when the user has a way to inject **uninterpreted strings** that can reach a dangerous routine such as  `eval()` or an `exec()`, it makes it possible for them to inject dangerous code. What we need is a way to restrict the ability of uninterpreted input string fragments from reaching dangerous portions of code.

## A Simple Taint Tracker

For capturing information flows we need a new string class. The idea is to use the new tainted string class `tstr` as a wrapper on the original `str` class.

 We need to write the `tstr.__new__()` method because we want to track the parent object responsible for the taint (essentially because we want to customize the object creation, and `__init__` is [too late](https://docs.python.org/3/reference/datamodel.html#basic-customization) for that.).

The taint map in variable `_taint` contains non-overlapping taints mapped to the original string.

In [None]:
class tstr_(str):
    def __new__(cls, value, *args, **kw):
        return super(tstr_, cls).__new__(cls, value)

class tstr(tstr_):
    def __init__(self, value, taint=None, parent=None, **kwargs):
        self.parent = parent
        l = len(self)
        if taint:
            if isinstance(taint, int):
                self._taint = list(range(taint, taint + len(self)))
            else:
                assert len(taint) == len(self)
                self._taint = taint
        else:
            self._taint = list(range(0, len(self)))

    def has_taint(self):
        return any(True for i in self._taint if i >= 0)

    def __repr__(self):
        return str.__repr__(self)

    def __str__(self):
        return str.__str__(self)

In [None]:
t = tstr('hello')
t.has_taint(), t._taint

In [None]:
t = tstr('world', taint = 6)
t._taint

By default, when we wrap a string, it is tainted. Hence we also need a way to `untaint` the string.

In [None]:
class tstr(tstr):
    def untaint(self):
        self._taint =  [-1] * len(self)
        return self

In [None]:
t = tstr('hello world')
t.untaint()
t.has_taint()

However, the taint does not transition from the whole string to parts.

In [None]:
with ExpectError():
    t = tstr('hello world')
    t[0:5].has_taint()

### Slice

The Python `slice` operator `[n:m]` relies on the object being an `iterator`. Hence, we define the `__iter__()` method.

In [None]:
class tstr(tstr):
    def __iter__(self):
        return tstr_iterator(self)
    
    def create(self, res, taint):
        return tstr(res, taint, self)

    def __getitem__(self, key):
        res = super().__getitem__(key)
        if type(key) == int:
            key = len(self) + key if key < 0 else key
            return self.create(res, [self._taint[key]])
        elif type(key) == slice:
            return self.create(res, self._taint[key])
        else:
            assert False

The Python `slice` operator `[n:m]` relies on the object being an `iterator`. Hence, we define the `__iter__()` method.

#### The iterator class
The `__iter__()` method requires a supporting `iterator` object.

In [None]:
class tstr_iterator():
    def __init__(self, tstr):
        self._tstr = tstr
        self._str_idx = 0

    def __next__(self):
        if self._str_idx == len(self._tstr): raise StopIteration
        # calls tstr getitem should be tstr
        c = self._tstr[self._str_idx]
        assert type(c) is tstr
        self._str_idx += 1
        return c

In [None]:
t = tstr('hello world')
t[0:5].has_taint()

### Helper Methods
We define a few helper methods that deals with the mapped taint index.

In [None]:
class tstr(tstr):
    class TaintException(Exception):
        pass

    def x(self, i=0):
        v = self._x(i)
        if v < 0:
            raise taint.TaintException('Invalid mapped char idx in tstr')
        return v

    def _x(self, i=0):
        return self.get_mapped_char_idx(i)

    def get_mapped_char_idx(self, i):
        if self._taint:
            return self._taint[i]
        else:
            raise taint.TaintException('Invalid request idx')

    def get_first_mapped_char(self):
        for i in self._taint:
            if i >= 0:
                return i
        return -1

    def is_tpos_contained(self, tpos):
        return tpos in self._taint

    def is_idx_tainted(self, idx):
        return self._taint[idx] != -1

In [None]:
my_str = tstr('abcdefghijkl', taint=list(range(4,16)))
my_str[0].x(),my_str[-1].x(),my_str[-2].x()

In [None]:
s = my_str[0:4]
s.x(0),s.x(3)

In [None]:
s = my_str[0:-1]
len(s),s.x(10)

### Concatenation

Implementing concatenation is straight forward:

In [None]:
class tstr(tstr):
    def __add__(self, other):
        if type(other) is tstr:
            return self.create(str.__add__(self, other), (self._taint + other._taint))
        else:
            return self.create(str.__add__(self, other), (self._taint + [-1 for i in other]))

Testing concatenations

In [None]:
my_str1 = tstr("hello")
my_str2 = tstr("world", taint=6)
my_str3 = "bye"
v = my_str1 + my_str2
print(v._taint)

w = my_str1 + my_str3 + my_str2
print(w._taint)

In [None]:
class tstr(tstr):
    def __radd__(self, other):  #concatenation (+) -- other is not tstr
        if type(other) is tstr:
            return self.create(str.__add__(other, self), (other._taint + self._taint))
        else:
            return self.create(str.__add__(other, self), ([-1 for i in other] + self._taint))

In [None]:
my_str1 = "hello"
my_str2 = tstr("world")
v = my_str1 + my_str2
v._taint

### Replace

In [None]:
class tstr(tstr):
    def replace(self, a, b, n=None):
        old_taint = self._taint
        b_taint = b._taint if type(b) is tstr else [-1] * len(b)
        mystr = str(self)
        i = 0
        while True:
            if n and i >= n: break
            idx = mystr.find(a)
            if idx == -1: break
            last = idx + len(a)
            mystr = mystr.replace(a, b, 1)
            partA, partB = old_taint[0:idx], old_taint[last:]
            old_taint = partA + b_taint + partB
            i += 1
        return self.create(mystr, old_taint)

In [None]:
my_str = tstr("aa cde aa")
res = my_str.replace('aa', 'bb')
res, res._taint

### Split

We essentially have to re-implement split operations, and split by space is slightly different from other splits.

In [None]:
class tstr(tstr):
    def _split_helper(self, sep, splitted):
        result_list = []
        last_idx = 0
        first_idx = 0
        sep_len = len(sep)

        for s in splitted:
            last_idx = first_idx + len(s)
            item = self[first_idx:last_idx]
            result_list.append(item)
            first_idx = last_idx + sep_len
        return result_list

    def _split_space(self, splitted):
        result_list = []
        last_idx = 0
        first_idx = 0
        sep_len = 0
        for s in splitted:
            last_idx = first_idx + len(s)
            item = self[first_idx:last_idx]
            result_list.append(item)
            v = str(self[last_idx:])
            sep_len = len(v) - len(v.lstrip(' '))
            first_idx = last_idx + sep_len
        return result_list

    def rsplit(self, sep=None, maxsplit=-1):
        splitted = super().rsplit(sep, maxsplit)
        if not sep:
            return self._split_space(splitted)
        return self._split_helper(sep, splitted)

    def split(self, sep=None, maxsplit=-1):
        splitted = super().split(sep, maxsplit)
        if not sep:
            return self._split_space(splitted)
        return self._split_helper(sep, splitted)

In [None]:
my_str = tstr('ab cdef ghij kl')
ab, cdef, ghij, kl = my_str.rsplit(sep=' ')
print(ab._taint, cdef._taint, ghij._taint, kl._taint)

my_str = tstr('ab   cdef ghij    kl', taint=100)
ab, cdef, ghij, kl = my_str.rsplit()
print(ab._taint, cdef._taint, ghij._taint, kl._taint)

In [None]:
my_str = tstr('ab cdef ghij kl', taint=list(range(0, 15)))
ab, cdef, ghij, kl = my_str.split(sep=' ')
print(ab._taint, cdef._taint, kl._taint)

my_str = tstr('ab   cdef ghij    kl', taint=list(range(0, 20)))
ab, cdef, ghij, kl = my_str.split()
print(ab._taint, cdef._taint, kl._taint)

### Strip

In [None]:
class tstr(tstr):
    def strip(self, cl=None):
        return self.lstrip(cl).rstrip(cl)

    def lstrip(self, cl=None):
        res = super().lstrip(cl)
        i = self.find(res)
        return self[i:]

    def rstrip(self, cl=None):
        res = super().rstrip(cl)
        return self[0:len(res)]


In [None]:
my_str1 = tstr("  abc  ")
v = my_str1.strip()
v, v._taint

In [None]:
my_str1 = tstr("  abc  ")
v = my_str1.lstrip()
v, v._taint

In [None]:
my_str1 = tstr("  abc  ")
v = my_str1.rstrip()
v, v._taint

### Expand Tabs

In [None]:
class tstr(tstr):
    def expandtabs(self, n=8):
        parts = self.split('\t')
        res = super().expandtabs(n)
        all_parts = []
        for i, p in enumerate(parts):
            all_parts.extend(p._taint)
            if i < len(parts) - 1:
                l = len(all_parts) % n
                all_parts.extend([p._taint[-1]] * l)
        return self.create(res, all_parts)

In [None]:
my_tstr = tstr("ab\tcd")
my_str = str("ab\tcd")
v1 = my_str.expandtabs(4)
v2 = my_tstr.expandtabs(4)
print(len(v1), repr(my_tstr), repr(v2), v2._taint)

In [None]:
class tstr(tstr):
    def join(self, iterable):
        mystr = ''
        mytaint = []
        sep_taint = self._taint
        lst = list(iterable)
        for i, s in enumerate(lst):
            staint = s._taint if type(s) is tstr else [-1] * len(s)
            mytaint.extend(staint)
            mystr += str(s)
            if i < len(lst)-1:
                mytaint.extend(sep_taint)
                mystr += str(self)
        res = super().join(iterable)
        assert len(res) == len(mystr)
        return self.create(res, mytaint)

In [None]:
my_str = tstr("ab cd", taint=100)
(v1, v2), v3 = my_str.split(), 'ef'
print(v1._taint, v2._taint)
v4 = tstr('').join([v2,v3,v1])
print(v4, v4._taint)

In [None]:
my_str = tstr("ab cd", taint=100)
(v1, v2), v3 = my_str.split(), 'ef'
print(v1._taint, v2._taint)
v4 = tstr(',').join([v2,v3,v1])
print(v4, v4._taint)

### Partitions

In [None]:
class tstr(tstr):
    def partition(self, sep):
        partA, sep, partB = super().partition(sep)
        return (
            self.create(partA, self._taint[0:len(partA)]), self.create(sep, self._taint[len(partA): len(partA) + len(sep)]), self.create(partB, self._taint[len(partA) + len(sep):]))

    def rpartition(self, sep):
        partA, sep, partB = super().rpartition(sep)
        return (self.create(partA, self._taint[0:len(partA)]), self.create(sep, self._taint[len(partA): len(partA) + len(sep)]), self.create(partB, self._taint[len(partA) + len(sep):]))

### Justify

In [None]:
class tstr(tstr):
    def ljust(self, width, fillchar=' '):
        res = super().ljust(width, fillchar)
        initial = len(res) - len(self)
        if type(fillchar) is tstr:
            t = fillchar.x()
        else:
            t = -1
        return self.create(res, [t] * initial + self._taint)

    def rjust(self, width, fillchar=' '):
        res = super().rjust(width, fillchar)
        final = len(res) - len(self)
        if type(fillchar) is tstr:
            t = fillchar.x()
        else:
            t = -1
        return self.create(res, self._taint + [t] * final)

### String methods that do not change taint

In [None]:
def make_str_wrapper_eq_taint(fun):
    def proxy(*args, **kwargs):
        res = fun(*args, **kwargs)
        return args[0].create(res, args[0]._taint)
    return proxy

for name, fn in inspect.getmembers(str, callable):
    if name in ['swapcase', 'upper', 'lower', 'capitalize', 'title']:
        setattr(tstr, name, make_str_wrapper_eq_taint(fn))


In [None]:
a = tstr('aa', taint=100).upper()
a, a._taint

### General wrappers

These are not strictly needed for operation, but can be useful for tracing

In [None]:
def make_str_wrapper(fun):
    def proxy(*args, **kwargs):
        res = fun(*args, **kwargs)
        return res
    return proxy

import types
tstr_members = [name for name, fn in inspect.getmembers(tstr,callable)
if type(fn) == types.FunctionType and fn.__qualname__.startswith('tstr')]

for name, fn in inspect.getmembers(str, callable):
    if name not in set(['__class__', '__new__', '__str__', '__init__',
                        '__repr__','__getattribute__']) | set(tstr_members):
        setattr(tstr, name, make_str_wrapper(fn))

### Methods yet to be translated

These methods generate strings from other strings. However, we do not have the right implementations for any of these. Hence these are marked as dangerous until we can generate the right translations.

In [None]:
def make_str_abort_wrapper(fun):
    def proxy(*args, **kwargs):
        raise TaintException('%s Not implemented in TSTR' % fun.__name__)
    return proxy

for name, fn in inspect.getmembers(str, callable):
    if name in ['__format__', '__rmod__', '__mod__', 'format_map', 'format',
               '__mul__','__rmul__','center','zfill', 'decode', 'encode', 'splitlines']:
        setattr(tstr, name, make_str_abort_wrapper(fn))

## EOF Tracker

Sometimes we want to know where an empty string came from. That is, if an empty string is the result of operations on a tainted string, we want to know the best guess as to what the taint index of the preceding character is.

### Slice


For detecting EOF, we need to carry the cursor. The main idea is the cursor indicates the taint of the character in front of it.

In [None]:
class eoftstr(tstr):
    def create(self, res, taint):
        return eoftstr(res, taint, self)
    
    def __getitem__(self, key):
        def get_interval(key):
            return ((0 if key.start is None else key.start),
                    (len(res) if key.stop is None else key.stop))

        res = super().__getitem__(key)
        if type(key) == int:
            key = len(self) + key if key < 0 else key
            return self.create(res, [self._taint[key]])
        elif type(key) == slice:
            if res:
                return self.create(res, self._taint[key])
            # Result is an empty string
            t = self.create(res, self._taint[key])
            key_start, key_stop = get_interval(key)
            cursor = 0
            if key_start < len(self):
                assert key_stop < len(self)
                cursor = self._taint[key_stop]
            else:
                if len(self) == 0:
                    # if the original string was empty, we assume that any
                    # empty string produced from it should carry the same taint.
                    cursor = self.x()
                else:
                    # Key start was not in the string. We can reply only
                    # if the key start was just outside the string, in
                    # which case, we guess.
                    if key_start != len(self):
                        raise taint.TaintException('Can\'t guess the taint')
                    cursor = self._taint[len(self) - 1] + 1
            # _tcursor gets created only for empty strings.
            t._tcursor = cursor
            return t

        else:
            assert False

In [None]:
class eoftstr(eoftstr):
    def get_mapped_char_idx(self, i):
        if self._taint:
            return self._taint[i]
        else:
            if i != 0:
                raise taint.TaintException('Invalid request idx')
            # self._tcursor gets created only for empty strings.
            # use the exception to determine which ones need it.
            return self._tcursor

In [None]:
t = eoftstr('hello world')
print(repr(t[11:]))
print(t[11:].x(), t[11:]._taint)

## A Comparison Tracker

Sometimes, we also want to know what each character in an input was compared to.

### Operators

In [None]:
class Op(enum.Enum):
    LT = 0
    LE = enum.auto()
    EQ = enum.auto()
    NE = enum.auto()
    GT = enum.auto()
    GE = enum.auto()
    IN = enum.auto()
    NOT_IN = enum.auto()
    IS = enum.auto()
    IS_NOT = enum.auto()
    FIND_STR = enum.auto()


COMPARE_OPERATORS = {
    Op.EQ: lambda x, y: x == y,
    Op.NE: lambda x, y: x != y,
    Op.IN: lambda x, y: x in y,
    Op.NOT_IN: lambda x, y: x not in y,
    Op.FIND_STR: lambda x, y: x.find(y)
}

Comparisons = []

### Instructions

In [None]:
class Instr:
    def __init__(self, o, a, b):
        self.opA = a
        self.opB = b
        self.op = o

    def o(self):
        if self.op == Op.EQ:
            return 'eq'
        elif self.op == Op.NE:
            return 'ne'
        else:
            return '?'

    def opS(self):
        if not self.opA.has_taint() and type(self.opB) is tstr:
            return (self.opB, self.opA)
        else:
            return (self.opA, self.opB)

    @property
    def op_A(self):
        return self.opS()[0]

    @property
    def op_B(self):
        return self.opS()[1]

    def __repr__(self):
        return "%s,%s,%s" % (self.o(), repr(self.opA), repr(self.opB))

    def __str__(self):
        if self.op == Op.EQ:
            if str(self.opA) == str(self.opB):
                return "%s = %s" % (repr(self.opA), repr(self.opB))
            else:
                return "%s != %s" % (repr(self.opA), repr(self.opB))
        elif self.op == Op.NE:
            if str(self.opA) == str(self.opB):
                return "%s = %s" % (repr(self.opA), repr(self.opB))
            else:
                return "%s != %s" % (repr(self.opA), repr(self.opB))
        elif self.op == Op.IN:
            if str(self.opA) in str(self.opB):
                return "%s in %s" % (repr(self.opA), repr(self.opB))
            else:
                return "%s not in %s" % (repr(self.opA), repr(self.opB))
        elif self.op == Op.NOT_IN:
            if str(self.opA) in str(self.opB):
                return "%s in %s" % (repr(self.opA), repr(self.opB))
            else:
                return "%s not in %s" % (repr(self.opA), repr(self.opB))
        else:
            assert False

### Equivalance

In [None]:
class ctstr(eoftstr):
    def create(self, res, taint):
        o = ctstr(res, taint, self)
        o.comparisons = self.comparisons
        return o
    
    def with_comparisons(self, comparisons):
        self.comparisons = comparisons
        return self

In [None]:
class ctstr(ctstr):
    def __eq__(self, other):
        if len(self) == 0 and len(other) == 0:
            self.comparisons.append(Instr(Op.EQ, self, other))
            return True
        elif len(self) == 0:
            self.comparisons.append(Instr(Op.EQ, self, other[0]))
            return False
        elif len(other) == 0:
            self.comparisons.append(Instr(Op.EQ, self[0], other))
            return False
        elif len(self) == 1 and len(other) == 1:
            self.comparisons.append(Instr(Op.EQ, self, other))
            return super().__eq__(other)
        else:
            if not self[0] == other[0]:
                return False
            return self[1:] == other[1:]

In [None]:
t = ctstr('hello world', taint=100).with_comparisons([])
print(t.comparisons)
t == 'hello'
for c in t.comparisons:
    print(repr(c))

In [None]:
class ctstr(ctstr):
    def __ne__(self, other):
        return not self.__eq__(other)

In [None]:
t = ctstr('hello', taint=100).with_comparisons([])
print(t.comparisons)
t != 'bye'
for c in t.comparisons:
    print(repr(c))

In [None]:
class ctstr(ctstr):
    def __contains__(self, other):
        self.comparisons.append(Instr(Op.IN, self, other))
        return super().__contains__(other)

In [None]:
class ctstr(ctstr):
    def find(self, sub, start=None, end=None):
        if start == None:
            start_val = 0
        if end == None:
            end_val = len(self)
        self.comparisons.append(Instr(Op.IN, self[start_val:end_val], sub))
        return super().find(sub, start, end)

## Lessons Learned

* One can track the information flow form input to the internals of a system.

## Next Steps

_Link to subsequent chapters (notebooks) here:_

## Background

\cite{Lin2008}

## Exercises

_Close the chapter with a few exercises such that people have things to do.  To make the solutions hidden (to be revealed by the user), have them start with_

```markdown
**Solution.**
```

_Your solution can then extend up to the next title (i.e., any markdown cell starting with `#`)._

_Running `make metadata` will automatically add metadata to the cells such that the cells will be hidden by default, and can be uncovered by the user.  The button will be introduced above the solution._

### Exercise 1: _Title_

_Text of the exercise_

In [None]:
# Some code that is part of the exercise
pass

_Some more text for the exercise_

**Solution.** _Some text for the solution_

In [None]:
# Some code for the solution
2 + 2

_Some more text for the solution_

### Exercise 2: _Title_

_Text of the exercise_

**Solution.** _Solution for the exercise_