In [None]:
from math import log2

from z3 import *

In [None]:
BitVec32 = lambda n: BitVec(n, bv=32)
BitVecVal32 = lambda v: BitVecVal(v, bv=32)

Range tracking part of value tracking will be done with the following C structure

```c
struct wrange {
	u32 start;
	u32 end;
}
```

In [None]:
class Wrange:
    SIZE = 32 # Working with 32-bit integers
    name: str
    start: BitVecRef
    end: BitVecRef

    def __init__(self, name, start=None, end=None):
        self.name = name
        self.start = BitVec(f'Wrange-{name}-start', bv=self.SIZE) if start is None else start
        assert(self.start.size() == self.SIZE)
        self.end = BitVec(f'Wrange-{name}-end', bv=self.SIZE) if end is None else end
        assert(self.end.size() == self.SIZE)

    def print(self, model):
        name = self.name
        pad = ' ' * (len(self.name) + 1)
        start = model.eval(self.start).as_long()
        length = model.eval(self.length).as_long()
        end = model.eval(self.end).as_long()
        print(f'{name}(start={start}/{hex(start)},\n{pad}length={length}/{hex(length)},\n{pad}end={end}/{hex(end)})')

    def wellformed(self):
        # allow end < start
        return BoolVal(True)

    def reset(self):
        return And(self.start == BitVecVal(0, bv=self.SIZE), self.end == BitVecVal(-1, bv=self.SIZE))

    @property
    def length(self):
        return self.end - self.start

    @property
    def uwrapping(self):
        return ULT(self.end, self.start)

    @property
    def umin(self):
        return If(self.uwrapping, BitVecVal(0, bv=self.SIZE), self.start)

    @property
    def umax(self):
        return If(self.uwrapping, BitVecVal(2**self.SIZE - 1, bv=self.SIZE), self.end)

    @property
    def swrapping(self):
        return self.end < self.start

    @property
    def smin(self):
        return If(self.swrapping, BitVecVal(1 << (self.SIZE - 1), bv=self.SIZE), self.start)

    @property
    def smax(self):
        return If(self.swrapping, BitVecVal((2**self.SIZE - 1) >> 1, bv=self.SIZE), self.end)

    def contains(self, val: BitVecRef):
        assert(val.size() == self.SIZE)
        # start <= val <= end
        nonwrapping_cond = And(ULE(self.start, val), ULE(val, self.end))
        # 0 <= val <= end or start <= val <= 2**32-1
        wrapping_cond = Or(
                And(ULE(BitVecVal(0, bv=self.SIZE), val), ULE(val, self.end)),
                And(ULE(self.start, val), ULE(val, BitVecVal(2**self.SIZE - 1, bv=self.SIZE)))
        )
        return If(self.uwrapping, wrapping_cond, nonwrapping_cond)

In [None]:
x = BitVec32('x')
w1 = Wrange('w1', start=BitVecVal32(1), end=BitVecVal32(3))
prove(
    w1.contains(x) == Or(x == BitVecVal32(1), x == BitVecVal32(2), x == BitVecVal32(3))
)

x = BitVec32('x')
w1 = Wrange('w1', start=BitVecVal32(-1), end=BitVecVal32(1))
prove(
    w1.contains(x) == Or(x == BitVecVal32(-1), x == BitVecVal32(0), x == BitVecVal32(1))
)

## Addition

In [None]:
def wrange_add(a: Wrange, b: Wrange):
    assert(a.SIZE == b.SIZE)
    new_length = a.length + b.length
    too_wide = Or(ULT(new_length, a.length), ULT(new_length, b.length))
    new_start = If(too_wide, BitVecVal(0, a.SIZE), a.start + b.start)
    new_end = If(too_wide, BitVecVal(2**a.SIZE-1, a.SIZE), a.end + b.end)
    return Wrange(f'{a.name} + {b.name}', new_start, new_end)

