# Modular Strided Intervals

$\texttt{i}N = \mathbb{Z}/{2^N}$

for $s, a, b \in \{0..2^N-1\}$:

$\gamma_N \colon \texttt{i}N \rightarrow \mathcal{P}(\mathbb{Z}/{2^N}), s[a, b]_N \mapsto \{k \in [a, b]_N \mid k \equiv a \mod s\}$ where

$[a, b]_N = \{k+2^N\mathbb{Z} \mid k \in \{a..\min\{l \in \mathbb{Z} \mid a \leq k \land l \equiv b \mod 2^N\}\}\}$

In [1]:
from itertools import count, takewhile
from random import randint
from sympy import gcd

In [2]:
class MSI(object):
    """
    Modular strided iterval
    """
    def __init__(self, bit_width, begin, end, stride=1):
        self.bit_width = bit_width
        self.begin = begin
        self.end = end
        self.stride = stride

    def __eq__(self, other):
        return (self.bit_width == other.bit_width
            and self.stride == other.stride
            and self.begin == other.begin
            and self.end == other.end)

    def __repr__(self):
        return f'{self.stride}[{self.begin}, {self.end}]_{{{self.bit_width}}}'
    
    def __hash__(self):
        return (self.begin+23) * (self.end+29) * (self.stride+31) % 16777216
    
    def _tuple_repr(self):
        return (self.bit_width, self.begin, self.end, self.stride)

## Defining Functions and Predicates

In [3]:
# This predicate has not tests, it's an axiom.
def valid(i):
    n, a, b, s = i._tuple_repr()
    if n <= 0:
        return False
    if a < 0 or 2**n <= a:
        return False
    if b < 0 or 2**n <= b:
        return False
    if s < 0 or 2**n <= s:
        return False
    return True

In [4]:
def gamma(i):
    n = i.bit_width
    s = i.stride
    a = i.begin
    b = i.end if a <= i.end else i.end + 2**n
    return {k % 2**n for k in takewhile(lambda k: k <= b, (a+l*s for l in count()) if s > 0 else [a])}

In [5]:
def normal(i):
    n, a, b, s = i._tuple_repr()
    if s == 0 and a != b:
        return False
    if a == b and s != 0:
        return False
    if not b in gamma(i):
        return False
    a_ = a - s
    if a_ != a and a_ >= 0 and gamma(i) == gamma(MSI(n, a_, (b-s) % 2**n, s)):
        return False
    if b < a and gamma(i) == gamma(MSI(n, b, a, 2**n - s)):
        return False
    return True

## Test sets of MSIs with theire respective concretizations

In [6]:
test_MSIs_handpicked_gamma = [
    # normalized
    #   no wraparound
    #     strid = 0
    #       begin = 0
    (MSI(4, 0, 0, 0), {0}),
    #       begin > 0
    (MSI(4, 3, 3, 0), {3}),
    #     strid = 1
    #       begin = 0
    #         end < 2**N-1
    (MSI(4, 0, 2, 1), {0, 1, 2}),
    #         end = 2**N-1
    (MSI(3, 0, 7, 1), {0, 1, 2, 3, 4, 5, 6, 7}),
    #       begin > 0
    (MSI(4, 3, 4, 1), {3, 4}),
    #     stride > 1
    #       begin = 0
    (MSI(4, 0, 4, 2), {0, 2, 4}),
    #       begin > 0
    (MSI(3, 1, 7, 3), {1, 4, 7}),
    (MSI(6, 6, 26, 10), {6, 16, 26}),
    #   wraparound
    #     stride = 1
    (MSI(4, 14, 2, 1), {14, 15, 0, 1, 2}),
    #     stride > 1
    (MSI(4, 11, 4, 3), {1, 4, 11, 14})]
