Skip to content

Commit

Permalink
re fullmatch implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
pschanely committed May 19, 2020
1 parent 3bf7726 commit 3ed7d3e
Show file tree
Hide file tree
Showing 5 changed files with 359 additions and 1 deletion.
3 changes: 2 additions & 1 deletion crosshair/core.py
Expand Up @@ -170,7 +170,8 @@ def __exit__(self, exc_type, exc_value, tb):
'expected string or bytes-like object' in exc_str):
# Ideally we'd attempt literal strings after encountering this.
# See https://github.com/pschanely/CrossHair/issues/8
raise CrosshairUnsupported('Detected proxy intolerance: '+exc_str)
debug('Proxy intolerace at: ', traceback.format_exc())
raise CrosshairUnsupported('Detected proxy intolerance: ' + exc_str)
if isinstance(exc_value, (UnexploredPath, CrosshairInternal, z3.Z3Exception)):
return False # internal issue: re-raise
if isinstance(exc_value, BaseException): # TODO: should this be "Exception" instead?
Expand Down
2 changes: 2 additions & 0 deletions crosshair/libimpl/__init__.py
Expand Up @@ -2,10 +2,12 @@
from crosshair.libimpl import collectionslib
from crosshair.libimpl import datetimelib
from crosshair.libimpl import randomlib
from crosshair.libimpl import relib

def make_registrations():
builtinslib.make_registrations()
collectionslib.make_registrations()
datetimelib.make_registrations()
randomlib.make_registrations()
relib.make_registrations()

225 changes: 225 additions & 0 deletions crosshair/libimpl/relib.py
@@ -0,0 +1,225 @@
import re
from typing import *

from sre_parse import LITERAL, RANGE, ANY, IN, BRANCH, SUBPATTERN # type: ignore
from sre_parse import MAX_REPEAT, MAXREPEAT # type: ignore
from sre_parse import CATEGORY, CATEGORY_DIGIT # type: ignore
from sre_parse import parse


import z3 # type: ignore

from crosshair import debug, register_patch, register_type
from crosshair import realize, with_realized_args, IgnoreAttempt

from crosshair.libimpl.builtinslib import SmtInt, SmtStr


# TODO: test _Match methods
# TODO: SUBPATTERN
# TODO: CATEGORY
# TODO: re.MULTILINE
# TODO: re.DOTALL
# TODO: re.IGNORECASE
# TODO: Give up on re.LOCALE
# TODO: bytes input and re.ASCII
# TODO: Match edge conditions; IndexError etc
# TODO: Match.__repr__
# TODO: wait for unicode support in z3 and adapt this.
# TODO: greediness by default; also nongreedy: +? *? ?? {n,m}?
# TODO: ATs: parse(r'\A^\b\B$\Z', re.MULTILINE) == [(AT, AT_BEGINNING_STRING),
# (AT, AT_BEGINNING), (AT, AT_BOUNDARY), (AT, AT_NON_BOUNDARY),
# (AT, AT_END), (AT, AT_END_STRING)]
# TODO: capture groups
# TODO: backreferences to capture groups: parse(r'(\w) \1') ==
# [(SUBPATTERN, (1, 0, 0, [(IN, [(CATEGORY, CATEGORY_WORD)])])),
# (LITERAL, 32), (GROUPREF, 1)]
# TODO: categories: CATEGORY_SPACE, CATEGORY_WORD, CATEGORY_LINEBREAK
# TODO: NEGATE: parse(r'[^34]') == [(IN, [(NEGATE, None), (LITERAL, 51), (LITERAL, 52)])]
# TODO: NOT_LITERAL: parse(r'[^\n]') == [(NOT_LITERAL, 10)]
# TODO: search()
# TODO: split()
# TODO: findall() and finditer()
# TODO: sub() and subn()
# TODO: positive/negative lookahead/lookbehind


class ReUnhandled(Exception):
pass