In [None]:
x = BitVec32('x')
w = wrange_add(
    # {1, 2, 3}
    Wrange('w1', start=BitVecVal32(1), end=BitVecVal32(3)),
    # + {0}
    Wrange('w2', start=BitVecVal32(0), end=BitVecVal32(0)),
)   # = {1, 2, 3}
prove(               # 1 <= x <= 3
    w.contains(x) == And(BitVecVal32(1) <= x, x <= BitVecVal32(3)),
)

x = BitVec32('x')
w = wrange_add(
    # {-1}
    Wrange('w1', start=BitVecVal32(-1), end=BitVecVal32(-1)),
    # + {0, 1, 2}
    Wrange('w2', start=BitVecVal32(0), end=BitVecVal32(2)),   
)   # = {-1, 0, 1}
prove(               # -1 <= x <= 1
    w.contains(x) == And(BitVecVal32(-1) <= x, x <= BitVecVal32(1)),
)

In [None]:
w1 = Wrange('w1')
w2 = Wrange('w2')
result = wrange_add(w1, w2)
x = BitVec32('x')
y = BitVec32('y')
premise = And(
    w1.wellformed(),
    w2.wellformed(),
    w1.contains(x),
    w2.contains(y),
)

In [None]:
prove(
    Implies(
        premise,
        And(
            result.contains(x + y),
            result.wellformed(),
        ),
    )
)

## Arithmetic Negation

In [None]:
def wrange_neg(a: Wrange):
    return Wrange(f'(-{a.name})', -a.end, -a.start)

In [None]:
x = BitVec32('x')
w = wrange_neg(
    # -{1, 2, 3}
    Wrange('w1', start=BitVecVal32(1), end=BitVecVal32(3)),
)   # = {-3, -2, -1}
prove(
    w.contains(x) == And(-3 <= x, x <= -1)
)

x = BitVec32('x')
w = wrange_neg(
    # -{-1}
    Wrange('w1', start=BitVecVal32(-1), end=BitVecVal32(-1)),
)   # = { 1}
prove(
    w.contains(x) == (x == 1)
)

In [None]:
w1 = Wrange('w1')
result = wrange_neg(w1)
x = BitVec32('x')
premise = And(
    w1.wellformed(),
    w1.contains(x),
)

In [None]:
prove(
    Implies(
        premise,
        And(
            result.contains(-x),
            result.wellformed(),
        ),
    )
)

## Subtraction

In [None]:
def wrange_sub_composed(a: Wrange, b: Wrange):
    # Be a bit lazy here, improve later
    w = wrange_add(a, wrange_neg(b))
    return Wrange(f'{a.name} - {b.name}', w.start, w.length)

def wrange_sub(a: Wrange, b: Wrange):
    assert(a.SIZE == b.SIZE)
    new_length = a.length + b.length
    too_wide = Or(ULT(new_length, a.length), ULT(new_length, b.length))
    new_start = If(too_wide, BitVecVal(0, a.SIZE), a.start - b.end)
    new_end = If(too_wide, BitVecVal(2**a.SIZE-1, a.SIZE), a.end - b.start)
    return Wrange(f'{a.name} - {b.name}', new_start, new_end)

In [None]:
x = BitVec32('x')
w = wrange_sub(
    # {1, 2, 3}
    Wrange('w1', start=BitVecVal32(1), end=BitVecVal32(3)),
    # - {0}
    Wrange('w2', start=BitVecVal32(0), end=BitVecVal32(0)),
)   # = {1, 2, 3}
prove(               # 1 <= x <= 3
    w.contains(x) == And(ULE(1, x), ULE(x, 3))
)

x = BitVec32('x')
w = wrange_sub(
    # {-1}
    Wrange('w1', start=BitVecVal32(-1), end=BitVecVal32(-1)),
    # - {0, 1, 2}
    Wrange('w2', start=BitVecVal32(0), end=BitVecVal32(2)),
)   # = {-3, -2, -1}
prove(               # -3 <= x <= -1
    w.contains(x) == And(-3 <= x, x <= -1),
)