test_MSIs_handpicked_gamma_unnormalized = [
    # unnormalized
    #   no wraparound
    #     stride = 0
    #       begin = 0
    (MSI(4, 0, 3, 0), {0}),
    #       begin > 0
    (MSI(4, 3, 8, 0), {3}),
    #     stride = 1
    #       begin = 0
    #         end = begin
    (MSI(4, 0, 0, 1), {0}),
    #         end != begin
    (MSI(2, 2, 1, 1), {0, 1, 2, 3}),
    #       begin > 0
    #         end = begin
    (MSI(4, 3, 3, 1), {3}),
    #         end != begin
    (MSI(3, 5, 4, 1), {0, 1, 2, 3, 4, 5, 6, 7}),
    #     stride > 1
    #       begin = 0
    (MSI(4, 0, 5, 2), {0, 2, 4}),
    (MSI(4, 0, 3, 5), {0}),
    #       begin > 0
    #         end = begin - stride mod 2**N
    (MSI(4, 11, 7, 4), {3, 7, 11, 15}),
    #         end != begin - stride mod 2**N
    (MSI(6, 6, 35, 10), {6, 16, 26}),
    (MSI(4, 3, 7, 5), {3}),
    #   wraparound
    #     stride = 0
    (MSI(4, 5, 3, 0), {5}),
    #     stride = 1
    (MSI(3, 5, 4, 1), {0, 1, 2, 3, 4, 5, 6, 7}),
    (MSI(4, 15, 0, 1), {15, 0}),
    #     stride > 1
    #       end = begin - stride mod 2**N
    (MSI(4, 10, 6, 4), {2, 6, 10, 14}),
    (MSI(4, 12, 2, 6), {2, 12}),
    #       end != begin and != begin - stride mod 2**N
    (MSI(4, 13, 2, 8), {13}),
    (MSI(4, 11, 6, 3), {11, 14, 1, 4}),
    (MSI(4, 10, 9, 4), {2, 6, 10, 14}),
    (MSI(4, 12, 7, 6), {2, 12})
]

In [7]:
test_MSIs_handpicked = {}
for i, _ in test_MSIs_handpicked_gamma:
    n = i.bit_width
    if n not in test_MSIs_handpicked:
        test_MSIs_handpicked[n] = [i]
    else:
        test_MSIs_handpicked[n].append(i)
print('size: ' + ', '.join(f'{n}: {len(js)}' for n, js in test_MSIs_handpicked.items()))

test_MSIs_handpicked_unnormalized = {}
for i, _ in test_MSIs_handpicked_gamma_unnormalized:
    n = i.bit_width
    if n not in test_MSIs_handpicked_unnormalized:
        test_MSIs_handpicked_unnormalized[n] = [i]
    else:
        test_MSIs_handpicked_unnormalized[n].append(i)
print('size: ' + ', '.join(f'{n}: {len(js)}' for n, js in test_MSIs_handpicked_unnormalized.items()))

size: 4: 7, 3: 2, 6: 1
size: 4: 16, 2: 1, 3: 2, 6: 1


## Tests for gamma

In [8]:
def test_gamma():
    failed = False
    for i, ks in test_MSIs_handpicked_gamma:
        if not gamma(i) == ks:
            failed = True
            print(f'{i}: {gamma(i)}, {ks}')
    if not failed:
        print('succeeded')

def test_gamma_unnormalized():
    failed = False
    for i, ks in test_MSIs_handpicked_gamma_unnormalized:
        if not gamma(i) == ks:
            failed = True
            print(f'{i}: {gamma(i)}, {ks}')
    if not failed:
        print('succeeded')

In [9]:
test_gamma()
test_gamma_unnormalized()

succeeded
succeeded


## Normalization function

In [10]:
def normalize(i):
    n, a, b, s = i._tuple_repr()
    if s == 0:
        b = a
    else:
        b_ = b if a <= b else b+2**n
        b = (b_ - (b_-a) % s) % 2**n
        if a == b:
            s = 0
        else:
            if 2**n % s == 0 and (a-b) % 2**n == s:
                a = a % s
                b = (a-s) % 2**n
            elif b == (a+s) % 2**n and b < a:
                a, b = b, a
                s = b-a
    return MSI(n, a, b, s)

## Test sets and utility functions for testing

Warning:

`normal` is used in `unary_function_test` if the `unnormalized` parameter is `True`, but tested later. Therefore this parameter should not be set before `normal` is tested.

In [11]:
def test_set(bit_widths, begins, ends, strides, only_normal=True, print_stats=False):
    MSIs = {}
    for n in bit_widths:
        js = set()
        bs = begins(n)
        for b in bs:
            es = ends(n)
            for e in es:
                ss = strides(n)
                for s in ss:
                    if only_normal:
                        js.add(normalize(MSI(n, b, e, s)))
                    else:
                        js.add(MSI(n, b, e, s))
        MSIs[n] = list(js)
    if print_stats:
        print('size: ' + ', '.join(f'{n}: {len(js)}' for n, js in MSIs.items()))
        if not only_normal:
            print('unnormalized: ' + ', '.join(f'{n}: {len(list(0 for j in js if not normal(j)))}' for n, js in MSIs.items()))
    return MSIs