def _handle_item(parsed: Tuple[object, Any], flags: int) -> z3.ExprRef:
(op, arg) = parsed
if op is LITERAL:
if re.IGNORECASE & flags:
if re.ASCII & flags:
return z3.Union(z3.Re(chr(arg).lower()), z3.Re(chr(arg).upper()))
else:
raise ReUnhandled
else:
return z3.Re(chr(arg))
elif op is RANGE:
lo, hi = arg
if re.IGNORECASE & flags:
if re.ASCII & flags:
return z3.Union(z3.Range(chr(lo).lower(), chr(hi).lower()),
z3.Range(chr(lo).upper(), chr(hi).upper()))
else:
raise ReUnhandled
else:
return z3.Range(chr(lo), chr(hi))
elif op is IN:
return z3.Union(*(_handle_item(a, flags) for a in arg))
elif op is CATEGORY:
if arg == CATEGORY_DIGIT:
if re.ASCII & flags:
return z3.Range('0','9')
raise ReUnhandled
elif op is ANY and arg is None:
if re.ASCII & flags:
if re.DOTALL & flags:
return z3.Range(chr(0), chr(255))
else:
return z3.Union(z3.Range(chr(0), chr(9)),
z3.Range(chr(11), chr(255)))
raise ReUnhandled
elif op is BRANCH and arg[0] is None:
branches = arg[1]
return z3.Union(*(_handle_seq(b, flags) for b in branches))
elif op is SUBPATTERN and arg[1] == 0 == arg[2]:
group_num, _, _, subparsed = arg
raise ReUnhandled # need to figure out how to capture subpatterns
#return _handle_seq(subparsed, flags)
elif op is MAX_REPEAT:
(min_repeat, max_repeat, subparsed) = arg
if max_repeat == MAXREPEAT:
if min_repeat == 0:
return z3.Star(_handle_seq(subparsed, flags))
elif min_repeat == 1:
return z3.Plus(_handle_seq(subparsed, flags))
else:
raise ReUnhandled
elif isinstance(min_repeat, int) and isinstance(max_repeat, int):
return z3.Loop(_handle_seq(subparsed, flags), min_repeat, max_repeat)
raise ReUnhandled
else:
raise ReUnhandled(str(op))

def _handle_seq(parsed: Any, flags: int) -> z3.ExprRef:
if len(parsed) == 1:
return _handle_item(parsed[0], flags)
else:
return z3.Concat(*(_handle_item(p, flags) for p in parsed))

def _interpret(pattern: str, flags: int):
parsed = parse(pattern, flags)
try:
ret = _handle_seq(parsed, flags)
debug('Attempting symbolic regex interpretation: ', ret)
return ret
except ReUnhandled:
return None


class _Match:
def __init__(self,
patt: re.Pattern,
string: str,
pos: int,
endpos: Optional[int],
groups: List[Tuple[Optional[str], int, int]]):
self._groups = groups
self.string = string
self.pos = pos
self.endpos = endpos if endpos is not None else len(string)
self.re = patt
self.lastindex = None
self.lastgroup = None
def __bool__(self):
return True
def __repr__(self):
return f'<re.Match object; span={self.span()!r}, match={self.group()!r}>'
def __getitem__(self, idx):
return self.group(idx)
def group(self, *nums):
if not nums:
nums = (0,)
ret = []
for num in nums:
name, start, end = self._groups[num]
ret.append(self.string[start:end])
if len(nums) == 1:
return ret[0]
else:
return tuple(ret)
def groups(self):
indicies = range(1, len(self._groups))
if indicies:
return self.group(*indicies)
else:
return ()
def groupdict(self, default=None):
ret = {}
for groupname, start, end in self._groups:
if groupname is not None:
ret[groupname] = self.string[start:end]
return ret
def start(self, group=0):
return self._groups[group][1]
def end(self, group=0):
return self._groups[group][2]
def span(self, group=0):
_, start, end = self._groups[group]
return (start, end)


def _slice_match_area(string, pos=0, endpos=None):
smtstr = string.var
is_bounded = pos != 0 or endpos is not None
endpos = z3.Length(smtstr) if endpos is None else endpos
if is_bounded:
smtstr = z3.SubString(smtstr, pos, endpos - pos)
return (smtstr, endpos)

