# Track Access

**IMPORTANT.** Not all methods are implemented. A more detailed treatment of `tstr` can be found in the FuzzingBook [Tracking Information Flow](https://www.fuzzingbook.org/html/InformationFlow.html)  chapter.

In [None]:
import sys
import random
import enum
import src.utils as utils
import string

## tstr

`tstr` is a simple proxy for strings. It tracks the origin of any given descendent string.

In [None]:
import inspect
import enum

class tstr_(str):
    def __new__(cls, value, *args, **kw):
        return super(tstr_, cls).__new__(cls, value)

class tstr(tstr_):
    def __init__(self, value, taint=None, parent=None, **kwargs):
        self.parent = parent
        l = len(self)
        if taint is None:
            taint = 0
        self.taint = list(range(taint, taint + l)) if isinstance(
            taint, int) else taint
        assert len(self.taint) == l

    def __repr__(self):
        return self

    def __str__(self):
        return str.__str__(self)

class tstr(tstr):
    def untaint(self):
        self.taint = [None] * len(self)
        return self

    def has_taint(self):
        return any(True for i in self.taint if i is not None)

    def taint_in(self, gsentence):
        return set(self.taint) <= set(gsentence.taint)



class tstr(tstr):
    def create(self, res, taint):
        return tstr(res, taint, self)



class tstr(tstr):            
    def __getitem__(self, key):
        def get_interval(key):
            return ((0 if key.start is None else key.start),
                    (len(res) if key.stop is None else key.stop))

        res = super().__getitem__(key)
        if isinstance(key, int):
            key = len(self) + key if key < 0 else key
            return self.create(res, [self.taint[key]])
        elif isinstance(key, slice):
            if res:
                return self.create(res, self.taint[key])
            # Result is an empty string
            t = self.create(res, self.taint[key])
            key_start, key_stop = get_interval(key)
            cursor = 0
            if key_start < len(self):
                assert key_stop < len(self)
                #cursor = self.taint[key_stop]
            else:
                if len(self) == 0:
                    # if the original string was empty, we assume that any
                    # empty string produced from it should carry the same
                    # taint.
                    #cursor = self.x()
                #else:
                    # Key start was not in the string. We can reply only
                    # if the key start was just outside the string, in
                    # which case, we guess.
                    if key_start != len(self):
                        raise tstr.TaintException('Can\'t guess the taint')
                    #cursor = self.taint[len(self) - 1] + 1
            # _tcursor gets created only for empty strings.
            t._tcursor = cursor
            return t

        else:
            assert False

class tstr(tstr):
    def __iter__(self):
        return tstr_iterator(self)

class tstr_iterator():
    def __init__(self, tstr):
        self._tstr = tstr
        self._str_idx = 0

    def __next__(self):
        if self._str_idx == len(self._tstr):
            raise StopIteration
        # calls tstr getitem should be tstr
        c = self._tstr[self._str_idx]
        assert isinstance(c, tstr)
        self._str_idx += 1
        return c

class tstr(tstr):
    def __add__(self, other):
        if isinstance(other, tstr):
            return self.create(str.__add__(self, other),
                               (self.taint + other.taint))
        else:
            return self.create(str.__add__(self, other),
                               (self.taint + [-1 for i in other]))

class tstr(tstr):
    def __radd__(self, other):
        if other:
            taint = other.taint if isinstance(other, tstr) else [
                None for i in other]
        else:
            taint = []
        return self.create(str.__add__(other, self), (taint + self.taint))

class tstr(tstr):
    class TaintException(Exception):
        pass

    def x(self, i=0):
        if not self.taint:
            raise tstr.TaintException('Invalid request idx')
        if isinstance(i, int):
            return [self[p]
                    for p in [k for k, j in enumerate(self.taint) if j == i]]
        elif isinstance(i, slice):
            r = range(i.start or 0, i.stop or len(self), i.step or 1)
            return [self[p]
                    for p in [k for k, j in enumerate(self.taint) if j in r]]

class tstr(tstr):
    def replace(self, a, b, n=None):
        old_taint = self.taint
        b_taint = b.taint if isinstance(b, tstr) else [None] * len(b)
        mystr = str(self)
        i = 0
        while True:
            if n and i >= n:
                break
            idx = mystr.find(a)
            if idx == -1:
                break
            last = idx + len(a)
            mystr = mystr.replace(a, b, 1)
            partA, partB = old_taint[0:idx], old_taint[last:]
            old_taint = partA + b_taint + partB
            i += 1
        return self.create(mystr, old_taint)

class tstr(tstr):
    def _split_helper(self, sep, splitted):
        result_list = []
        last_idx = 0
        first_idx = 0
        sep_len = len(sep)

        for s in splitted:
            last_idx = first_idx + len(s)
            item = self[first_idx:last_idx]
            result_list.append(item)
            first_idx = last_idx + sep_len
        return result_list

    def _split_space(self, splitted):
        result_list = []
        last_idx = 0
        first_idx = 0
        sep_len = 0
        for s in splitted:
            last_idx = first_idx + len(s)
            item = self[first_idx:last_idx]
            result_list.append(item)
            v = str(self[last_idx:])
            sep_len = len(v) - len(v.lstrip(' '))
            first_idx = last_idx + sep_len
        return result_list

    def rsplit(self, sep=None, maxsplit=-1):
        splitted = super().rsplit(sep, maxsplit)
        if not sep:
            return self._split_space(splitted)
        return self._split_helper(sep, splitted)

    def split(self, sep=None, maxsplit=-1):
        splitted = super().split(sep, maxsplit)
        if not sep:
            return self._split_space(splitted)
        return self._split_helper(sep, splitted)

class tstr(tstr):
    def strip(self, cl=None):
        return self.lstrip(cl).rstrip(cl)

    def lstrip(self, cl=None):
        res = super().lstrip(cl)
        i = self.find(res)
        return self[i:]

    def rstrip(self, cl=None):
        res = super().rstrip(cl)
        return self[0:len(res)]


class tstr(tstr):
    def expandtabs(self, n=8):
        parts = self.split('\t')
        res = super().expandtabs(n)
        all_parts = []
        for i, p in enumerate(parts):
            all_parts.extend(p.taint)
            if i < len(parts) - 1:
                l = len(all_parts) % n
                all_parts.extend([p.taint[-1]] * l)
        return self.create(res, all_parts)

class tstr(tstr):
    def join(self, iterable):
        mystr = ''
        mytaint = []
        sep_taint = self.taint
        lst = list(iterable)
        for i, s in enumerate(lst):
            staint = s.taint if isinstance(s, tstr) else [None] * len(s)
            mytaint.extend(staint)
            mystr += str(s)
            if i < len(lst) - 1:
                mytaint.extend(sep_taint)
                mystr += str(self)
        res = super().join(iterable)
        assert len(res) == len(mystr)
        return self.create(res, mytaint)

class tstr(tstr):
    def partition(self, sep):
        partA, sep, partB = super().partition(sep)
        return (self.create(partA, self.taint[0:len(partA)]),
                self.create(sep, self.taint[len(partA):len(partA) + len(sep)]),
                self.create(partB, self.taint[len(partA) + len(sep):]))

    def rpartition(self, sep):
        partA, sep, partB = super().rpartition(sep)
        return (self.create(partA, self.taint[0:len(partA)]),
                self.create(sep, self.taint[len(partA):len(partA) + len(sep)]),
                self.create(partB, self.taint[len(partA) + len(sep):]))

class tstr(tstr):
    def ljust(self, width, fillchar=' '):
        res = super().ljust(width, fillchar)
        initial = len(res) - len(self)
        if isinstance(fillchar, tstr):
            t = fillchar.x()
        else:
            t = -1
        return self.create(res, [t] * initial + self.taint)

    def rjust(self, width, fillchar=' '):
        res = super().rjust(width, fillchar)
        final = len(res) - len(self)
        if isinstance(fillchar, tstr):
            t = fillchar.x()
        else:
            t = -1
        return self.create(res, self.taint + [t] * final)

class tstr(tstr):
    def swapcase(self):
        return self.create(str(self).swapcase(), self.taint)

    def upper(self):
        return self.create(str(self).upper(), self.taint)

    def lower(self):
        return self.create(str(self).lower(), self.taint)

    def capitalize(self):
        return self.create(str(self).capitalize(), self.taint)

    def title(self):
        return self.create(str(self).title(), self.taint)

    def t(self, i=0):
        if self.taint:
            return self.taint[i]
        else:
            if i != 0:
                raise tstr.TaintException('Invalid request idx')
            # self._tcursor gets created only for empty strings.
            # use the exception to determine which ones need it.
            return self._tcursor
def taint_include(gword, gsentence):
    return set(gword.taint) <= set(gsentence.taint)


def make_str_wrapper(fun):
    def proxy(*args, **kwargs):
        res = fun(*args, **kwargs)
        return res
    return proxy

import types
tstr_members = [name for name, fn in inspect.getmembers(tstr, callable)
                if isinstance(fn, types.FunctionType) and fn.__qualname__.startswith('tstr')]

for name, fn in inspect.getmembers(str, callable):
    if name not in set(['__class__', '__new__', '__str__', '__init__',
                        '__repr__', '__getattribute__']) | set(tstr_members):
        setattr(tstr, name, make_str_wrapper(fn))


def make_str_abort_wrapper(fun):
    def proxy(*args, **kwargs):
        raise tstr.TaintException('%s Not implemented in TSTR' % fun.__name__)
    return proxy

In [None]:
if __name__ == '__main__':
    my_str = tstr('hello world')
    print(my_str.taint)
    a, b = my_str.split(' ')
    print(a.taint, b.taint)
    print(my_str[0:3].taint)

## xtstr

`xtstr` keeps track of character comparisons to indexes.

In [None]:
class Op(enum.Enum):
    LT = 0
    LE = enum.auto()
    EQ = enum.auto()
    NE = enum.auto()
    GT = enum.auto()
    GE = enum.auto()
    IN = enum.auto()
    NOT_IN = enum.auto()
    IS = enum.auto()
    IS_NOT = enum.auto()
    FIND_STR = enum.auto()

COMPARE_OPERATORS = {
    Op.EQ: lambda x, y: x == y,
    Op.NE: lambda x, y: x != y,
    Op.IN: lambda x, y: x in y,
    Op.NOT_IN: lambda x, y: x not in y,
    Op.FIND_STR: lambda x, y: x.find(y)
}

Comparisons = []

# ### Instructions

class Instr:
    def __init__(self, o, a, b):
        self.opA = a
        self.opB = b
        self.op = o

    def o(self):
        if self.op == Op.EQ:
            return 'eq'
        elif self.op == Op.NE:
            return 'ne'
        else:
            return '?'

    def opS(self):
        if not self.opA.has_taint() and isinstance(self.opB, tstr):
            return (self.opB, self.opA)
        else:
            return (self.opA, self.opB)

    @property
    def op_A(self):
        return self.opS()[0]

    @property
    def op_B(self):
        return self.opS()[1]

    def __repr__(self):
        return "%s,%s,%s" % (self.o(), repr(self.opA), repr(self.opB))

    def __str__(self):
        if self.op == Op.EQ:
            if str(self.opA) == str(self.opB):
                return "%s = %s" % (repr(self.opA), repr(self.opB))
            else:
                return "%s != %s" % (repr(self.opA), repr(self.opB))
        elif self.op == Op.NE:
            if str(self.opA) == str(self.opB):
                return "%s = %s" % (repr(self.opA), repr(self.opB))
            else:
                return "%s != %s" % (repr(self.opA), repr(self.opB))
        elif self.op == Op.IN:
            if str(self.opA) in str(self.opB):
                return "%s in %s" % (repr(self.opA), repr(self.opB))
            else:
                return "%s not in %s" % (repr(self.opA), repr(self.opB))
        elif self.op == Op.NOT_IN:
            if str(self.opA) in str(self.opB):
                return "%s in %s" % (repr(self.opA), repr(self.opB))
            else:
                return "%s not in %s" % (repr(self.opA), repr(self.opB))
        else:
            assert False


class xtstr(tstr):
    def create(self, res, taint):
        o = xtstr(res, taint, self)
        o.comparisons = self.comparisons
        return o

    def add_instr(self, op, c_a, c_b):
        self.comparisons.append(Instr(op, c_a, c_b))

    def with_comparisons(self, comparisons):
        self.comparisons = comparisons
        return self

class xtstr(xtstr):
    def __eq__(self, other):
        if len(self) == 0 and len(other) == 0:
            self.add_instr(Op.EQ, self, other)
            return True
        elif len(self) == 0:
            self.add_instr(Op.EQ, self, other[0])
            return False
        elif len(other) == 0:
            self.add_instr(Op.EQ, self[0], other)
            return False
        elif len(self) == 1 and len(other) == 1:
            self.add_instr(Op.EQ, self, other)
            return super().__eq__(other)
        else:
            if not self[0] == other[0]:
                return False
            return self[1:] == other[1:]

class xtstr(xtstr):
    def __ne__(self, other):
        return not self.__eq__(other)

class xtstr(xtstr):
    def __contains__(self, other):
        self.add_instr(Op.IN, self, other)
        return super().__contains__(other)

class xtstr(xtstr):
    def find(self, sub, start=None, end=None):
        if start is None:
            start_val = 0
        else:
            start_val = start
        if end is None:
            end_val = len(self)
        else:
            end_val = end
        self.add_instr(Op.IN, self[start_val:end_val], sub)
        return super().find(sub, start, end)

class xtstr(xtstr):
    def rfind(self, sub, start=None, end=None):
        if start is None:
            start_val = 0
        else:
            start_val = start
        if end is None:
            end_val = len(self)
        else:
            end_val = end
        self.add_instr(Op.IN, self[start_val:end_val], sub)
        return super().find(sub, start, end)

class xtstr(xtstr):
    def startswith(self, s, beg =0,end=None):
        if end == None:
            end = len(self)
        self == s[beg:end]
        return super().startswith(s, beg, end)


def substrings(s, l):
    for i in range(len(s) - (l - 1)):
        yield s[i:i + l]

class xtstr(xtstr):
    def in_(self, s):
        if isinstance(s, str):
            # c in '0123456789'
            # to
            # __fn(c).in_('0123456789')
            # ensure that all characters are compared
            result = [self == c for c in substrings(s, len(self))]
            return any(result)
        else:
            for item in s:
                if self == item:
                    return True
            return False

class xtstr(xtstr):
    def split(self, sep=None, maxsplit=-1):
        self.add_instr(Op.IN, self, sep)
        return super().split(sep, maxsplit)

In [None]:
class xtstr(xtstr):
    def _find(self, substr, sub, m):
        v_ = str(substr)
        if not v_: return []
        v = v_.find(str(sub))
        start = substr.taint[0]
        if v == -1:
            return [(i, m) for i in range(start, start + len(substr))]
        else:
            return [(i, m) for i in range(start, start + v + len(sub))]

    def add_instr(self, op, c_a, c_b):
        ct = None
        m = get_current_method()
        if len(c_a) == 1 and isinstance(c_a, xtstr):
            ct = c_a.taint[0]
            self.comparisons.append((ct, m))
        elif len(c_b) == 1 and isinstance(c_b, xtstr):
            ct = c_b.taint[0]
            self.comparisons.append((ct, m))
        elif op == Op.IN:
            self.comparisons.extend(self._find(c_a, c_b, m))
        elif len(c_a) == 0 or len(c_b) == 0:
            pass
        else:
            assert False, "op:%s A:%s B:%s" % (op, c_a, c_b)
        # print(repr(m))
    def replace(self, old, new, count=None):
        m = get_current_method()
        if count is not None:
            # TODO
            self.comparisons.extend([(t, m) for t in self.taint])
            return super().replace(old, new, count)
        else:
            self.comparisons.extend([(t, m) for t in self.taint])
            return super().replace(old, new)

    def create(self, res, taint):
        o = xtstr(res, taint, self)
        o.comparisons = self.comparisons
        return o

    def __hash__(self):
        return hash(str(self))

import inspect
def make_str_abort_wrapper(fun):
    def proxy(*args, **kwargs):
        raise tstr.TaintException(
            '%s Not implemented in `xtstr`' %
            fun.__name__)
    return proxy

defined_xtstr = {}
for name, fn in inspect.getmembers(xtstr, callable):
    clz = fn.__qualname__.split('.')[0]
    if clz in {'xtstr'}:
        defined_xtstr[name] = clz

for name, fn in inspect.getmembers(str, callable):
    if name not in defined_xtstr and name not in {
            '__init__', '__str__', '__eq__', '__ne__', '__class__', '__new__',
            '__setattr__', '__len__', '__getattribute__', '__le__', 'lower',
            'strip', 'lstrip', 'rstrip', '__iter__', '__getitem__', '__add__', '__repr__'}:
        setattr(xtstr, name, make_str_abort_wrapper(fn))

## helpers

In [None]:
CURRENT_METHOD = None
METHOD_NUM_STACK = []
METHOD_MAP = {}
METHOD_NUM = 0

def get_current_method():
    return CURRENT_METHOD

def set_current_method(method, stack_depth, mid):
    global CURRENT_METHOD
    CURRENT_METHOD = (method, stack_depth, mid)
    return CURRENT_METHOD

def trace_init():
    global CURRENT_METHOD
    global METHOD_NUM_STACK
    global METHOD_MAP
    global METHOD_NUM
    CURRENT_METHOD = None
    METHOD_NUM_STACK.clear()
    METHOD_MAP.clear()
    METHOD_NUM = 0

    start = (METHOD_NUM, None, [])
    METHOD_NUM_STACK.append(start)
    METHOD_MAP[METHOD_NUM] = start

def trace_call(method):
    global CURRENT_METHOD
    global METHOD_NUM_STACK
    global METHOD_MAP
    global METHOD_NUM
    METHOD_NUM += 1

    # create our method invocation
    # method_num, method_name, children
    n = (METHOD_NUM, method, [])
    METHOD_MAP[METHOD_NUM] = n
    # add ourselves as one of the children to the previous method invocation
    METHOD_NUM_STACK[-1][2].append(n)
    # and set us as the current method.
    METHOD_NUM_STACK.append(n)

def trace_return():
    METHOD_NUM_STACK.pop()

def trace_set_method(method):
    set_current_method(method, len(METHOD_NUM_STACK), METHOD_NUM_STACK[-1][0])
    
class in_wrap:
    def __init__(self, s):
        self.s = s

    def in_(self, s):
        m = get_current_method()
        if isinstance(s, xtstr):
            cmps = s._find(s, self.s, m)
            s.comparisons.extend(cmps)
        if isinstance(self.s, xtstr):
            cmps = [(t,m) for t in self.s.taint]
            self.s.comparisons.extend(cmps)
        return self.s in s

def taint_wrap__(st):
    if isinstance(st, str):
        return in_wrap(st)
    else:
        return st

def wrap_input(inputstr):
    return xtstr(inputstr, parent=None).with_comparisons([])

def convert_comparisons(comparisons, inputstr):
    light_comparisons = []
    for idx, (method, stack_depth, mid) in comparisons:
        if idx is None: continue
        light_comparisons.append((idx, inputstr[idx], mid))
    return light_comparisons

def convert_method_map(method_map):
    light_map = {}
    for k in method_map:
        method_num, method_name, children = method_map[k]
        light_map[k] = (k, method_name, [c[0] for c in children])
    return light_map

## calculator.py

In [None]:
#%pycat subjects/calculator.py

## InRewriter

We also rewrite the source so that `asring in value` gets converted to `taint_wrap__(astring).in_(value)`. Note that what we are tracking is not really taints, but rather _character accesses_ to the origin string.

The `InRewriter` class handles transforming `in` statements so that taints can be tracked. It has two methods. The `wrap()` method transforms any `a in lst` calls to `taint_wrap__(a) in lst`.

In [None]:
import ast

In [None]:
class InRewriter(ast.NodeTransformer):
    def wrap(self, node):
        return ast.Call(func=ast.Name(id='taint_wrap__', ctx=ast.Load()), args=[node], keywords=[])

The `wrap()` method is internally used by `visit_Compare()` method to transform `a in lst` to `taint_wrap__(a).in_(lst)`. We need to do this because Python ties the overriding of `in` operator to the `__contains__()` method in the class of `lst`. In our case, however, very often `a` is the element tainted and hence proxied. Hence we need a method invoked on the `a` object.

In [None]:
class InRewriter(InRewriter):
    def visit_Compare(self, tree_node):
        left = tree_node.left
        if not tree_node.ops or not isinstance(tree_node.ops[0], ast.In):
            return tree_node
        mod_val = ast.Call(
            func=ast.Attribute(
                value=self.wrap(left),
                attr='in_'),
            args=tree_node.comparators,
            keywords=[])
        return mod_val

Tying it together

In [None]:
def rewrite_in(src):
    v = ast.fix_missing_locations(InRewriter().visit(ast.parse(src)))
    source = ast.unparse(v)
    return "%s" % source

In [None]:
if __name__ == '__main__':
    print(rewrite_in('s in ["a", "b", "c"]'))

In [None]:
def rewrite(src, original):
    src = ast.fix_missing_locations(InRewriter().visit(ast.parse(src)))
    header = """
import json
import sys
    """
    source = ast.unparse(src)
    return "%s\n%s" % (header, source)

In [None]:
if __name__ == '__main__':
    calc_parse_rewritten = rewrite(utils.slurp('subjects/calculator.py'), original='calculator.py')
    print(calc_parse_rewritten)

In [None]:
if __name__ == '__main__':
    exec(calc_parse_rewritten)

In [None]:
if __name__ == '__main__':
    from src.utils import ExpectError
    trace_init()
    trace_set_method('parse_expr')

In [None]:
if __name__ == '__main__':
    mystring = '1+1+xyz'
    tainted_input = wrap_input(mystring)
    with ExpectError() as e:
        main(tainted_input)
    print(e.msg)

In [None]:
if __name__ == '__main__':
    for c in tainted_input.comparisons:
        print(c)

# Generator

In [None]:
class NeedMoreException(Exception): ...
class InvalidValueException(Exception): ...
class InputLimitException(Exception): ...
class IterationLimitException(Exception): ...

In [None]:
def logit(*v):
    print(*v, file=sys.stderr)
    return

In [None]:
class Status(enum.Enum):
    Complete = 0
    Incomplete = 1
    Incorrect = -1

In [None]:
def calc_wrapper(mystring):
    tainted_input = wrap_input(mystring + ' ')
    try:
        try:
            main(mystring)
            return Status.Complete, None, mystring
        except:
            main(tainted_input)
            return Status.Complete, None, ''
    except:
        if not tainted_input.comparisons:
            return Status.Incorrect, None, ''
        last = max(i for i, _ in tainted_input.comparisons)
        if last < len(mystring):
            return Status.Incorrect, None, ''
        return Status.Incomplete, None, ''

In [None]:
def get_fitness(s):
    st, n, _ = calc_wrapper(s)
    match st:
        case Status.Complete:
            if len(s) > 100:
                return 0
            return 1.0/(len(s)+10)
        case Status.Incomplete:
            return 1.0/(len(s)+10)
        case Status.Incorrect:
            return 10
    assert False, (st, s)

In [None]:
def neighbours(x):
    old = x[:-1]
    #return [old + c for c in string.printable] + 
    return [x + c for c in random.sample(string.printable, len(string.printable))]

In [None]:
LOG_VALUES = 1

In [None]:
def hillclimber():
    # Create and evaluate starting point
    x = ''
    fitness = get_fitness(x)
    print("Initial value: %s at fitness %.4f" % (repr(x), fitness))
    iterations = 0
    logs = 0

    # Stop once we have found an optimal solution
    while fitness > 0.001:
        iterations += 1
        # Move to first neighbour with a better fitness
        found = False
        for nextx in neighbours(x):
            new_fitness = get_fitness(nextx)
            #print(repr(nextx), new_fitness, file=sys.stderr)
            # Smaller fitness values are better
            if new_fitness < fitness:
                x = nextx
                fitness = new_fitness
                if logs < LOG_VALUES:
                    print("New value: %s at fitness %.4f" % (x, fitness))
                elif logs == LOG_VALUES:
                    print("...")
                #logs += 1
                found = True
                break
        if not found:
            break

    print("Found optimum after %d iterations at %s" % (iterations, x))

In [None]:
if __name__ == '__main__':
    hillclimber()

# Done

In [None]:
#%tb