# Modular Strided Intervals

Fix $N \in \{1, \ldots, 2^{23} - 1\}$.

The LLVM type $\texttt{i}N$ represents $N$-bit tuples:

$\texttt{i}N := \{0, 1\}^N$

These tuples can be interpreted as elements of $\mathbb{Z}/{2^N}$ using the isomorphism $\phi_N$ together with an appropriate map of operations:

$\phi_N \colon \texttt{i}N \rightarrow \mathbb{Z}/{2^N}, (b_0, \ldots, b_{N-1}) \mapsto \left(\sum_{k=0}^{N-1}b_k^k\right) + 2^N \mathbb{Z}$

An abstraction of $\mathbb{Z}$ and thefore also of $\texttt{i}N$ can be obtained by a generalization of intervals over $\mathbb{Z}$, represeted by the type $\mathrm{MSI}_N$ of _modular strided intervals (MSI)_:

$\mathrm{MSI}_N := \{s[a, b]_N \mid a, b, s \in \mathbb{Z}/2^N\}$

The sematics of an MSI is given by the concetization function $\gamma_N$:

$\gamma_N \colon \mathrm{MSI}_N \rightarrow \mathcal{P}(\mathbb{Z}/{2^N}), s[a, b]_N \mapsto \{k + 2^N \mathbb{Z} \mid k \in \mathbb{Z}, a \leq k, k \leq \min \{l \in \mathbb{Z} \mid a \leq k, l \equiv b \mod 2^N\}, k \equiv a \mod s\}$

In [1]:
from itertools import count, takewhile
from random import randint
from sympy import gcd, lcm

In [2]:
class MSI(object):
    """
    Modular strided iterval
    """
    def __init__(self, bit_width, begin, end, stride=1):
        self.bit_width = bit_width
        self.begin = begin
        self.end = end
        self.stride = stride

    def __eq__(self, other):
        return (self.bit_width == other.bit_width
            and self.stride == other.stride
            and self.begin == other.begin
            and self.end == other.end)

    def __repr__(self):
        return f'{self.stride}[{self.begin}, {self.end}]_{{{self.bit_width}}}'
    
    def __hash__(self):
        return (self.begin+23) * (self.end+29) * (self.stride+31) % 16777216
    
    def _tuple_repr(self):
        return (self.bit_width, self.begin, self.end, self.stride)

## Defining Functions and Predicates

In [3]:
# This predicate has not tests, it's an axiom.
def valid(i):
    n, a, b, s = i._tuple_repr()
    if n <= 0:
        return False
    if a < 0 or 2**n <= a:
        return False
    if b < 0 or 2**n <= b:
        return False
    if s < 0 or 2**n <= s:
        return False
    return True

In [4]:
def gamma(i):
    n = i.bit_width
    s = i.stride
    a = i.begin
    b = i.end if a <= i.end else i.end + 2**n
    return {k % 2**n for k in takewhile(
        lambda k: k <= b,
        (a+l*s for l in count()) if s > 0 else [a]
    )}

$\gamma_N$ is not injective, therefore normalization of MSIs is needed s.t. $\gamma_N$ restricted to $\{i \in \mathrm{MSI}_N \mid \mathrm{normal}(i)\}$ is injective. All other operatins on MSIs assume that there operands are normal and are expected to return a normal MSI.

Expanation of $\textrm{normal}$:

Fix $s[a, b]_N \in \textrm{MSI}_N$.

Case 1:
  
  Assume $s = 0$.

In [5]:
def normal(i):
    n, a, b, s = i._tuple_repr()
    if s == 0 and a != b:
        return False
    if a == b and s != 0:
        return False
    if not b in gamma(i):
        return False
    a_ = a - s
    if a_ != a and a_ >= 0 and gamma(i) == gamma(MSI(n, a_, (b-s) % 2**n, s)):
        return False
    if b < a and gamma(i) == gamma(MSI(n, b, a, 2**n - s)):
        return False
    return True

## Test sets of MSIs with theire respective concretizations

In [6]:
test_MSIs_handpicked_gamma = [
    # normalized
    #   no wraparound
    #     strid = 0
    #       begin = 0
    (MSI(4, 0, 0, 0), {0}),
    #       begin > 0
    (MSI(4, 3, 3, 0), {3}),
    #     strid = 1
    #       begin = 0
    #         end < 2**N-1
    (MSI(4, 0, 2, 1), {0, 1, 2}),
    #         end = 2**N-1
    (MSI(3, 0, 7, 1), {0, 1, 2, 3, 4, 5, 6, 7}),
    #       begin > 0
    (MSI(4, 3, 4, 1), {3, 4}),
    #     stride > 1
    #       begin = 0
    (MSI(4, 0, 4, 2), {0, 2, 4}),
    #       begin > 0
    (MSI(3, 1, 7, 3), {1, 4, 7}),
    (MSI(6, 6, 26, 10), {6, 16, 26}),
    #   wraparound
    #     stride = 1
    (MSI(4, 14, 2, 1), {14, 15, 0, 1, 2}),
    #     stride > 1
    (MSI(4, 11, 4, 3), {1, 4, 11, 14})]
test_MSIs_handpicked_gamma_unnormalized = [
    # unnormalized
    #   no wraparound
    #     stride = 0
    #       begin = 0
    (MSI(4, 0, 3, 0), {0}),
    #       begin > 0
    (MSI(4, 3, 8, 0), {3}),
    #     stride = 1
    #       begin = 0
    #         end = begin
    (MSI(4, 0, 0, 1), {0}),
    #         end != begin
    (MSI(2, 2, 1, 1), {0, 1, 2, 3}),
    #       begin > 0
    #         end = begin
    (MSI(4, 3, 3, 1), {3}),
    #         end != begin
    (MSI(3, 5, 4, 1), {0, 1, 2, 3, 4, 5, 6, 7}),
    #     stride > 1
    #       begin = 0
    (MSI(4, 0, 5, 2), {0, 2, 4}),
    (MSI(4, 0, 3, 5), {0}),
    #       begin > 0
    #         end = begin - stride mod 2**N
    (MSI(4, 11, 7, 4), {3, 7, 11, 15}),
    #         end != begin - stride mod 2**N
    (MSI(6, 6, 35, 10), {6, 16, 26}),
    (MSI(4, 3, 7, 5), {3}),
    #   wraparound
    #     stride = 0
    (MSI(4, 5, 3, 0), {5}),
    #     stride = 1
    (MSI(3, 5, 4, 1), {0, 1, 2, 3, 4, 5, 6, 7}),
    (MSI(4, 15, 0, 1), {15, 0}),
    #     stride > 1
    #       end = begin - stride mod 2**N
    (MSI(4, 10, 6, 4), {2, 6, 10, 14}),
    (MSI(4, 12, 2, 6), {2, 12}),
    #       end != begin and != begin - stride mod 2**N
    (MSI(4, 13, 2, 8), {13}),
    (MSI(4, 11, 6, 3), {11, 14, 1, 4}),
    (MSI(4, 10, 9, 4), {2, 6, 10, 14}),
    (MSI(4, 12, 7, 6), {2, 12})
]

In [7]:
test_MSIs_handpicked = {}
for i, _ in test_MSIs_handpicked_gamma:
    n = i.bit_width
    if n not in test_MSIs_handpicked:
        test_MSIs_handpicked[n] = [i]
    else:
        test_MSIs_handpicked[n].append(i)