_orig_match = re.Pattern.match
def _match(self, string, pos=0, endpos=None):
# TODO: Work in progress. Greediness is not accounted for here.
if type(string) is SmtStr:
interp = _interpret(self.pattern, self.flags)
if interp is not None:
smtstr, endpos = _slice_match_area(string, pos, endpos)
space = string.statespace
match_end = SmtInt(space, int, 'matchend' + space.uniq())
matching_substr = z3.SubString(smtstr, 0, match_end)
if space.smt_fork(z3.InRe(matching_substr, interp)):
## It's the greediest match:
#x = z3.Var(0, z3.IntSort())
#space.add(z3.ForAll([x], z3.Implies(z3.And(match_end < x, x < z3.Length(smtstr))),
# z3.Not(z3.InRe(z3.SubString(smtstr, 0, x), interp))))
return _Match(self, string, pos, endpos, [(None, pos, match_end)])
else:
return None
string = realize(string)
return _orig_match(self, string, pos) if endpos is None else _orig_match(self, string, pos, endpos)

_orig_fullmatch = re.Pattern.fullmatch
def _fullmatch(self, string, pos=0, endpos=None):
if type(string) is SmtStr:
interp = _interpret(self.pattern, self.flags)
if interp is not None:
smtstr, endpos = _slice_match_area(string, pos, endpos)
if string.statespace.smt_fork(z3.InRe(smtstr, interp)):
return _Match(self, string, pos, endpos, [(None, pos, endpos)])
else:
return None
return _orig_fullmatch(self, realize(string), self.flags)

def make_registrations():
register_patch(re.Pattern, with_realized_args(re.Pattern.search), 'search')
#register_patch(re.Pattern, with_realized_args(re.Pattern.match), 'match')
#register_patch(re.Pattern, _match, 'match')
register_patch(re.Pattern, _fullmatch, 'fullmatch')
register_patch(re.Pattern, with_realized_args(re.Pattern.split), 'split')
register_patch(re.Pattern, with_realized_args(re.Pattern.findall), 'findall')
register_patch(re.Pattern, with_realized_args(re.Pattern.finditer), 'finditer')
register_patch(re.Pattern, with_realized_args(re.Pattern.sub), 'sub')
register_patch(re.Pattern, with_realized_args(re.Pattern.subn), 'subn')
129 changes: 129 additions & 0 deletions crosshair/libimpl/relib_test.py
@@ -0,0 +1,129 @@
import re
import sre_parse
import sys
import unittest
from typing import *

import z3 # type: ignore

from crosshair.libimpl.relib import _handle_seq

from crosshair.core_and_libs import *
from crosshair.test_util import check_ok
from crosshair.test_util import check_exec_err
from crosshair.test_util import check_post_err
from crosshair.test_util import check_fail
from crosshair.test_util import check_unknown
from crosshair.test_util import check_messages
from crosshair.util import set_debug

class RegularExpressionTests(unittest.TestCase):

def test_handle_simple(self):
z3re = _handle_seq(sre_parse.parse('abc'), 0)
self.assertEqual(z3.simplify(z3.InRe('abc', z3re)), True)
self.assertEqual(z3.simplify(z3.InRe('ab', z3re)), False)

def test_handle_or(self):
z3re = _handle_seq(sre_parse.parse('a|bc'), 0)
self.assertEqual(z3.simplify(z3.InRe('bc', z3re)), True)
self.assertEqual(z3.simplify(z3.InRe('a', z3re)), True)
self.assertEqual(z3.simplify(z3.InRe('ac', z3re)), False)

def test_handle_noncapturing_subgroup(self):
z3re = _handle_seq(sre_parse.parse('(?:a|b)c'), 0)
self.assertEqual(z3.simplify(z3.InRe('ac', z3re)), True)
self.assertEqual(z3.simplify(z3.InRe('bc', z3re)), True)
self.assertEqual(z3.simplify(z3.InRe('a', z3re)), False)

def test_handle_range(self):
z3re = _handle_seq(sre_parse.parse('[a-z]7'), 0)
self.assertEqual(z3.simplify(z3.InRe('b7', z3re)), True)
self.assertEqual(z3.simplify(z3.InRe('z7', z3re)), True)
self.assertEqual(z3.simplify(z3.InRe('A7', z3re)), False)