In [12]:
f = lambda n: list(range(2**n))
g = lambda n: list(range(2**n))
print('test_MSIs_4_exhaustive')
test_MSIs_4_exhaustive = test_set(range(1, 4+1), f, g, f, print_stats=True)
print('test_MSIs_4_exhaustive')
test_MSIs_4_exhaustive_unnormalized = test_set(range(1, 4+1), f, g, f, only_normal=False, print_stats=True)

test_MSIs_4_exhaustive
size: 1: 3, 2: 15, 3: 95, 4: 575
test_MSIs_4_exhaustive
size: 1: 8, 2: 64, 3: 512, 4: 4096
unnormalized: 1: 5, 2: 49, 3: 417, 4: 3521


In [13]:
f = lambda n: list(range(2**n))
g = lambda n: list(range(2**n))
print('test_MSIs_5_6_exhaustive')
test_MSIs_5_6_exhaustive = test_set(range(5, 6+1), f, g, f, print_stats=True)
print('test_MSIs_5_6_exhaustive')
test_MSIs_5_6_exhaustive_unnormalized = test_set(range(5, 6+1), f, g, f, only_normal=False, print_stats=True)

test_MSIs_5_6_exhaustive
size: 5: 3039, 6: 15231
test_MSIs_5_6_exhaustive
size: 5: 32768, 6: 262144
unnormalized: 5: 29729, 6: 246913


In [14]:
test_MSIs_6_exhaustive = {
    **test_MSIs_4_exhaustive, **test_MSIs_5_6_exhaustive
}
test_MSIs_6_exhaustive_unnormalized = {
    **test_MSIs_4_exhaustive_unnormalized, **test_MSIs_5_6_exhaustive_unnormalized
}

In [15]:
ks = [a+b for a in [0, 30] for b in [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 15]]
ls = [30, 31, 32, 33, 35, 36, 40, 45]
f = lambda _: ks
print('test_MSIs_6_partial')
test_MSIs_6_partial = test_set([6], f, g, f, print_stats=True)
print('\ntest_MSIs_6_partial_unnormalized')
test_MSIs_6_partial_unnormalized = test_set([6], f, f, f, only_normal=False, print_stats=True)

test_MSIs_6_partial
size: 6: 4111

test_MSIs_6_partial_unnormalized
size: 6: 10648
unnormalized: 6: 9309


In [16]:
ks = [a+b for a in [0, 30, 60] for b in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 25]]
ls = [0, 2, 3, 5, 6, 10, 15]
f = lambda n: takewhile(lambda k: k < 2**n, ks)
g = lambda n: (((15 if 2**n < 30 else 45) + 15 + a) % 2**n for a in ls)
print('test_MSIs_8_partial')
test_MSIs_8_partial = test_set([8], f, g, f, print_stats=True)
print('\ntest_MSIs_8_partial_unnormalized')
test_MSIs_8_partial_unnormalized = test_set([8], f, g, f, only_normal=False, print_stats=True)

test_MSIs_8_partial
size: 8: 2438

test_MSIs_8_partial_unnormalized
size: 8: 10647
unnormalized: 8: 9705


In [17]:
f = lambda n: set(randint(0, 2**n-1) for _ in range(8))
g = lambda n: set(randint(0, 2**(n-1)-1) for _ in range(8))
print('test_MSIs_random')
test_MSIs_random = test_set(range(5, 8+1), f, f, g, print_stats=True)
print('\ntest_MSIs_random_unnormalized')
test_MSIs_random_unnormalized = test_set(range(5, 8+1), f, f, g, only_normal=False, print_stats=True)

test_MSIs_random
size: 5: 201, 6: 299, 7: 344, 8: 366

test_MSIs_random_unnormalized
size: 5: 326, 6: 317, 7: 496, 8: 486
unnormalized: 5: 271, 6: 286, 7: 475, 8: 469


In [18]:
def _unary_function_test(f, p, test_MSIs, test_count=0, fail_count=0, fail_lim=8):
    for n, js in test_MSIs.items():
        print(f'    testing bit width: {n}')
        for i in js:
            test_count += 1
            x = f(i)
            if not p(i, x):
                fail_count += 1
                print(f'        {i}: {x}')
                if fail_count == fail_lim:
                    return test_count, fail_count
            if test_count % 25000 == 0:
                print(f'- tested {test_count} arguments')
    return test_count, fail_count

