In [None]:
from math import log2

from z3 import *

Range tracking part of value tracking will be done with the following C structure

```c
struct wrange {
	/* It is possible to have start > end, this is valid and just means that we're tracking wrapped value.
     * e.g. start=0xffffffffffffffff, end=0x0000000000000001 means that we are tracking {0xffffffffffffffff, 0x0, 0x1}.
     */
	u64 start;
	u64 end;
}
```

```python
class Wrange:
    SIZE = 64 # Working with 64-bit integers
    name: str
    start: BitVecRef
    end: BitVecRef

    def __init__(self, name, start=None, end=None):
        self.name = name
        self.start = BitVec(f'Wrange-{name}-start', bv=self.SIZE) if start is None else start
        self.end = BitVec(f'Wrange-{name}-end', bv=self.SIZE) if end is None else end

    def wellformed(self):
        return If(self.end == self.start - 1, self.start == 0, False)
        
    def reset(self):
        return And(self.start == BitVecVal(0, bv=self.SIZE), self.end == BitVecVal(-1, bv=self.SIZE))

    def contains(self, val: BitVecRef):
        nonwrapping_cond = And(ULE(self.start, val), ULE(val, self.end))
        wrapping_cond = And(ULE(val, self.end), UGE(self.start, val))
        return If(ULT(self.end, self.start), wrapping_cond, nonwrapping_cond)

def wrange_add(a: Wrange, b: Wrange):
    diff_a = a.end - a. start
    diff_b = b.end - b.start
    new_diff = diff_a + diff_b
    
    new_start = If(too_wide, BitVecVal(0, a.SIZE), a.start + b.start)
    new_end = If(too_wide, BitVecVal(-1, a.SIZE), a.end + b.end)
    return Wrange(f'{a.name} + {b.name}', new_start, new_end)
```

In [None]:
class Wrange:
    SIZE = 64 # Working with 64-bit integers
    name: str
    base: BitVecRef
    diff: BitVecRef

    def __init__(self, name, base=None, diff=None):
        self.name = name
        self.base = BitVec(f'Wrange-{name}-base', bv=self.SIZE) if base is None else base
        assert(self.base.size() == self.SIZE)
        self.diff = BitVec(f'Wrange-{name}-diff', bv=self.SIZE) if diff is None else diff
        assert(self.diff.size() == self.SIZE)

    def wellformed(self):
        return If(self.diff == BitVecVal(2**64 - 1, bv=self.SIZE), self.base == 0, True)
        
    def reset(self):
        return And(self.base == BitVecVal(0, bv=self.SIZE), self.diff == BitVecVal(-1, bv=self.SIZE))

    def contains(self, val: BitVecRef):
        assert(val.size() == self.SIZE)
        end = self.base + self.diff
        nonwrapping_cond = And(ULE(self.base, val), ULE(val, end))
        wrapping_cond = Or(ULE(val, end), UGE(self.base, val))
        return If(ULT(end, self.base), wrapping_cond, nonwrapping_cond)

In [None]:
prove(And(
    x == BitVecVal(0, bv=64) - 1,
    Or(ULE(x, BitVecVal(-1, bv=64) + BitVecVal(2, bv=64)), UGE(BitVecVal(-1, bv=64), x))
))

In [None]:
x = BitVec('x', bv=64)
w1 = Wrange('w1', BitVecVal(1, bv=64), BitVecVal(1, bv=64))
prove(
    w1.contains(x) == Or(x == BitVecVal(1, bv=64), x == BitVecVal(2, bv=64))
)

x = BitVec('x', bv=64)
w1 = Wrange('w1', BitVecVal(-1, bv=64), BitVecVal(2, bv=64))
prove(
    w1.contains(x) == Or(x == BitVecVal(-1, bv=64), x == BitVecVal(0, bv=64), x == BitVecVal(1, bv=64))
)

## Addition

In [None]:
def wrange_add(a: Wrange, b: Wrange):
    new_diff = a.diff + b.diff
    too_wide = Or(ULT(new_diff, a.diff), ULT(new_diff, b.diff))
    new_base = If(too_wide, BitVecVal(0, a.SIZE), a.base + b.base)
    new_diff = If(too_wide, BitVecVal(2**64-1, a.SIZE), a.diff + b.diff)
    return Wrange(f'{a.name} + {b.name}', new_base, new_diff)

In [None]:
x = BitVec('x', bv=64)
w = wrange_add(
    # {1, 2, 3}
    Wrange('w1', BitVecVal(1, bv=64), BitVecVal(2, bv=64)),
    # + {0}
    Wrange('w2', BitVecVal(0, bv=64), BitVecVal(0, bv=64)),
)   # = {1, 2, 3}
prove(               # 1 <= x <= 3
    w.contains(x) == And(BitVecVal(1, bv=64) <= x, x <= BitVecVal(3, bv=64)),
)