print('size: ' + ', '.join(f'{n}: {len(js)}' for n, js in test_MSIs_handpicked.items()))

test_MSIs_handpicked_unnormalized = {}
for i, _ in test_MSIs_handpicked_gamma_unnormalized:
    n = i.bit_width
    if n not in test_MSIs_handpicked_unnormalized:
        test_MSIs_handpicked_unnormalized[n] = [i]
    else:
        test_MSIs_handpicked_unnormalized[n].append(i)
print('size: ' + ', '.join(f'{n}: {len(js)}' for n, js in test_MSIs_handpicked_unnormalized.items()))

size: 4: 7, 3: 2, 6: 1
size: 4: 16, 2: 1, 3: 2, 6: 1


## Tests for gamma

In [8]:
def test_gamma():
    failed = False
    for i, ks in test_MSIs_handpicked_gamma:
        if not gamma(i) == ks:
            failed = True
            print(f'{i}: {gamma(i)}, {ks}')
    if not failed:
        print('succeeded')

def test_gamma_unnormalized():
    failed = False
    for i, ks in test_MSIs_handpicked_gamma_unnormalized:
        if not gamma(i) == ks:
            failed = True
            print(f'{i}: {gamma(i)}, {ks}')
    if not failed:
        print('succeeded')

In [9]:
test_gamma()
test_gamma_unnormalized()

succeeded
succeeded


## Normalization function

In [10]:
def normalize(i):
    n, a, b, s = i._tuple_repr()
    if s == 0:
        b = a
    else:
        b_ = b if a <= b else b+2**n
        b = (b_ - (b_-a) % s) % 2**n
        if a == b:
            s = 0
        else:
            if 2**n % s == 0 and (a-b) % 2**n == s:
                a = a % s
                b = (a-s) % 2**n
            elif b == (a+s) % 2**n and b < a:
                a, b = b, a
                s = b-a
    return MSI(n, a, b, s)

## Test sets and utility functions for testing

Warning:

`normal` is used in `unary_function_test` if the `unnormalized` parameter is `True`, but tested later. Therefore this parameter should not be set before `normal` is tested.

In [11]:
def test_set(bit_widths, begins, ends, strides, only_normal=True, print_stats=False):
    MSIs = {}
    for n in bit_widths:
        js = set()
        bs = begins(n)
        for b in bs:
            es = ends(n)
            for e in es:
                ss = strides(n)
                for s in ss:
                    if only_normal:
                        js.add(normalize(MSI(n, b, e, s)))
                    else:
                        js.add(MSI(n, b, e, s))
        MSIs[n] = list(js)
    if print_stats:
        print('size: ' + ', '.join(f'{n}: {len(js)}' for n, js in MSIs.items()))
        if not only_normal:
            print('unnormalized: ' + ', '.join(f'{n}: {len(list(0 for j in js if not normal(j)))}' for n, js in MSIs.items()))
    return MSIs

In [12]:
f = lambda n: list(range(2**n))
g = lambda n: list(range(2**n))
print('test_MSIs_4_exhaustive')
test_MSIs_4_exhaustive = test_set(range(1, 4+1), f, g, f, print_stats=True)
print('test_MSIs_4_exhaustive')
test_MSIs_4_exhaustive_unnormalized = test_set(range(1, 4+1), f, g, f, only_normal=False, print_stats=True)

test_MSIs_4_exhaustive
size: 1: 3, 2: 15, 3: 95, 4: 575
test_MSIs_4_exhaustive
size: 1: 8, 2: 64, 3: 512, 4: 4096
unnormalized: 1: 5, 2: 49, 3: 417, 4: 3521


In [13]:
f = lambda n: list(range(2**n))
g = lambda n: list(range(2**n))
print('test_MSIs_5_6_exhaustive')
test_MSIs_5_6_exhaustive = test_set(range(5, 6+1), f, g, f, print_stats=True)
print('test_MSIs_5_6_exhaustive')
test_MSIs_5_6_exhaustive_unnormalized = test_set(range(5, 6+1), f, g, f, only_normal=False, print_stats=True)

test_MSIs_5_6_exhaustive
size: 5: 3039, 6: 15231
test_MSIs_5_6_exhaustive
size: 5: 32768, 6: 262144
unnormalized: 5: 29729, 6: 246913


In [14]:
test_MSIs_6_exhaustive = {
    **test_MSIs_4_exhaustive, **test_MSIs_5_6_exhaustive
}
test_MSIs_6_exhaustive_unnormalized = {
    **test_MSIs_4_exhaustive_unnormalized, **test_MSIs_5_6_exhaustive_unnormalized
}

In [15]:
ks = [a+b for a in [0, 30] for b in [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 15]]
ls = [30, 31, 32, 33, 35, 36, 40, 45]
f = lambda _: ks
print('test_MSIs_6_partial')
test_MSIs_6_partial = test_set([6], f, g, f, print_stats=True)
print('\ntest_MSIs_6_partial_unnormalized')
test_MSIs_6_partial_unnormalized = test_set([6], f, f, f, only_normal=False, print_stats=True)

test_MSIs_6_partial
size: 6: 4111

test_MSIs_6_partial_unnormalized
size: 6: 10648
unnormalized: 6: 9309


In [16]:
ks = [a+b for a in [0, 30, 60] for b in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 25]]
ls = [0, 2, 3, 5, 6, 10, 15]
f = lambda n: takewhile(lambda k: k < 2**n, ks)
g = lambda n: (((15 if 2**n < 30 else 45) + 15 + a) % 2**n for a in ls)
print('test_MSIs_8_partial')
test_MSIs_8_partial = test_set([8], f, g, f, print_stats=True)
print('\ntest_MSIs_8_partial_unnormalized')
test_MSIs_8_partial_unnormalized = test_set([8], f, g, f, only_normal=False, print_stats=True)

test_MSIs_8_partial
size: 8: 2438

test_MSIs_8_partial_unnormalized
size: 8: 10647
unnormalized: 8: 9705


In [17]:
f = lambda n: set(randint(0, 2**n-1) for _ in range(8))
g = lambda n: set(randint(0, 2**(n-1)-1) for _ in range(8))
print('test_MSIs_random')
test_MSIs_random = test_set(range(5, 8+1), f, f, g, print_stats=True)
print('\ntest_MSIs_random_unnormalized')
test_MSIs_random_unnormalized = test_set(range(5, 8+1), f, f, g, only_normal=False, print_stats=True)

test_MSIs_random
size: 5: 181, 6: 288, 7: 339, 8: 316

test_MSIs_random_unnormalized
size: 5: 253, 6: 412, 7: 462, 8: 487
unnormalized: 5: 218, 6: 366, 7: 432, 8: 468


In [18]:
def _unary_function_test(f, p, test_MSIs, test_count=0, fail_count=0, fail_lim=8):
    for n, js in test_MSIs.items():
        print(f'    testing bit width: {n}')
        for i in js:
            test_count += 1
            x = f(i)
            if not p(n, i, x):
                fail_count += 1
                print(f'        {i}: {x}')
                if fail_count == fail_lim:
                    return test_count, fail_count
            if test_count % 25000 == 0:
                print(f'- tested {test_count} arguments')
    return test_count, fail_count