def unary_function_test(f, p, big=False, unnormalized=False):
    fail_lim = 16 if big else 8
    test_count = fail_count = 0
    print('testing MSIs with bit width up to 4 exhaustively')
    MSIs = test_MSIs_4_exhaustive_unnormalized if unnormalized else test_MSIs_4_exhaustive
    test_count, fail_count = _unary_function_test(f, p, MSIs, test_count, fail_count, fail_lim)
    if fail_count == fail_lim:
        return
    print('testing random MSIs with bit width from 5 to 8')
    MSIs = test_MSIs_random_unnormalized if unnormalized else test_MSIs_random
    test_count, fail_count = _unary_function_test(f, p, MSIs, test_count, fail_count, fail_lim)
    if fail_count == fail_lim:
        return
    if big:
        print('testing some MSIs with bit width 6')
        MSIs = test_MSIs_6_partial_unnormalized if unnormalized else test_MSIs_6_partial
        test_count, fail_count = _unary_function_test(f, p, MSIs, test_count, fail_count, fail_lim)
        if fail_count == fail_lim:
            return
        print('testing some MSIs with bit width 8')
        MSIs = test_MSIs_8_partial_unnormalized if unnormalized else test_MSIs_8_partial
        test_count, fail_count = _unary_function_test(f, p, MSIs, test_count, fail_count, fail_lim)
        if fail_count == fail_lim:
            return
    if fail_count == 0:
        print(f'succeeded (tested {test_count} arguments in total)')

In [19]:
def _bin_op_test(op_MSI, op, test_MSIs, test_count=0, fail_count=0, fail_lim=8, bad_args={}, bad_lim=8, non_zero=False):
    bad_precision = max(bad_args.values()) if len(bad_args) > 0 else 1
    for n, js in test_MSIs.items():
        print(f'    testing bit width: {n}')
        for i in js:
            vals_i = gamma(i)
            for j in js:
                test_count += 1
                vals_j = gamma(j)
                if non_zero and 0 in vals_j:
                    vals_op = {k for k in range(2**n)}
                else:
                    vals_op = {op(n, k, l) for k in vals_i for l in vals_j}
                vals_op_MSI = gamma(op_MSI(i, j))
                if not vals_op <= vals_op_MSI:
                    fail_count += 1
                    print(f'        {i} op {j}: {op(i, j)}')
                    if fail_count == fail_lim:
                        return test_count, fail_count, bad_args
                else:
                    precision = len(vals_op) / (len(vals_op_MSI) * 2**n)
                    if precision < bad_precision:
                        if len(bad_args) == bad_lim:
                            bad_args.pop(list(bad_args.keys())[list(bad_args.values()).index(bad_precision)])
                        bad_args[(i, j)] = precision
                        bad_precision = max(bad_args.values())
                if test_count % 25000 == 0:
                    print(f'- tested {test_count} arguments')
    return test_count, fail_count, bad_args

def bin_op_test(op_MSI, op, big=False, non_zero=False):
    fail_lim = bad_lim = 16 if big else 8
    test_count = fail_count = 0
    bad_args = {}
    print('testing MSIs with bit width up to 4 exhaustively')
    test_count, fail_count, bad_args = _bin_op_test(op_MSI, op, test_MSIs_4_exhaustive, test_count, fail_count, fail_lim, bad_args, bad_lim, non_zero=non_zero)
    if fail_count == fail_lim:
        return
    print('testing random MSIs with bit width from 5 to 8')
    test_count, fail_count, bad_args = _bin_op_test(op_MSI, op, test_MSIs_random, test_count, fail_count, fail_lim, bad_args, bad_lim, non_zero=non_zero)
    if fail_count == fail_lim:
        return
    if big:
        print('testing some MSIs with bit width 6')
        test_count, fail_count, bad_args = _bin_op_test(op_MSI, op, test_MSIs_6_partial, test_count, fail_count, fail_lim, bad_args, bad_lim, non_zero=non_zero)
        if fail_count == fail_lim:
            return
        print('testing some MSIs with bit width 8')
        test_count, fail_count, bad_args = _bin_op_test(op_MSI, op, test_MSIs_8_partial, test_count, fail_count, fail_lim, bad_args, bad_lim, non_zero=non_zero)
        if fail_count == fail_lim:
            return
    if fail_count == 0:
        print(f'succeeded (tested {test_count} arguments in total)')
        print('arguments with least precise results:')
        for (i, j), r in bad_args.items():
            print(f'{i}, {j}: {r}')