In [None]:
w1 = Wrange('w1')
w2 = Wrange('w2')
result = wrange_sub(w1, w2)
x = BitVec32('x')
y = BitVec32('y')
premise = And(
    w1.wellformed(),
    w2.wellformed(),
    w1.contains(x),
    w2.contains(y),
)

In [None]:
prove(
    Implies(
        premise,
        And(
            result.contains(x - y),
            result.wellformed(),
        ),
    )
)

## Multiplication

In [None]:
def wrange_mul(a: Wrange, b: Wrange):
    assert(a.SIZE == b.SIZE)
    too_large = Or(UGT(a.end, BitVecVal(2**(a.SIZE/2)-1, bv=a.SIZE)), UGT(b.end, BitVecVal(2**(b.SIZE/2)-1, bv=b.SIZE)))
    negative = Or(a.smin < 0, b.smin < 0)
    giveup = Or(too_large, negative)
    new_start = a.start * b.start
    new_end = a.end * b.end
    return Wrange(f'{a.name} * {b.name}', If(giveup, BitVecVal(0, a.SIZE), new_start), If(giveup, BitVecVal(-1, a.SIZE), new_end))

In [None]:
x = BitVec32('x')
w = wrange_mul(
    # {1, 2, 3}
    Wrange('w1', start=BitVecVal32(1), end=BitVecVal32(3)),
    # * {0}
    Wrange('w2', start=BitVecVal32(0), end=BitVecVal32(0)),
)   # = {0}
prove(
    Implies(
        w.contains(x),
        x == 0,
    )
)

x = BitVec32('x')
w = wrange_mul(
    # {-1}
    Wrange('w1', start=BitVecVal32(-1), end=BitVecVal32(-1)),
    # {0, 1, 2}
    Wrange('w2', start=BitVecVal32(-2), end=BitVecVal32(0)),
)   # {-2, -1, 0}
#prove(
#    Implies(
#        w.contains(x),
#        # -2 <= x <= 0
#        And(-2 <= x, x <= 0),
#    )
#)

In [None]:
w1 = Wrange('w1')
w2 = Wrange('w2')
result = wrange_mul(w1, w2)
x = BitVec32('x')
y = BitVec32('y')
premise = And(
    w1.wellformed(),
    w2.wellformed(),
    w1.contains(x),
    w2.contains(y),
)

In [None]:
prove(
    Implies(
        premise,
        And(
            result.contains(x * y),
            result.wellformed(),
        ),
    )
)

## Evaluation

In [None]:
s = Optimize()
x = BitVec32('x')
y = BitVec32('y')
w1 = Wrange('w1')
w2 = Wrange('w2')
result = wrange_mul(w1, w2)
premise = And(
    w1.wellformed(),
    w2.wellformed(),
    w1.contains(x),
    w2.contains(y),
)
s.minimize(x)
s.minimize(y)
s.minimize(w1.length)
s.minimize(w2.start)
s.add(Not(
    Implies(
        premise,
        And(
            result.contains(x * y),
            result.wellformed(),
        ),
    )
))
s.check()

In [None]:
m = s.model()
m

In [None]:
f'x={m.eval(x)}, w1.contains(x)={m.eval(w1.contains(x))}'

In [None]:
w1.print(m)

In [None]:
f'y={m.eval(y)}, w2.contains(y)={m.eval(w2.contains(y))}'

In [None]:
w2.print(m)

In [None]:
result.print(m)

In [None]:
f'x+y={m.eval(x*y)}, result.contains(x+y)={m.eval(result.contains(x+y))}'

In [None]:
m.eval(w1.wellformed()), m.eval(w2.wellformed()), m.eval(result.wellformed())

In [None]:
# wrapping?
m.eval(w1.uwrapping)

In [None]:
# non-wrapping cond AND
m.eval(ULE(w1.start, x)), m.eval(ULE(x, w1.end))

In [None]:
# wrapping cond OR
m.eval(ULE(x, w1.end)), m.eval(ULE(w1.start, x))