def unary_function_test(f, p, big=False, unnormalized=False):
    fail_lim = 16 if big else 8
    test_count = fail_count = 0
    print('testing MSIs with bit width up to 4 exhaustively')
    MSIs = test_MSIs_4_exhaustive_unnormalized if unnormalized else test_MSIs_4_exhaustive
    test_count, fail_count = _unary_function_test(f, p, MSIs, test_count, fail_count, fail_lim)
    if fail_count == fail_lim:
        return
    print('testing random MSIs with bit width from 5 to 8')
    MSIs = test_MSIs_random_unnormalized if unnormalized else test_MSIs_random
    test_count, fail_count = _unary_function_test(f, p, MSIs, test_count, fail_count, fail_lim)
    if fail_count == fail_lim:
        return
    if big:
        print('testing some MSIs with bit width 6')
        MSIs = test_MSIs_6_partial_unnormalized if unnormalized else test_MSIs_6_partial
        test_count, fail_count = _unary_function_test(f, p, MSIs, test_count, fail_count, fail_lim)
        if fail_count == fail_lim:
            return
        print('testing some MSIs with bit width 8')
        MSIs = test_MSIs_8_partial_unnormalized if unnormalized else test_MSIs_8_partial
        test_count, fail_count = _unary_function_test(f, p, MSIs, test_count, fail_count, fail_lim)
        if fail_count == fail_lim:
            return
    if fail_count == 0:
        print(f'succeeded (tested {test_count} arguments in total)')

In [19]:
def _bin_fun_test(f, p, test_MSIs, test_count=0, fail_count=0, fail_lim=8):
    for n, js in test_MSIs.items():
        print(f'    testing bit width: {n}')
        for i in js:
            for j in js:
                test_count += 1
                x = f(i, j)
                if not p(n, i, j, x):
                    fail_count += 1
                    print(f'        f {i} {j}: {x}')
                    if fail_count == fail_lim:
                        return test_count, fail_count
                if test_count % 25000 == 0:
                    print(f'- tested {test_count} arguments')
    return test_count, fail_count

def bin_fun_test(f, p, big=False, non_zero=False):
    fail_lim = 16 if big else 8
    test_count = fail_count = 0
    print('testing MSIs with bit width up to 4 exhaustively')
    test_count, fail_count = _bin_fun_test(f, p, test_MSIs_4_exhaustive, test_count, fail_count, fail_lim)
    if fail_count == fail_lim:
        return
    print('testing random MSIs with bit width from 5 to 8')
    test_count, fail_count = _bin_fun_test(f, p, test_MSIs_random, test_count, fail_count, fail_lim)
    if fail_count == fail_lim:
        return
    if big:
        print('testing some MSIs with bit width 6')
        test_count, fail_count = _bin_fun_test(f, p, test_MSIs_6_partial, test_count, fail_count, fail_lim)
        if fail_count == fail_lim:
            return
        print('testing some MSIs with bit width 8')
        test_count, fail_count = _bin_fun_test(f, p, test_MSIs_8_partial, test_count, fail_count, fail_lim)
        if fail_count == fail_lim:
            return
    if fail_count == 0:
        print(f'succeeded (tested {test_count} arguments in total)')

In [20]:
def _bin_op_test(op_MSI, op, test_MSIs, test_count=0, fail_count=0, fail_lim=8, bad_args={}, bad_lim=8, non_zero=False):
    bad_precision = max(bad_args.values()) if len(bad_args) > 0 else 1
    for n, js in test_MSIs.items():
        print(f'    testing bit width: {n}')
        for i in js:
            vals_i = gamma(i)
            for j in js:
                test_count += 1
                vals_j = gamma(j)
                if non_zero and 0 in vals_j:
                    vals_op = {op(n, k, l) for k in vals_i for l in vals_j if not l == 0}
                else:
                    vals_op = {op(n, k, l) for k in vals_i for l in vals_j}
                vals_op_MSI = gamma(op_MSI(i, j))
                if not vals_op <= vals_op_MSI:
                    fail_count += 1
                    print(f'        {i} op {j}: {op_MSI(i, j)}, {vals_op}, {vals_i}, {vals_j}')
                    if fail_count == fail_lim:
                        return test_count, fail_count, bad_args
                elif not len(vals_op) == 0:
                    precision = len(vals_op) / (len(vals_op_MSI) * 2**n)
                    if precision < bad_precision:
                        if len(bad_args) == bad_lim:
                            bad_args.pop(list(bad_args.keys())[list(bad_args.values()).index(bad_precision)])
                        bad_args[(i, j)] = precision
                        bad_precision = max(bad_args.values())
                if test_count % 25000 == 0:
                    print(f'- tested {test_count} arguments')
    return test_count, fail_count, bad_args

def bin_op_test(op_MSI, op, big=False, non_zero=False):
    fail_lim = bad_lim = 16 if big else 8
    test_count = fail_count = 0
    bad_args = {}
    print('testing MSIs with bit width up to 4 exhaustively')
    test_count, fail_count, bad_args = _bin_op_test(op_MSI, op, test_MSIs_4_exhaustive, test_count, fail_count, fail_lim, bad_args, bad_lim, non_zero=non_zero)
    if fail_count == fail_lim:
        return
    print('testing random MSIs with bit width from 5 to 8')
    test_count, fail_count, bad_args = _bin_op_test(op_MSI, op, test_MSIs_random, test_count, fail_count, fail_lim, bad_args, bad_lim, non_zero=non_zero)
    if fail_count == fail_lim:
        return
    if big:
        print('testing some MSIs with bit width 6')
        test_count, fail_count, bad_args = _bin_op_test(op_MSI, op, test_MSIs_6_partial, test_count, fail_count, fail_lim, bad_args, bad_lim, non_zero=non_zero)
        if fail_count == fail_lim:
            return
        print('testing some MSIs with bit width 8')
        test_count, fail_count, bad_args = _bin_op_test(op_MSI, op, test_MSIs_8_partial, test_count, fail_count, fail_lim, bad_args, bad_lim, non_zero=non_zero)
        if fail_count == fail_lim:
            return
    if fail_count == 0:
        print(f'succeeded (tested {test_count} arguments in total)')
        print('arguments with least precise results:')
        for (i, j), r in bad_args.items():
            print(f'{i}, {j}: {r}')

In [21]:
def _bin_rel_test(rel_MSI, rel, test_MSIs, test_count=0, fail_count=0, fail_lim=8):
    for n, js in test_MSIs.items():
        print(f'    testing bit width: {n}')
        for i in js:
            for j in js:
                test_count += 1
                if not (rel_MSI(i, j) == rel(i, j)):
                    fail_count += 1
                    print(f'        {i} rel {j}: {rel_MSI(i, j)}')
                    if fail_count == fail_lim:
                        return test_count, fail_count
                if test_count % 25000 == 0:
                    print(f'- tested {test_count} arguments')
    return test_count, fail_count

def bin_rel_test(rel_MSI, rel, big=False):
    fail_lim = 16 if big else 8
    test_count = fail_count = 0
    print('testing MSIs with bit width up to 4 exhaustively')
    test_count, fail_count = _bin_rel_test(rel_MSI, rel, test_MSIs_4_exhaustive, test_count, fail_count, fail_lim)
    if fail_count == fail_lim:
        return
    print('testing random MSIs with bit width from 5 to 8')
    test_count, fail_count = _bin_rel_test(rel_MSI, rel, test_MSIs_random, test_count, fail_count, fail_lim)
    if fail_count == fail_lim:
        return
    if big:
        print('testing some MSIs with bit width 6')
        test_count, fail_count = _bin_rel_test(rel_MSI, rel, test_MSIs_6_partial, test_count, fail_count, fail_lim)
        if fail_count == fail_lim:
            return
        print('testing some MSIs with bit width 8')
        test_count, fail_count = _bin_rel_test(rel_MSI, rel, test_MSIs_8_partial, test_count, fail_count, fail_lim)
        if fail_count == fail_lim:
            return
    if fail_count == 0:
        print(f'succeeded (tested {test_count} arguments in total)')