x = BitVec('x', bv=64)
w = wrange_add(
    # {-1}
    Wrange('w1', BitVecVal(-1, bv=64), BitVecVal(0, bv=64)),
    # + {0, 1, 2}
    Wrange('w2', BitVecVal(0, bv=64), BitVecVal(2, bv=64)),   
)   # = {-1, 0, 1}
prove(               # -1 <= x <= 1
    w.contains(x) == And(BitVecVal(-1, bv=64) <= x, x <= BitVecVal(1, bv=64)),
)

In [None]:
w1 = Wrange('w1')
w2 = Wrange('w2')
result = wrange_add(w1, w2)
x = BitVec('x', bv=64)
y = BitVec('y', bv=64)
premise = And(
    w1.wellformed(),
    w2.wellformed(),
    w1.contains(x),
    w2.contains(y),
)

In [None]:
prove(
    Implies(
        premise,
        And(
            result.contains(x + y),
            result.wellformed(),
        ),
    )
)

## Arithmetic Negation

In [None]:
def wrange_neg(a: Wrange):
    return Wrange(f'(-{a.name})', If(a.diff == -1, 0, - a.base - a.diff), a.diff)

In [None]:
x = BitVec('x', bv=64)
w = wrange_neg(
    # -{1, 2, 3}
    Wrange('w1', BitVecVal(0x1, bv=64), BitVecVal(0x2, bv=64)),
)   # = {-3, -2, -1}
prove(
    Implies(
        w.contains(x),
        And(-3 <= x, x <= -1),
    )
)

x = BitVec('x', bv=64)
w = wrange_neg(
    # -{-1}
    Wrange('w1', BitVecVal(-1, bv=64), BitVecVal(0x0, bv=64)),
)   # = { 1}
prove(
    Implies(
        w.contains(x),
        x == 1,
    )
)

In [None]:
w1 = Wrange('w1')
result = wrange_neg(w1)
x = BitVec('x', bv=64)
premise = And(
    w1.wellformed(),
    w1.contains(x),
)

In [None]:
prove(
    Implies(
        premise,
        And(
            result.contains(-x),
            result.wellformed(),
        ),
    )
)

## Subtraction

In [None]:
#def wrange_sub(a: Wrange, b: Wrange):
#    new_diff = a.diff + b.diff
#    too_wide = Or(new_diff < a.diff, new_diff < b.diff)
#    new_base = If(too_wide, BitVecVal(0, a.SIZE), a.base - b.base)
#    new_diff = If(too_wide, BitVecVal(-1, a.SIZE), a.diff + b.diff)
#    return Wrange(f'{a.name} - {b.name}', new_base, new_diff)

def wrange_sub(a: Wrange, b: Wrange):
    # Be a bit lazy here, improve later
    w = wrange_add(a, wrange_neg(b))
    return Wrange(f'{a.name} - {b.name}', w.base, w.diff)

In [None]:
x = BitVec('x', bv=64)
w = wrange_sub(
    # {1, 2, 3}
    Wrange('w1', BitVecVal(0x1, bv=64), BitVecVal(0x2, bv=64)),
    # - {0}
    Wrange('w2', BitVecVal(0x0, bv=64), BitVecVal(0x0, bv=64)),
)   # = {1, 2, 3}
prove(
    Implies(
        w.contains(x),
        # 1 <= x <= 3
        And(ULE(1, x), ULE(x, 3)),
    )
)

x = BitVec('x', bv=64)
w = wrange_sub(
    # {-1}
    Wrange('w1', BitVecVal(-1, bv=64), BitVecVal(0x0, bv=64)),
    # - {0, 1, 2}
    Wrange('w2', BitVecVal(0x0, bv=64), BitVecVal(0x2, bv=64)),
)   # = {-3, -2, -1}
prove(
    Implies(
        w.contains(x),
        # -3 <= x <= -1
        And(-3 <= x, x <= -1),
    )
)

In [None]:
w1 = Wrange('w1')
w2 = Wrange('w2')
result = wrange_sub(w1, w2)
x = BitVec('x', bv=64)
y = BitVec('y', bv=64)
premise = And(
    w1.wellformed(),
    w2.wellformed(),
    w1.contains(x),
    w2.contains(y),
)

In [None]:
prove(
    Implies(
        premise,
        And(
            result.contains(x - y),
            result.wellformed(),
        ),
    )
)

## Multiplication