In [109]:
def _bin_rel_test(rel_MSI, rel, test_MSIs, test_count=0, fail_count=0, fail_lim=8):
    for n, js in test_MSIs.items():
        print(f'    testing bit width: {n}')
        for i in js:
            for j in js:
                test_count += 1
                if not (rel_MSI(i, j) == rel(i, j)):
                    fail_count += 1
                    print(f'        {i} rel {j}: {rel_MSI(i, j)}')
                    if fail_count == fail_lim:
                        return test_count, fail_count
                if test_count % 25000 == 0:
                    print(f'- tested {test_count} arguments')
    return test_count, fail_count

def bin_rel_test(rel_MSI, rel, big=False):
    fail_lim = 16 if big else 8
    test_count = fail_count = 0
    print('testing MSIs with bit width up to 4 exhaustively')
    test_count, fail_count = _bin_rel_test(rel_MSI, rel, test_MSIs_4_exhaustive, test_count, fail_count, fail_lim)
    if fail_count == fail_lim:
        return
    print('testing random MSIs with bit width from 5 to 8')
    test_count, fail_count = _bin_rel_test(rel_MSI, rel, test_MSIs_random, test_count, fail_count, fail_lim)
    if fail_count == fail_lim:
        return
    if big:
        print('testing some MSIs with bit width 6')
        test_count, fail_count = _bin_rel_test(rel_MSI, rel, test_MSIs_6_partial, test_count, fail_count, fail_lim)
        if fail_count == fail_lim:
            return
        print('testing some MSIs with bit width 8')
        test_count, fail_count = _bin_rel_test(rel_MSI, rel, test_MSIs_8_partial, test_count, fail_count, fail_lim)
        if fail_count == fail_lim:
            return
    if fail_count == 0:
        print(f'succeeded (tested {test_count} arguments in total)')

## Test for normal

In [21]:
def test_normal():
    failed = False
    test_count = fail_count = 0
    for n, js in test_MSIs_6_exhaustive.items():
        equiv_classes = {}
        for i in js:
            a = frozenset(gamma(i))
            if a in equiv_classes:
                equiv_classes[a].add(i)
            else:
                equiv_classes[a] = {i}
        for equiv_class in equiv_classes.values():
            norm_forms = list(filter(normal, equiv_class))
            test_count += 1
            if len(norm_forms) != 1:
                failed = True
                fail_count += 1
                if len(norm_forms) == 0:
                    print(f'no normal form for {equiv_class}')
                else:
                    print(f'multiple normal forms {norm_forms}')
                if fail_count > 8:
                    return
    print(f'succeeded (tested {test_count} equivalence classes in total)')

In [22]:
test_normal()

succeeded (tested 18958 equivalence classes in total)


## Helper functions

In [23]:
def bounds(i):
    n, a, b, _ = i._tuple_repr()
    if a <= b:
        return a, b, False
    else:
        return a, b + 2**n, True

In [24]:
def contains(i, k):
    n, a, b, s = i._tuple_repr()
    if s == 0:
        return a == k
    elif a <= b:
        return a <= k and k <= b and (k - a) % s == 0
    else:
        if k >= a:
            return (k - a) % s == 0
        elif k <= b:
            return (k - b) % s == 0
        else:
            return False

In [25]:
def test_contains():
    failed = False
    test_count = fail_count = 0
    for n, js in test_MSIs_6_exhaustive.items():
        for i in js:
            test_count += 1
            a = gamma(i)
            for k in range(2**n):
                if k in a and not contains(i, k):
                    failed = True
                    fail_count += 1
                    print(f'{k} in gamma({i})')
                if k not in a and contains(i, k):
                    failed = True
                    fail_count += 1
                    print(f'{k} not in gamma({i})')
                if fail_count > 8:
                    return
    print(f'succeeded (tested {test_count} arguments)')

In [26]:
test_contains()

succeeded (tested 18958 arguments)


In [None]:
            if a <= b:
                if c <= d:
                    if debug:
                        print('5.1.1.1')
                    return c <= a and b <= d and (a-c) % t == 0
                else:
                    if a <= d:
                        if b <= d:
                            if debug:
                                print('5.1.1.2.1.1')
                            return (d-a) % t == 0
                        else:
                            if debug:
                                print('5.1.1.2.1.2')
                            return c-d <= s and (d-a) % t == 0 and (b-c) % t == 0
                    else:
                        if debug:
                            print('5.1.1.2.2')
                        return c <= a and (d-a) % t == 0
            else:
                if c <= d:
                    if debug:
                        print('5.1.2.1')
                    e = d - (d-a) % s
                    return (c-d) % 2**n <= s and (a-c) % t == 0 and (b-c) % t == 0 and e+s >= c + 2**n
                else:
                    if debug:
                        print('5.1.2.2')
                    return c <= a and b <= d and (a-c) % t == 0