## Test for normal

In [22]:
def test_normal():
    failed = False
    test_count = fail_count = 0
    for n, js in test_MSIs_6_exhaustive.items():
        equiv_classes = {}
        for i in js:
            a = frozenset(gamma(i))
            if a in equiv_classes:
                equiv_classes[a].add(i)
            else:
                equiv_classes[a] = {i}
        for equiv_class in equiv_classes.values():
            norm_forms = list(filter(normal, equiv_class))
            test_count += 1
            if len(norm_forms) != 1:
                failed = True
                fail_count += 1
                if len(norm_forms) == 0:
                    print(f'no normal form for {equiv_class}')
                else:
                    print(f'multiple normal forms {norm_forms}')
                if fail_count > 8:
                    return
    print(f'succeeded (tested {test_count} equivalence classes in total)')

In [23]:
test_normal()

succeeded (tested 18958 equivalence classes in total)


## Helper functions

In [22]:
def bounds(i):
    n, a, b, _ = i._tuple_repr()
    if a <= b:
        return a, b, False
    else:
        return a, b + 2**n, True

In [23]:
def contains(i, k):
    n, a, b, s = i._tuple_repr()
    if s == 0:
        return a == k
    elif a <= b:
        return a <= k and k <= b and (k - a) % s == 0
    else:
        if k >= a:
            return (k - a) % s == 0
        elif k <= b:
            return (k - b) % s == 0
        else:
            return False

In [24]:
def test_contains():
    failed = False
    test_count = fail_count = 0
    for n, js in test_MSIs_6_exhaustive.items():
        for i in js:
            test_count += 1
            a = gamma(i)
            for k in range(2**n):
                if k in a and not contains(i, k):
                    failed = True
                    fail_count += 1
                    print(f'{k} in gamma({i})')
                if k not in a and contains(i, k):
                    failed = True
                    fail_count += 1
                    print(f'{k} not in gamma({i})')
                if fail_count > 8:
                    return
    print(f'succeeded (tested {test_count} arguments)')

In [27]:
test_contains()

succeeded (tested 18958 arguments)