In [None]:
def wrange_mul(a: Wrange, b: Wrange):
    new_diff = a.diff * b.diff
    # Being very conservative here, at the very least diff <= U32_MAX would also work
    too_wide = Or(UGT(a.diff, BitVecVal(0xffff, bv=a.SIZE)), UGT(b.diff, BitVecVal(0xffff, bv=b.SIZE)))
    new_base = If(too_wide, BitVecVal(0, a.SIZE), a.base * b.base)
    new_diff = If(too_wide, BitVecVal(-1, a.SIZE), a.diff * b.diff)
    return Wrange(f'{a.name} * {b.name}', new_base, new_diff)

In [None]:
x = BitVec('x', bv=64)
w = wrange_mul(
    # {1, 2, 3}
    Wrange('w1', BitVecVal(0x1, bv=64), BitVecVal(0x2, bv=64)),
    # * {0}
    Wrange('w2', BitVecVal(0x0, bv=64), BitVecVal(0x0, bv=64)),
)   # = {0}
prove(
    Implies(
        w.contains(x),
        x == 0,
    )
)

x = BitVec('x', bv=64)
w = wrange_mul(
    # {-1}
    Wrange('w1', BitVecVal(-1, bv=64), BitVecVal(0x0, bv=64)),
    # {0, 1, 2}
    Wrange('w2', BitVecVal(0x0, bv=64), BitVecVal(0x2, bv=64)),
)   # {-2, -1, 0}
prove(
    Implies(
        w.contains(x),
        # -2 <= x <= 0
        And(-2 <= x, x <= 0),
    )
)

In [None]:
w1 = Wrange('w1')
w2 = Wrange('w2')
result = wrange_mul(w1, w2)
x = BitVec('x', bv=64)
y = BitVec('y', bv=64)
premise = And(
    w1.wellformed(),
    w2.wellformed(),
    w1.contains(x),
    w2.contains(y),
)

In [None]:
prove(
    Implies(
        premise,
        And(
            result.contains(x * y),
            result.wellformed(),
        ),
    )
)

## Bitwise-AND

In [None]:
def wrange_and(a: Wrange, b: Wrange):
    return Wrange(f'{a.name} & {b.name}', base=BitVecVal(0, bv=a.SIZE), diff=BitVecVal(0, bv=a.SIZE))

In [None]:
x = BitVec('x', bv=64)
w = wrange_and(
    # {1 (0b01), 2 (0b10), 3 (0b11)}
    Wrange('w1', BitVecVal(0x1, bv=64), BitVecVal(0x2, bv=64)),
    # & { 0 (0b00) }
    Wrange('w2', BitVecVal(0x0, bv=64), BitVecVal(0x0, bv=64)),
)   # = { 0 (ob00) }
prove(
    w.contains(x) == (x == 0),
)

x = BitVec('x', bv=64)
w = wrange_and(
    # {-1 (0xffffffffffffffff) }
    Wrange('w1', BitVecVal(-1, bv=64), BitVecVal(0x0, bv=64)),
    # & {0 (0b00), 1 (0b01), 2 (0b10)}
    Wrange('w2', BitVecVal(0x0, bv=64), BitVecVal(0x2, bv=64)),
)   # = {0 (0b00), 1 (0b01), 2 (0b10)}
prove(
    w.contains(x) == And(0 <= x, x <= 2),
)

In [None]:
w1 = Wrange('w1')
w2 = Wrange('w2')
result = wrange_and(w1, w2)
x = BitVec('x', bv=64)
y = BitVec('y', bv=64)
premise = And(
    w1.wellformed(),
    w2.wellformed(),
    w1.contains(x),
    w2.contains(y),
)

In [None]:
prove(
    Implies(
        premise,
        And(
            result.contains(x & y),
            result.wellformed(),
        ),
    )
)

## Evaluation

In [None]:
s = Optimize()
w1 = Wrange('w1', base=BitVecVal(-1, bv=64), diff=BitVecVal(0, bv=64))
w2 = Wrange('w2', base=BitVecVal(0, bv=64), diff=BitVecVal(2, bv=64))
x = BitVec('x', bv=64)
w = wrange_add(w1, w2)
s.minimize(x)
s.minimize(w1.base)
s.minimize(w2.base)
s.minimize(w1.diff)
s.minimize(w2.diff)
s.add(
    Not(w.contains(x) == And(BitVecVal(-1, bv=64) <= x, x <= BitVecVal(1, bv=64))),
)
s.check()

In [None]:
m = s.model()
m

In [None]:
m.eval(x)

In [None]:
m.eval(w1.base), m.eval(w1.diff), m.eval(w1.base + w1.diff)

In [None]:
m.eval(w2.base), m.eval(w2.diff), m.eval(w2.base + w2.diff)

In [None]:
m.eval(w.base), m.eval(w.diff), m.eval(w.base + w.diff)

In [None]:
m.eval(w.contains(x))

In [None]:
m.eval(w1.wellformed()), m.eval(w2.wellformed()), m.eval(w.wellformed())

In [None]:
x.size()