In [123]:
def leq_MSI(i, j, debug=False):
    n, a, b, s = i._tuple_repr()
    m, c, d, t = j._tuple_repr()
    assert n == m, 'strides must be equal'
    if s == 0: # i contains exactly 1 value
        return contains(j, a)
    elif t == 0: # j contains exactly 1 value
        return False
    elif b == (a+s) % 2**n: # i contains exactly 2 values
        return contains(j, a) and contains(j, b)
    if s % t == 0:
        if 2**n % t == 0 and (c-d) % 2**n == t: # j represents a residue class of Z/t (=> t | 2**n)
            return (a-c) % t == 0 and s % t == 0
        else:
            b_ = (b-a) % 2**n
            c_, d_ = (c-a) % 2**n, (d-a) % 2**n
            if (c-d) % 2**n <= s and d_ < c_ and d_ <= b_ and c_ <= b_: # this branch may not return, but continue below [a...d_...c_...b_...]
                e_ = s * (d_ // s)
                f_ = (b_ - s * ((b_-c_) // s)) % s**n
                if (f_-e_) == s:
                    if e_ < s:
                        if contains(j, a) and s % t == 0 and c_ % t == 0:
                            return True
                    elif (f_-c_) < s:
                        if contains(j, b) and s % t == 0 and d_ % t == 0:
                            return True
                    else:
                        if s % t == 0:
                            return True
            if c_ <= d_:
                return c_ == 0 and b_ <= d_
            else:
                return b_ <= d_ and (d_-b_) % t == 0
    else:
        return False

In [124]:
bin_rel_test(leq_MSI, lambda i, j: gamma(i) <= gamma(j))

testing MSIs with bit width up to 4 exhaustively
    testing bit width: 1
    testing bit width: 2
    testing bit width: 3
    testing bit width: 4
- tested 25000 arguments
- tested 50000 arguments
- tested 75000 arguments
- tested 100000 arguments
- tested 125000 arguments
- tested 150000 arguments
- tested 175000 arguments
- tested 200000 arguments
- tested 225000 arguments
- tested 250000 arguments
- tested 275000 arguments
- tested 300000 arguments
- tested 325000 arguments
testing random MSIs with bit width from 5 to 8
    testing bit width: 5
- tested 350000 arguments
- tested 375000 arguments
    testing bit width: 6
- tested 400000 arguments
- tested 425000 arguments
- tested 450000 arguments
    testing bit width: 7
- tested 475000 arguments
- tested 500000 arguments
- tested 525000 arguments
- tested 550000 arguments
- tested 575000 arguments
    testing bit width: 8
- tested 600000 arguments
- tested 625000 arguments
- tested 650000 arguments
- tested 675000 arguments
- tes

In [107]:
lhs = MSI(3, 2, 0, 3)
rhs = MSI(3, 5, 3, 3)
res = leq_MSI(lhs, rhs)
print(f'{lhs} leq {rhs} = {res}')
print(f'{gamma(lhs)} leq {gamma(rhs)} = {gamma(lhs) <= gamma(rhs)}')
leq_MSI(lhs, rhs, debug=True)

3[2, 0]_{3} leq 3[5, 3]_{3} = False
{0, 2, 5} leq {0, 3, 5} = False
5.1
5.1.1
5.1.1.1
5.1.1.1.1
5.2.2 - b_: 6, c_: 3, d_: 1


False

In [57]:
def max_MSI(i):
    n, a, b, s = i._tuple_repr()
    if a <= b:
        return b
    else:
        return 2**n - ((2**n - a - 1) % s + 1)

In [98]:
def min_MSI(i):
    n, a, b, s = i._tuple_repr()
    if a <= b:
        return a
    else:
        return a % s

In [82]:
def as_unsigned(i):
    n, a, b, s = i._tuple_repr()
    if a <= b:
        return MSI(n, a, b, s)
    else:
        t = int(gcd(s, (a-b) & 2**n))
        c = a % t
        d = (c-t) % 2**n
        return MSI(n, c, d, t)

In [99]:
unary_function_test(min_MSI, lambda i, k: k == min(gamma(i)), big=True)

testing MSIs with bit width up to 4 exhaustively
    testing bit width: 1
    testing bit width: 2
    testing bit width: 3
        3[3, 1]_{3}: 0
        3[4, 2]_{3}: 1
        3[5, 3]_{3}: 2
        3[2, 0]_{3}: 2
    testing bit width: 4
        5[9, 3]_{4}: 4
        3[12, 2]_{4}: 0
        3[9, 5]_{4}: 0
        6[11, 7]_{4}: 5
        3[14, 7]_{4}: 2
        3[13, 3]_{4}: 1
        7[8, 6]_{4}: 1
        5[10, 4]_{4}: 0
        6[4, 0]_{4}: 4
        5[3, 2]_{4}: 3
        3[10, 6]_{4}: 1
        6[5, 1]_{4}: 5


In [87]:
unary_function_test(max_MSI, lambda i, k: k == max(gamma(i)), big=True)

testing MSIs with bit width up to 4 exhaustively
    testing bit width: 1
    testing bit width: 2
    testing bit width: 3
    testing bit width: 4
testing random MSIs with bit width from 5 to 8
    testing bit width: 5
    testing bit width: 6
    testing bit width: 7
    testing bit width: 8
testing some MSIs with bit width 6
    testing bit width: 6
testing some MSIs with bit width 8
    testing bit width: 8
succeeded (tested 5041 arguments in total)


In [None]:
unary_function_test(as_unigned, lambda i, j: )

## Tests of Basic Functions

In [60]:
unary_function_test(normalize, lambda i, j: i == j, big=True)

testing MSIs with bit width up to 4 exhaustively
    testing bit width: 1
    testing bit width: 2
    testing bit width: 3
    testing bit width: 4
testing random MSIs with bit width from 5 to 8
    testing bit width: 5
    testing bit width: 6
    testing bit width: 7
    testing bit width: 8
testing some MSIs with bit width 6
    testing bit width: 6
testing some MSIs with bit width 8
    testing bit width: 8
succeeded (tested 6484 arguments in total)


In [61]:
unary_function_test(normalize, lambda i, j: valid(j) and gamma(i) == gamma(j), unnormalized=True, big=True)

testing MSIs with bit width up to 4 exhaustively
    testing bit width: 1
    testing bit width: 2
    testing bit width: 3
    testing bit width: 4
testing random MSIs with bit width from 5 to 8
    testing bit width: 5
    testing bit width: 6
    testing bit width: 7
    testing bit width: 8
testing some MSIs with bit width 6
    testing bit width: 6
testing some MSIs with bit width 8
    testing bit width: 8
- tested 25000 arguments
succeeded (tested 27648 arguments in total)


# Operations

## Helper Functions for Tests of Operations

## Implementation of Operations

In [215]:
def add(i, j):
    n, a, b, s = i._tuple_repr()
    m, c, d, t = j._tuple_repr()
    assert n == m, 'strides must be equal'
    u = int(gcd(s, t))
    b_ = b if a <= b else b + 2**n
    d_ = d if c <= d else d + 2**n
    e, f = a+c, b_+d_
    if f-e < 2**n:
        u_ = u
        e_, f_ = e % 2**n, f % 2**n
    else:
        u_ = int(gcd(u, 2**n))
        e_ = e % 2**n
        f_ = (e_-u_) % 2**n
    return normalize(MSI(n, e_, f_, u_))

In [325]:
bin_op_test(add, lambda n, a, b: (a+b) % 2**n)

testing MSIs with bit width up to 4 exhaustively
    testing bit width: 1
    testing bit width: 2
    testing bit width: 3
    testing bit width: 4
- tested 25000 arguments
- tested 50000 arguments
- tested 75000 arguments
- tested 100000 arguments
- tested 125000 arguments
testing random MSIs with bit width from 5 to 8
    testing bit width: 5
- tested 150000 arguments
    testing bit width: 6
- tested 175000 arguments
- tested 200000 arguments
- tested 225000 arguments
    testing bit width: 7
- tested 250000 arguments
- tested 275000 arguments
- tested 300000 arguments
- tested 325000 arguments
    testing bit width: 8
- tested 350000 arguments
- tested 375000 arguments
- tested 400000 arguments
- tested 425000 arguments
- tested 450000 arguments
- tested 475000 arguments
succeeded
arguments with least precise results:
119[80, 199]_{8}, 137[69, 206]_{8}: 0.01171875
129[101, 230]_{8}, 129[101, 230]_{8}: 0.01171875
129[101, 230]_{8}, 129[14, 143]_{8}: 0.01171875
129[14, 143]_{8}, 129

In [88]:
def mul(i, j, debug=False):
    n, a, b, s = i._tuple_repr()
    m, c, d, t = j._tuple_repr()
    assert n == m, 'strides must be equal'
    m = 2**n
    u = int(gcd(a, s)) * int(gcd(c, t))
    b_ = b if a <= b else b + m
    d_ = d if c <= d else d + m
    e, f = a*c, b_*d_
    if f-e < m:
        u_ = u
        e_, f_ = e % m, f % m
    else:
        u_ = int(gcd(u, m))
        e_ = e % m
        f_  = (e_-u_) % m
    if debug:
        print(f'u: {u}, e: {e}, f: {f}, u_: {u_}, e_: {e_}, f_: {f_}')
    return normalize(MSI(n, e_, f_, u_))

In [91]:
bin_op_test(mul, lambda n, a, b: (a*b) % 2**n, big=False)

testing MSIs with bit width up to 4 exhaustively
    testing bit width: 1
    testing bit width: 2
    testing bit width: 3
    testing bit width: 4
- tested 25000 arguments


KeyboardInterrupt: 

In [22]:
interprete_as_unsigned(MSI(4, 7, 3, 3))

3[1, 14]_{4}

In [450]:
def urem(i, j, debug=False):
    n, a, b, s = i._tuple_repr()
    m, c, d, t = j._tuple_repr()
    assert n == m, 'strides must be equal'
    if c == 0 or d < c and d % t == 0:
        if debug:
            print('1')
        e_, f_ = 0, 2**n-1
        u_ = 1
    elif s == 0:
        c_ = c if c <= d else t - ((2**n - c) % t)
        if a < c_:
            if debug:
                print('2.1')
            e_, f_ = a, b
            u_ = s
        else:
            if debug:
                print('2.2')
            e_ = a % int(gcd(t, c))
            f_ = d-1
            u_ = int(gcd(a, gcd(t, d)))
    else:
        s_ = s if a <= b else gcd(s, (a-b) % 2**n)
        t_ = t if c <= d else gcd(t, (c-d) % 2**n)
        if t == 0:
            p3_1 = False
            if a <= b:
                v, w = a // c, b // c
                x, y = a % c, b % c
                if v == w:
                    if debug:
                        print('3.1')
                    e_, f_ = x, y
                    u_ = s
                    p3_1 = True
            if not p3_1:
                if debug:
                    print('3.2')
                u_ = gcd(s_, c)
                e_, f_ = a % u_, c-1
        else:
            if debug:
                print('4')
            u_ = gcd(s_, gcd(t_, c))
            # if d < c, a more precise result can be achieved by spliting the lhs
            # at 2**n-1/0, computing urem for both parts and then take their lub.
            e_ = a % u_
            d_ = d if c <= d else 2**n - ((2**n - c) % t)
            f_ = d_-1
    if debug:
        print(MSI(n, e_, f_, u_))
    return normalize(MSI(n, e_, f_, u_))

In [None]:
def urem(i, j, debug=False):
    

In [451]:
lhs, rhs = MSI(3, 1, 1, 0), MSI(3, 3, 1, 3)
print(f'{lhs}, {rhs}: {gamma(lhs)}, {gamma(rhs)}')
res = urem(lhs, rhs, debug=True)
print(f'{res}: {gamma(res)}')

0[1, 1]_{3}, 3[3, 1]_{3}: {1}, {1, 3, 6}
2.2
1[1, 0]_{3}
1[0, 7]_{3}: {0, 1, 2, 3, 4, 5, 6, 7}


In [447]:
bin_op_test(urem, lambda n, a, b: (a % b) % 2**n, big=False, non_zero=True)

testing MSIs with bit width up to 4 exhaustively
    testing bit width: 1
    testing bit width: 2
    testing bit width: 3
        0[1, 1]_{3} op 3[3, 1]_{3}: 3[4, 2]_{3}
        0[1, 1]_{3} op 2[7, 3]_{3}: 2[0, 4]_{3}
        0[1, 1]_{3} op 2[5, 1]_{3}: 2[6, 2]_{3}
        0[3, 3]_{3} op 3[3, 1]_{3}: 3[6, 4]_{3}
        0[3, 3]_{3} op 2[7, 3]_{3}: 2[2, 6]_{3}
        0[3, 3]_{3} op 3[4, 2]_{3}: 3[7, 5]_{3}
        0[3, 3]_{3} op 2[5, 1]_{3}: 2[0, 4]_{3}
        0[5, 5]_{3} op 2[7, 3]_{3}: 2[4, 0]_{3}