In [25]:
def leq_MSI(i, j, debug=False):
    n, a, b, s = i._tuple_repr()
    m, c, d, t = j._tuple_repr()
    assert n == m, 'strides must be equal'
    if s == 0: # i contains exactly 1 value
        return contains(j, a)
    elif t == 0: # j contains exactly 1 value
        return False
    elif b == (a+s) % 2**n: # i contains exactly 2 values
        return contains(j, a) and contains(j, b)
    elif s % t == 0:
        if 2**n % t == 0 and (c-d) % 2**n == t: # j represents a residue class of Z/t (=> t | 2**n)
            return (a-c) % t == 0
        else:
            b_ = (b-a) % 2**n
            c_, d_ = (c-a) % 2**n, (d-a) % 2**n
            if d_ < c_ and c_ <= b_: # this branch may not return, but continue below [a...d_...c_...b_...]
                e_ = s * (d_ // s)
                f_ = (b_ - s * ((b_-c_) // s)) % s**n
                if (f_-e_) == s:
                    if e_ < s:
                        if contains(j, a) and c_ % t == 0:
                            return True
                    elif contains(j, b) and d_ % t == 0:
                        return True
            if c_ <= d_:
                return c_ == 0 and b_ <= d_
            else:
                return b_ <= d_ and (d_-b_) % t == 0
    else:
        return False

In [59]:
bin_rel_test(leq_MSI, lambda i, j: gamma(i) <= gamma(j))

testing MSIs with bit width up to 4 exhaustively
    testing bit width: 1
    testing bit width: 2
    testing bit width: 3
    testing bit width: 4
- tested 25000 arguments
- tested 50000 arguments
- tested 75000 arguments
- tested 100000 arguments
- tested 125000 arguments
- tested 150000 arguments
- tested 175000 arguments
- tested 200000 arguments
- tested 225000 arguments
- tested 250000 arguments
- tested 275000 arguments
- tested 300000 arguments
- tested 325000 arguments
testing random MSIs with bit width from 5 to 8
    testing bit width: 5
- tested 350000 arguments
    testing bit width: 6
- tested 375000 arguments
- tested 400000 arguments
- tested 425000 arguments
- tested 450000 arguments
    testing bit width: 7
- tested 475000 arguments
- tested 500000 arguments
- tested 525000 arguments
- tested 550000 arguments
    testing bit width: 8
- tested 575000 arguments
- tested 600000 arguments
- tested 625000 arguments
succeeded (tested 645567 arguments in total)


In [29]:
lhs = MSI(3, 2, 0, 3)
rhs = MSI(3, 5, 3, 3)
res = leq_MSI(lhs, rhs)
print(f'{lhs} leq {rhs} = {res}')
print(f'{gamma(lhs)} leq {gamma(rhs)} = {gamma(lhs) <= gamma(rhs)}')
leq_MSI(lhs, rhs, debug=True)

3[2, 0]_{3} leq 3[5, 3]_{3} = False
{0, 2, 5} leq {0, 3, 5} = False


False

In [26]:
def size(i):
    n, a, b, s = i._tuple_repr()
    if s == 0:
        return 1
    else:
        return ((b-a) % 2**n) // s + 1

In [31]:
unary_function_test(size, lambda n, i, s: s == len(gamma(i)), big=True)

testing MSIs with bit width up to 4 exhaustively
    testing bit width: 1
    testing bit width: 2
    testing bit width: 3
    testing bit width: 4
testing random MSIs with bit width from 5 to 8
    testing bit width: 5
    testing bit width: 6
    testing bit width: 7
    testing bit width: 8
testing some MSIs with bit width 6
    testing bit width: 6
testing some MSIs with bit width 8
    testing bit width: 8
succeeded (tested 8386 arguments in total)


In [27]:
def lub(i, j, debug=False):
    n, a, b, s = i._tuple_repr()
    m, c, d, t = j._tuple_repr()
    assert n == m, 'strides must be equal'
    if b == (a+s) % 2**n and s >= 2**(n-1):
        if debug:
            print('correction 1')
        a, b = b, a
        s = 2**n - s
    if d == (c+t) % 2**n and t >= 2**(n-1):
        if debug:
            print('correction 2')
        c, d = d, c
        t = 2**n - t
    b_ = (b-a) % 2**n
    c_, d_ = (c-a) % 2**n, (d-a) % 2**n
    if debug:
        print(f'a: {a}, b: {b}, c: {c}, d: {d}')
        print(f'b_: {b_}, c_: {c_}, d_: {d_}')
    if (b_ < c_ and c_ < d_): # no overlapping regions
        if debug:
            print(f'case 0: b_: {b_}, c_: {c_}, d_: {d_}')
        u1 = int(gcd(gcd(s, t), (c-b) % 2**n))
        e1, f1 = a, d
        u2 = int(gcd(gcd(s, t), (a-d) % 2**n))
        e2, f2 = c, b
        opt1 = normalize(MSI(n, e1, f1, u1))
        opt2 = normalize(MSI(n, e2, f2, u2))
        if debug:
            print(f'opt2: {opt2}, opt1: {opt1}')
        if (size(opt1) < size(opt2)):
            return opt1
        else:
            return opt2
    elif d_ < c_ and c_ <= b_: # two overlapping regions
        if debug:
            print(f'case 1: b_: {b_}, c_: {c_}, d_: {d_}')
        u = int(gcd(gcd(s, t), gcd(c_ if c_ <= d_ else d_, 2**(n-1))))
        e = a % u
        f = (e - u) % 2**n
        return normalize(MSI(n, e, f, u))
    else: # one overlapping region
        if debug:
            print(f'case 2: b_: {b_}, c_: {c_}, d_: {d_}')
        e = a if c_ <= d_ else c
        f = b if d_ < b_ else d
        u = int(gcd(gcd(s, t), (c_ if c_ <= d_ else d_)))
        return normalize(MSI(n, e, f, u))

In [28]:
lhs, rhs = MSI(2, 0, 1, 1), MSI(2, 3, 3, 0)
print(f'{lhs}, {rhs}: {gamma(lhs)}, {gamma(rhs)}')
res = lub(lhs, rhs, debug=True)
print(f'{res}: {gamma(res)}')
print()
res = lub(rhs, lhs, debug=True)
print(f'{res}: {gamma(res)}')

1[0, 1]_{2}, 0[3, 3]_{2}: {0, 1}, {3}
a: 0, b: 1, c: 3, d: 3
b_: 1, c_: 3, d_: 3
case 2: b_: 1, c_: 3, d_: 3
1[0, 3]_{2}: {0, 1, 2, 3}

a: 3, b: 3, c: 0, d: 1
b_: 0, c_: 1, d_: 2
case 0: b_: 0, c_: 1, d_: 2
opt2: 1[0, 3]_{2}, opt1: 1[3, 1]_{2}
1[3, 1]_{2}: {0, 1, 3}


In [73]:
bin_fun_test(lub, lambda n, i, j, x: gamma(i) | gamma(j) <= gamma(x) and lub(i, j) == lub(j, i))

testing MSIs with bit width up to 4 exhaustively
    testing bit width: 1
    testing bit width: 2
        f 1[0, 1]_{2} 0[3, 3]_{2}: 1[0, 3]_{2}
        f 3[0, 3]_{2} 0[2, 2]_{2}: 1[0, 3]_{2}
        f 0[0, 0]_{2} 1[1, 2]_{2}: 1[0, 2]_{2}
        f 1[1, 2]_{2} 0[0, 0]_{2}: 1[0, 3]_{2}
        f 1[2, 3]_{2} 0[1, 1]_{2}: 1[0, 3]_{2}
        f 0[3, 3]_{2} 1[0, 1]_{2}: 1[3, 1]_{2}
        f 0[1, 1]_{2} 1[2, 3]_{2}: 1[1, 3]_{2}
        f 0[2, 2]_{2} 3[0, 3]_{2}: 1[2, 0]_{2}


In [29]:
lhs, rhs = MSI(3, 2, 0, 3), MSI(3, 4, 2, 3)
print(f'{lhs}, {rhs}: {gamma(lhs)}, {gamma(rhs)}')
res = lub(lhs, rhs, debug=True)
print(f'{res}: {gamma(res)}')

3[2, 0]_{3}, 3[4, 2]_{3}: {0, 2, 5}, {2, 4, 7}
a: 2, b: 0, c: 4, d: 2
b_: 6, c_: 2, d_: 0
case 1: b_: 6, c_: 2, d_: 0
1[0, 7]_{3}: {0, 1, 2, 3, 4, 5, 6, 7}


In [35]:
lub(MSI(6, 1, 37, 6), MSI(6, 31, 7, 6))

2[1, 63]_{6}

In [30]:
def as_signed_int(n, k):
    return k if k < 2**(n-1) else k - 2**n

In [31]:
def umax_MSI(i):
    n, a, b, s = i._tuple_repr()
    if a <= b:
        return b
    else:
        return 2**n - 1 - ((2**n - 1 - a) % s)

In [32]:
def umin_MSI(i):
    n, a, b, s = i._tuple_repr()
    if a <= b:
        return a
    else:
        return b % s

In [33]:
def smax_MSI(i):
    n, a, b, s = i._tuple_repr()
    a, b = as_signed_int(n, a), as_signed_int(n, b)
    if a <= b:
        return b % 2**n
    else:
        return 2**(n-1) - 1 - ((2**(n-1) - 1 - a) % s)

In [34]:
def smin_MSI(i):
    n, a, b, s = i._tuple_repr()
    a, b = as_signed_int(n, a), as_signed_int(n, b)
    if a <= b:
        return a % 2**n
    else:
        b = b % 2**n
        return (((b + 2**(n-1)) % 2**n % s) - 2**(n-1)) % 2**n

In [41]:
smin_MSI(MSI(2, 0, 3, 3))

3

In [35]:
def ustride(i):
    n, a, b, s = i._tuple_repr()
    if a <= b:
        return s
    else:
        return int(gcd(s, a-b))

In [36]:
def sstride(i):
    n, a, b, s = i._tuple_repr()
    if as_signed_int(n, a) <= as_signed_int(n, b):
        return s
    else:
        return int(gcd(s, as_signed_int(n, a)-as_signed_int(n, b)))

In [37]:
def pos_min(i):
    m = umin_MSI(i)
    if m < 2**(n-1):
        return m
    else:
        return None

In [38]:
def neg_max(i):
    m = umax_MSI(i)
    if m < 2**(n-1):
        return None
    else:
        return m

In [39]:
def sabsmin(i):
    n, a, b, s = i._tuple_repr()
    a_, b_ = as_singed_int(a), as_singed_int(b)
    if s == 0:
        return a
    elif b_ < 0:
        return b
    elif 0 < a_:
        return a
    else:
        x = a % s
        y = x - s
        return x if x <= y else y

In [40]:
def absmax(i):
    a, b = as_singed_int(i.begin), as_singed_int(i.end)
    b if abs(a) <= abs(b) else a

In [48]:
unary_function_test(umax_MSI, lambda n, i, k: max(gamma(i)) == k, big=True)

testing MSIs with bit width up to 4 exhaustively
    testing bit width: 1
    testing bit width: 2
    testing bit width: 3
    testing bit width: 4
testing random MSIs with bit width from 5 to 8
    testing bit width: 5
    testing bit width: 6
    testing bit width: 7
    testing bit width: 8
testing some MSIs with bit width 6
    testing bit width: 6
testing some MSIs with bit width 8
    testing bit width: 8
succeeded (tested 8386 arguments in total)


In [49]:
unary_function_test(umin_MSI, lambda n, i, k: min(gamma(i)) == k, big=True)

testing MSIs with bit width up to 4 exhaustively
    testing bit width: 1
    testing bit width: 2
    testing bit width: 3
    testing bit width: 4
testing random MSIs with bit width from 5 to 8
    testing bit width: 5
    testing bit width: 6
    testing bit width: 7
    testing bit width: 8
testing some MSIs with bit width 6
    testing bit width: 6
testing some MSIs with bit width 8
    testing bit width: 8
succeeded (tested 8386 arguments in total)


In [50]:
unary_function_test(smax_MSI, lambda n, i, k: max(map(lambda k: as_signed_int(n, k), gamma(i))) % 2**n == k, big=True)

testing MSIs with bit width up to 4 exhaustively
    testing bit width: 1
    testing bit width: 2
    testing bit width: 3
    testing bit width: 4
testing random MSIs with bit width from 5 to 8
    testing bit width: 5
    testing bit width: 6
    testing bit width: 7
    testing bit width: 8
testing some MSIs with bit width 6
    testing bit width: 6
testing some MSIs with bit width 8
    testing bit width: 8
succeeded (tested 8386 arguments in total)


In [51]:
unary_function_test(smin_MSI, lambda n, i, k: min(map(lambda k: as_signed_int(n, k), gamma(i))) % 2**n == k, big=True)

testing MSIs with bit width up to 4 exhaustively
    testing bit width: 1
    testing bit width: 2
    testing bit width: 3
    testing bit width: 4
testing random MSIs with bit width from 5 to 8
    testing bit width: 5
    testing bit width: 6
    testing bit width: 7
    testing bit width: 8
testing some MSIs with bit width 6
    testing bit width: 6
testing some MSIs with bit width 8
    testing bit width: 8
succeeded (tested 8386 arguments in total)


In [41]:
def as_unsigned(i):
    n, a, b, s = i._tuple_repr()
    if a <= b:
        return MSI(n, a, b, s)
    else:
        t = int(gcd(s, (a-b) & 2**n))
        c = a % t
        d = (c-t) % 2**n
        return MSI(n, c, d, t)

## Implementation of Operations

In [42]:
def add(i, j):
    n, a, b, s = i._tuple_repr()
    m, c, d, t = j._tuple_repr()
    assert n == m, 'strides must be equal'
    u = int(gcd(s, t))
    b_ = b if a <= b else b + 2**n
    d_ = d if c <= d else d + 2**n
    e, f = a+c, b_+d_
    if f-e < 2**n:
        u_ = u
        e_, f_ = e % 2**n, f % 2**n
    else:
        u_ = int(gcd(u, 2**n))
        e_ = e % 2**n
        f_ = (e_-u_) % 2**n
    return normalize(MSI(n, e_, f_, u_))

In [49]:
bin_op_test(add, lambda n, a, b: (a+b) % 2**n)

testing MSIs with bit width up to 4 exhaustively
    testing bit width: 1
    testing bit width: 2
    testing bit width: 3
    testing bit width: 4
- tested 25000 arguments
- tested 50000 arguments
- tested 75000 arguments
- tested 100000 arguments
- tested 125000 arguments
- tested 150000 arguments
- tested 175000 arguments
- tested 200000 arguments
- tested 225000 arguments
- tested 250000 arguments
- tested 275000 arguments
- tested 300000 arguments
- tested 325000 arguments
testing random MSIs with bit width from 5 to 8
    testing bit width: 5
- tested 350000 arguments
    testing bit width: 6
- tested 375000 arguments
- tested 400000 arguments
- tested 425000 arguments
- tested 450000 arguments
    testing bit width: 7
- tested 475000 arguments
- tested 500000 arguments
- tested 525000 arguments
    testing bit width: 8
- tested 550000 arguments
- tested 575000 arguments
- tested 600000 arguments
- tested 625000 arguments
- tested 650000 arguments
- tested 675000 arguments
succe

In [43]:
def sub(i, j):
    n, a, b, s = i._tuple_repr()
    m, c, d, t = j._tuple_repr()
    assert n == m, 'strides must be equal'
    u = int(gcd(s, t))
    b_ = b if a <= b else b + 2**n
    d_ = d if c <= d else d + 2**n
    e, f = a-d_, b_-c
    if f-e < 2**n:
        u_ = u
        e_, f_ = e % 2**n, f % 2**n
    else:
        u_ = int(gcd(u, 2**n))
        e_ = e % 2**n
        f_ = (e_-u_) % 2**n
    return normalize(MSI(n, e_, f_, u_))

In [56]:
bin_op_test(sub, lambda n, a, b: (a-b) % 2**n)

testing MSIs with bit width up to 4 exhaustively
    testing bit width: 1
    testing bit width: 2
    testing bit width: 3
    testing bit width: 4
- tested 25000 arguments
- tested 50000 arguments
- tested 75000 arguments
- tested 100000 arguments
- tested 125000 arguments
- tested 150000 arguments
- tested 175000 arguments
- tested 200000 arguments
- tested 225000 arguments
- tested 250000 arguments
- tested 275000 arguments
- tested 300000 arguments
- tested 325000 arguments
testing random MSIs with bit width from 5 to 8
    testing bit width: 5
- tested 350000 arguments
    testing bit width: 6
- tested 375000 arguments
- tested 400000 arguments
- tested 425000 arguments
- tested 450000 arguments
    testing bit width: 7
- tested 475000 arguments
- tested 500000 arguments
- tested 525000 arguments
    testing bit width: 8
- tested 550000 arguments
- tested 575000 arguments
- tested 600000 arguments
- tested 625000 arguments
- tested 650000 arguments
- tested 675000 arguments
succe

In [44]:
def mul(i, j, debug=False):
    n, a, b, s = i._tuple_repr()
    m, c, d, t = j._tuple_repr()
    assert n == m, 'strides must be equal'
    m = 2**n
    u = int(gcd(a, s)) * int(gcd(c, t))
    b_ = b if a <= b else b + m
    d_ = d if c <= d else d + m
    e, f = a*c, b_*d_
    if f-e < m:
        u_ = u
        e_, f_ = e % m, f % m
    else:
        u_ = int(gcd(u, m))
        e_ = e % m
        f_  = (e_-u_) % m
    if debug:
        print(f'u: {u}, e: {e}, f: {f}, u_: {u_}, e_: {e_}, f_: {f_}')
    return normalize(MSI(n, e_, f_, u_))

In [45]:
def urem(i, j, debug=False):
    n, _, _, s = i._tuple_repr()
    m, _, _, t = j._tuple_repr()
    assert n == m, 'strides must be equal'
    a, b = umin_MSI(i), umax_MSI(i)
    c, d = umin_MSI(j), umax_MSI(j)
    s, t = ustride(i), ustride(j)
    if c == 0:
        if t == 0:
            if debug:
                print('case 1')
            return MSI(n, 0, (-1) % 2**n, 1)
        else:
            c = t
    if b < c:
        if debug:
            print('case 2')
        return i
    elif t == 0:
        if a//c == b//c:
            if debug:
                print('case 3.1')
            return normalize(MSI(n, a % c, b % c, s))
        else:
            if debug:
                print('case 3.2')
            u = int(gcd(s, c))
            return normalize(MSI(n, a % u, c-1, u))
    else:
        if debug:
            print('case 4')
        u = int(gcd(gcd(c, t), s))
        return normalize(MSI(n, a % u, min(b, d-1), u))

In [199]:
bin_op_test(urem, lambda n, a, b: (a % b) % 2**n, big=False, non_zero=True)

testing MSIs with bit width up to 4 exhaustively
    testing bit width: 1
    testing bit width: 2
    testing bit width: 3
    testing bit width: 4
- tested 25000 arguments
- tested 50000 arguments
- tested 75000 arguments


KeyboardInterrupt: 

In [46]:
def udiv(i, j, debug=False):
    n, _, _, _ = i._tuple_repr()
    m, _, _, t = j._tuple_repr()
    assert n == m, 'strides must be equal'
    a, b = umin_MSI(i), umax_MSI(i)
    c, d = umin_MSI(j), umax_MSI(j)
    s = ustride(i)
    m = 2**n
    if c == 0:
        if t == 0:
            return MSI(n, 0, (-1) % 2**n, 1)
        else:
            c = ustride(j)
    s_ = int(gcd(a, s))
    if t == 0:
        u = s_ // c
        u = u if u*c == s_ else 1
        return normalize(MSI(n, a//c, b//c, u))
    else:
        e, f = a//d, b//c
        return normalize(MSI(n, e, f, 1))

In [57]:
lhs, rhs = MSI(3, 1, 5, 1), MSI(3, 2, 0, 3)
print(f'{lhs}, {rhs}: {gamma(lhs)}, {gamma(rhs)}')
res = udiv(lhs, rhs, debug=True)
print(f'{res}: {gamma(res)}')

1[1, 5]_{3}, 3[2, 0]_{3}: {1, 2, 3, 4, 5}, {0, 2, 5}
1[0, 5]_{3}: {0, 1, 2, 3, 4, 5}


In [164]:
bin_op_test(udiv, lambda n, a, b: (a // b), big=False, non_zero=True)

testing MSIs with bit width up to 4 exhaustively
    testing bit width: 1
    testing bit width: 2
    testing bit width: 3
    testing bit width: 4
- tested 25000 arguments
- tested 50000 arguments
- tested 75000 arguments
- tested 100000 arguments
- tested 125000 arguments
- tested 150000 arguments
- tested 175000 arguments
- tested 200000 arguments
- tested 225000 arguments
- tested 250000 arguments
- tested 275000 arguments
- tested 300000 arguments
- tested 325000 arguments
testing random MSIs with bit width from 5 to 8
    testing bit width: 5
- tested 350000 arguments
    testing bit width: 6
- tested 375000 arguments
- tested 400000 arguments
- tested 425000 arguments
    testing bit width: 7
- tested 450000 arguments
- tested 475000 arguments
- tested 500000 arguments
- tested 525000 arguments
    testing bit width: 8
- tested 550000 arguments
- tested 575000 arguments
- tested 600000 arguments
- tested 625000 arguments
succeeded (tested 647793 arguments in total)
arguments wi

In [47]:
def rem(k, n):
    assert not n == 0, 'remainder by 0'
    if k > 0:
        return k % abs(n)
    else:
        return -(abs(k) % abs(n))

In [48]:
def div(k, n):
    assert not n == 0, 'division by 0'
    if k > 0:
        return k // n
    else:
        return -(abs(k) // n)

In [49]:
def smin(n, k, l):
    k_, l_ = as_signed_int(n, k), as_signed_int(n, l)
    if k_ <= l_:
        return k_
    else:
        return l_

def smax(n, k, l):
    k_, l_ = as_signed_int(n, k), as_signed_int(n, l)
    if k_ >= l_:
        return k_
    else:
        return l_

In [50]:
def srem(i, j, debug=False):
    n, m = i.bit_width, j.bit_width
    assert n == m, 'strides must be equal'
    a, b = as_signed_int(n, smin_MSI(i)), as_signed_int(n, smax_MSI(i))
    c, d = as_signed_int(n, smin_MSI(j)), as_signed_int(n, smax_MSI(j))
    s, t = sstride(i), sstride(j)
    if debug:
        print(f'a: {a}, b: {b}, c: {c}, d: {d}, s: {s}, t: {t}')
    if d < 0:
        if debug:
            print('all negative')
        c, d = -d, -c
    elif c < 0:
        if debug:
            print('some negative')
        t_ = (d+c) % t
        c, d = min(-c % t, d % t), max(-c, d)
        t = gcd(t, t_)
    if debug:
        print(f'a: {a}, b: {b}, c: {c}, d: {d}, s: {s}, t: {t}')
    if c == 0: # remainder by bound not possible
        if t == 0: # definite remainder by 0
            if debug:
                print('case 1')
            return MSI(n, 0, (-1) % 2**n, 1)
        else: # correct bound to avoid ramainder by 0
            if debug:
                print('avoid 0')
            c = c+t
            # renormalize
            if c == d:
                if debug:
                    print('renormalize')
                t = 0
    if debug:
        print(f'a: {a}, b: {b}, c: {c}, d: {d}, s: {s}, t: {t}')
    absMaxI = max(abs(a), abs(b))
    if absMaxI < c: # remainder has no effect
        if debug:
            print('case 2')
        return i
    elif t == 0: # remainder by constant
        if div(a, c) == div(b, c): # E x. x*c <= a <= b < (x+1)*c 
            if debug:
                print('case 3')
            return normalize(MSI(n, rem(a, c) % 2**n, rem(b, c) % 2**n, s))
    if debug:
        print(f'case 5')
    u = int(gcd(gcd(c, t), s))
    e = a % u if 0 < a else max(a, 1-d + (a+d-1) % u)
    f = min(b, d-1) if 0 < b else (e-1) % u + 1 - u
    if debug:
        print(f'u: {u}, e: {e}, f: {f}, {a % u if 0 < a else 1-d + (a-d+1) % u}, {d-1 if 0 < b else (e-1) % u + 1 - u}')
    return normalize(MSI(n, e % 2**n, f % 2**n, u))

In [63]:
srem(MSI(3, 5, 6, 1), MSI(3, 2, 2, 0), debug=True)

a: -3, b: -2, c: 2, d: 2, s: 1, t: 0
a: -3, b: -2, c: 2, d: 2, s: 1, t: 0
a: -3, b: -2, c: 2, d: 2, s: 1, t: 0
case 3


7[0, 7]_{3}

In [51]:
def srem_cases(i, j):
    p = []
    n, m = i.bit_width, j.bit_width
    assert n == m, 'strides must be equal'
    a, b = as_signed_int(n, smin_MSI(i)), as_signed_int(n, smax_MSI(i))
    c, d = as_signed_int(n, smin_MSI(j)), as_signed_int(n, smax_MSI(j))
    s, t = sstride(i), sstride(j)
    if append(p, d < 0) and d < 0:
        c, d = -d, -c
    elif append(p, c < 0) and c < 0:
        t_ = (d+c) % t
        append(p, -c % t <= d % t)
        append(p, -c <= d)
        c, d = min(-c % t, d % t), max(-c, d)
        t = gcd(t, t_)
    if append(p, c == 0) and c == 0: # remainder by bound not possible
        if append(p, t == 0) and t == 0: # definite remainder by 0
            return MSI(n, 0, (-1) % 2**n, 1), p
        else: # correct bound to avoid ramainder by 0
            c = c+t
            # renormalize
            if append(p, c == d) and c == d:
                t = 0
    append(p, a >= 0)
    append(p, a >= b)
    append(p, abs(a) >= abs(b))
    absMaxI = max(abs(a), abs(b))
    if append(p, absMaxI < c) and absMaxI < c: # remainder has no effect
        return i, p
    elif append(p, t == 0) and t == 0: # remainder by constant
        if append(p, div(a, c) == div(b, c)) and div(a, c) == div(b, c): # E x. x*c <= a <= b < (x+1)*c 
            return normalize(MSI(n, rem(a, c) % 2**n, rem(b, c) % 2**n, s)), p
    u = int(gcd(gcd(c, t), s))
    if append(p, 0 < a) and 0 < a:
        e = a % u
    else:
        append(p, a >= 1-d + (a+d-1) % u)
        e = max(a, 1-d + (a+d-1) % u)
    if append(p, 0 < b) and 0 < b:
        append(p, b >= d-1)
        f = min(b, d-1)
    else:
        f = (e-1) % u + 1 - u
    return normalize(MSI(n, e % 2**n, f % 2**n, u)), p

In [269]:
srem_cases(MSI(4, 2, 5, 3), MSI(4, 15, 3, 2))

(1[0, 2]_{4},
 [False,
  True,
  True,
  True,
  False,
  True,
  False,
  False,
  False,
  False,
  True,
  True,
  True])

In [52]:
def append(xs, x):
    xs.append(x)
    return True

In [53]:
def add_case(cases, path, ex):
    if cases is None:
        if len(path) == 0:
            return (True, True, ex)
        else:
            return (False, False, {path[0]: add_case(None, path[1:], ex), not path[0]: None})
    else:
        fin0, a0, cs0 = cases
        if fin0:
            return cases
        else:
            b = path[0]
            fin1, a1, cs1 = add_case(cs0[b], path[1:], ex)
            fin = cs0[not b] is not None and cs0[not b][0] and fin1
            return (fin, a0, {b: (fin1, a1, cs1), not b: cs0[not b]})

In [54]:
def gen_test_cases(f):
    cases = None
    for n in range(1, 4+1):
        for i in test_MSIs_6_exhaustive[n]:
            for j in test_MSIs_6_exhaustive[n]:
                _, path = f(i, j)
                cases = add_case(cases, path, (i, j))
                if cases[0]:
                    return cases
    return cases

In [55]:
def get_cases(cases):
    if cases is None:
        return []
    elif cases[1]:
        return [cases[2]]
    else:
        return get_cases(cases[2][True]) + get_cases(cases[2][False])

In [56]:
def find_unreachable_pathes(cases):
    if cases is None:
        return [p]
    elif cases[0]:
        return []
    r = []
    for b in [True, False]:
        if cases[2][b] is None:
            r += [[b]]
        else:
            r += [[b]+p for p in find_unreachable_pathes(cases[2][b])]
    return r

In [287]:
cases = gen_test_cases(srem_cases)

In [290]:
cs = get_cases(cases)

In [291]:
len(cs)

267

In [292]:
cs[0]

(0[0, 0]_{1}, 0[1, 1]_{1})

In [305]:
srem(MSI(2, 2, 2, 0), MSI(2, 2, 3, 1))

3[0, 3]_{2}

In [304]:
for lhs, rhs in cs:
    n, a, b, s = lhs._tuple_repr()
    _, c, d, t = rhs._tuple_repr()
    ref = srem(lhs, rhs)
    _, e, f, u = ref._tuple_repr()
    print(f'lhs = {{{n}, {a}, {b}, {s}}}; rhs = {{{n}, {c}, {d}, {t}}}; ref = {{{n}, {e}, {f}, {u}}};')
    print(f'res_p = lhs.srem({n}, rhs);');
    print(f'res = *(static_cast<StridedInterval *>(res_p.get()));')
    print(f'if (res != ref) {{')
    print(f'  errs() << "[testSrem] failed with operands " << lhs << ", " << rhs << ": got " << res << ", expected " << ref << "\\n";')
    print(f'}}')

lhs = {1, 0, 0, 0}; rhs = {1, 1, 1, 0}; ref = {1, 0, 0, 0};
res_p = lhs.srem(1, rhs);
res = *(static_cast<StridedInterval *>(res_p.get()));
if (res != ref) {
  errs() << "[testSrem] failed with operands " << lhs << ", " << rhs << ": got " << res << ", expected " << ref << "\n";
}
lhs = {2, 1, 1, 0}; rhs = {2, 3, 3, 0}; ref = {2, 0, 0, 0};
res_p = lhs.srem(2, rhs);
res = *(static_cast<StridedInterval *>(res_p.get()));
if (res != ref) {
  errs() << "[testSrem] failed with operands " << lhs << ", " << rhs << ": got " << res << ", expected " << ref << "\n";
}
lhs = {2, 1, 1, 0}; rhs = {2, 2, 3, 1}; ref = {2, 0, 1, 1};
res_p = lhs.srem(2, rhs);
res = *(static_cast<StridedInterval *>(res_p.get()));
if (res != ref) {
  errs() << "[testSrem] failed with operands " << lhs << ", " << rhs << ": got " << res << ", expected " << ref << "\n";
}
lhs = {3, 1, 1, 0}; rhs = {3, 5, 7, 1}; ref = {3, 0, 1, 1};
res_p = lhs.srem(3, rhs);
res = *(static_cast<StridedInterval *>(res_p.get()));
if (res != ref) {

res = *(static_cast<StridedInterval *>(res_p.get()));
if (res != ref) {
  errs() << "[testSrem] failed with operands " << lhs << ", " << rhs << ": got " << res << ", expected " << ref << "\n";
}
lhs = {3, 7, 7, 0}; rhs = {3, 0, 3, 1}; ref = {3, 0, 7, 7};
res_p = lhs.srem(3, rhs);
res = *(static_cast<StridedInterval *>(res_p.get()));
if (res != ref) {
  errs() << "[testSrem] failed with operands " << lhs << ", " << rhs << ": got " << res << ", expected " << ref << "\n";
}
lhs = {3, 6, 6, 0}; rhs = {3, 0, 2, 1}; ref = {3, 0, 7, 7};
res_p = lhs.srem(3, rhs);
res = *(static_cast<StridedInterval *>(res_p.get()));
if (res != ref) {
  errs() << "[testSrem] failed with operands " << lhs << ", " << rhs << ": got " << res << ", expected " << ref << "\n";
}
lhs = {4, 1, 14, 13}; rhs = {4, 0, 6, 3}; ref = {4, 1, 14, 13};
res_p = lhs.srem(4, rhs);
res = *(static_cast<StridedInterval *>(res_p.get()));
if (res != ref) {
  errs() << "[testSrem] failed with operands " << lhs << ", " << rhs << ": got " 

In [258]:
n = 4
lhs, rhs = MSI(n, 10, 13, 3), MSI(n, 10, 13, 3)
print(f'{lhs}, {rhs}: {set(map(lambda k: as_signed_int(n, k), gamma(lhs)))}, {set(map(lambda k: as_signed_int(n, k), gamma(rhs)))}')
res = srem(lhs, rhs, debug=True)
print(f'{res}: {set(map(lambda k: as_signed_int(n, k), gamma(res)))}')

3[10, 13]_{4}, 3[10, 13]_{4}: {-6, -3}, {-6, -3}
a: -6, b: -3, c: -6, d: -3, s: 3, t: 3
all negative
a: -6, b: -3, c: 3, d: 6, s: 3, t: 3
a: -6, b: -3, c: 3, d: 6, s: 3, t: 3
case 5
u: 3, e: -3, f: 0, -4, 0
13[0, 13]_{4}: {0, -3}


In [302]:
bin_op_test(srem, lambda n, a, b: rem(as_signed_int(n, a), as_signed_int(n, b)) % 2**n, big=False, non_zero=True)

testing MSIs with bit width up to 4 exhaustively
    testing bit width: 1
    testing bit width: 2
    testing bit width: 3
    testing bit width: 4
- tested 25000 arguments
- tested 50000 arguments


KeyboardInterrupt: 