def test_handle_ascii_wildcard(self):
z3re = _handle_seq(sre_parse.parse('1.2'), re.A)
self.assertEqual(z3.simplify(z3.InRe('1x2', z3re)), True)
self.assertEqual(z3.simplify(z3.InRe('1\x002', z3re)), True)
self.assertEqual(z3.simplify(z3.InRe('111', z3re)), False)

def test_handle_repeats(self):
z3re = _handle_seq(sre_parse.parse('y*e+s{2,3}'), 0)
self.assertEqual(z3.simplify(z3.InRe('yess', z3re)), True)
self.assertEqual(z3.simplify(z3.InRe('ess', z3re)), True)
self.assertEqual(z3.simplify(z3.InRe('yyesss', z3re)), True)
self.assertEqual(z3.simplify(z3.InRe('yss', z3re)), False)
self.assertEqual(z3.simplify(z3.InRe('yessss', z3re)), False)
self.assertEqual(z3.simplify(z3.InRe('e', z3re)), False)

def test_handle_ascii_numeric(self):
z3re = _handle_seq(sre_parse.parse('a\d+'), re.A)
self.assertEqual(z3.simplify(z3.InRe('a32', z3re)), True)
self.assertEqual(z3.simplify(z3.InRe('a0', z3re)), True)
self.assertEqual(z3.simplify(z3.InRe('a-', z3re)), False)

def test_fullmatch_basic_fail(self) -> None:
def f(s: str) -> bool:
''' post: _ '''
return not re.compile('ab+').fullmatch(s)
self.assertEqual(*check_fail(f))

def test_fullmatch_basic_ok(self) -> None:
def f(s: str) -> Optional[re.Match]:
'''
pre: s == 'a'
post: _
'''
return re.compile('a').fullmatch(s)
self.assertEqual(*check_ok(f))

def test_fullmatch_complex_fail(self) -> None:
def f(s: str) -> str:
'''
pre: re.fullmatch('ab+aXb+a+', s)
post: _ != 'X'
'''
return s[-5]
self.assertEqual(*check_fail(f))

def TODO_test_match_basic_fail(self) -> None:
def f(s: str) -> bool:
''' post: implies(_, len(s) <= 3) '''
return re.compile('ab?c').match(s)
self.assertEqual(*check_ok(f))

def test_match_properties(self) -> None:
test_string = '01ab9'
match = re.compile('ab').fullmatch('01ab9', 2, 4)
assert match is not None
self.assertEqual(match.span(), (2, 4))
self.assertEqual(match.groups(), ())
self.assertEqual(match.group(0), 'ab')
self.assertEqual(match[0], 'ab')
self.assertEqual(match.pos, 2)
self.assertEqual(match.endpos, 4)
self.assertEqual(match.lastgroup, None)
self.assertEqual(match.string, '01ab9')
self.assertEqual(match.re.pattern, 'ab')
def f(s:str) -> Optional[re.Match]:
'''
pre: s == '01ab9'
post: _.span() == (2, 4)
post: _.groups() == ()
post: _.group(0) == 'ab'
post: _[0] == 'ab'
post: _.pos == 2
post: _.endpos == 4
post: _.lastgroup == None
post: _.string == '01ab9'
post: _.re.pattern == 'ab'
'''
return re.compile('ab').fullmatch(s, 2, 4)
self.assertEqual(*check_ok(f))


if __name__ == '__main__':
if ('-v' in sys.argv) or ('--verbose' in sys.argv):
set_debug(True)
unittest.main()
1 change: 1 addition & 0 deletions crosshair/statespace.py
Expand Up @@ -424,6 +424,7 @@ def __init__(self, rand: random.Random, expr: z3.ExprRef, solver: z3.Solver):
notexpr = z3.Not(expr)
could_be_true = solver_is_sat(solver, expr)
could_be_false = solver_is_sat(solver, notexpr)
#debug(expr, ' true?', could_be_true, ' false?', could_be_false)
if (not could_be_true) and (not could_be_false):
debug(' *** Reached impossible code path *** ')
debug('Current solver state:\n', str(solver))
Expand Down

0 comments on commit 3ed7d3e

Please sign in to comment.