In [None]:
import random
import math
import scipy.stats as stats

In [None]:
def binary_search(fn, i, j):
    """ Binary search in the range [i, j). 
    """
    while i < j:
        delta = j - i
        pivot = i + math.floor(delta / 2)
        if fn(pivot):
            i = pivot + 1
        else:
            j = pivot
    return i - 1

def skewed_binary_search(fn, k, i, j):
    if k == 0: 
        return i - 1
    delta = j - i
    diff = max(1, 2 ** math.floor(math.log((j - i) / k, 2)))
    while i < j:
        pivot = max(0, j - diff)
        if fn(pivot):
            return binary_search(fn, pivot + 1, j)
        j = pivot
    return i - 1

In [None]:
ll = list(range(10))
def greater_than(i):
    print("x" * i, "." * (10 - i))
    return ll[i] <= 5

binary_search(greater_than, 0, 10)
print()
binary_search(greater_than, 7, 8)

In [None]:
ll = list(range(10))
def greater_than(i):
    print(i)
    print("x" * i, "." * (10 - i))
    return ll[i] <= 5

skewed_binary_search(greater_than, 30, 0, 10)

In [None]:
n = 100
k = 10

NOTHING  = "."
REQUIRED = "x"
INCLUDED = "o"


def create_distribution(n, k):
    dist = set(random.sample(range(n), k))
    return [ REQUIRED if i in dist else NOTHING for i in range(n)]
    
def pretty(dist):
    return ''.join(dist)

example = create_distribution(n, k)
# pretty(example)

In [None]:
def binary_reduction(items, debug=False):
    count = 0
    
    def p(i):
        nonlocal count
        if debug: print(f"{count: 3}",pretty(items[:i]) + "|" + pretty(items[i+1:]))
        count += 1
        return not all(j != REQUIRED for j in items[i:])
    r = len(items) 
    
    while r > 0:
        if debug: print()
        r = binary_search(p, 0, r)
        if r >= 0:
            items[r] = INCLUDED
        
        
    if debug: 
        print(pretty(items))
        
    return count
        
def skewed_binary_reduction(items, debug=False):
    count = 0
    
    def p(i):
        nonlocal count
        if debug: print(f"{count: 4}", pretty(items[:i]) + "|" + pretty(items[i+1:]))
        count += 1
        return not all(j != REQUIRED for j in items[i:])
    
    k = 1
    r = len(items)
    ks = []
    while r > 0:
        if debug: print(f"k={k:02}", f"{count:03}")
        new_r = skewed_binary_search(p, k, 0, r)
        ks.append(math.ceil(r / (r - new_r)))
        ks = [ max(1, x - 1) for x in ks ]
        k = min(new_r, stats.gmean(ks))
        r = new_r
        if r >= 0:
            items[r] = INCLUDED
        
    if debug: 
        print(pretty(items))
        
    return count



import itertools

def stepped_binary_reduction(items, debug=False):
    count = 0
    
    def p(i):
        nonlocal count
        if debug: print(f"{count: 4}", pretty(items[:i]) + "|" + pretty(items[i+1:]))
        count += 1
        return not all(j != REQUIRED for j in items[i:])
    
    r = len(items)
    diff = 2 ** int(math.floor(math.log(r, 2)) - 1)
    while r > 0:
        
        pivot = max(0, r - diff)
        if not p(pivot):
            r = pivot
            diff *= 2
            continue    
            
        if debug:
            print()
            print(f"diff={diff:02}", f"{pivot} {r}", f"{count:03}", f"{i}")
        
        if diff > 1:
            r = binary_search(p, pivot, r)
            if r >= 0:
                items[r] = INCLUDED
            diff //= 2
        else:
            items[pivot] = INCLUDED
            r = pivot
        
        if debug:
            print()
        
    if debug: 
        print(pretty(items))
        
    return count    
    
    
    

In [None]:
stepped_binary_reduction(create_distribution(30,30), debug=True)

In [None]:
print("   ", pretty(example))

In [None]:
binary_reduction(list(create_distribution(30,30)), debug=True)

In [None]:
skewed_binary_reduction(list(example), debug=True)

In [None]:
k = 40
ns = []
normal = []
skewed = []
stepped = []
for i in range(100):
    x = random.randrange(1, 500)
    n = k * x
    d = create_distribution(n, k)
    print(f"---- {i:3} {x} {n}")
    bx = list(d)
    normal.append(binary_reduction(bx))
    by = list(d)
    skewed.append(skewed_binary_reduction(by))
    bz = list(d)
    stepped.append(stepped_binary_reduction(bz))
    ns.append(n)
    assert bx == by
    assert bx == bz
    

In [None]:
import matplotlib.pyplot as plt

plt.figure()
plt.ylim(0, max(normal) + 20)
plt.scatter(ns, normal, label="binary")
plt.scatter(ns, skewed, label="skewed")
plt.scatter(ns, stepped, label="stepped")
plt.legend()


In [None]:
n = 3000
ks = []
normal = []
skewed = []
stepped = []
for i in range(100):
    k = random.randrange(1, n)
    d = create_distribution(n, k)
    print(f"---- {i:3} {k} {n}")
    bx = list(d)
    normal.append(binary_reduction(bx))
    by = list(d)
    skewed.append(skewed_binary_reduction(by))
    bz = list(d)
    stepped.append(stepped_binary_reduction(bz))
    ks.append(k)
    assert bx == by
    assert bx == bz
    

In [None]:
import matplotlib.pyplot as plt

plt.figure()
plt.ylim(0, max(normal) + 20)
plt.scatter(ks, normal, label="binary")
plt.scatter(ks, skewed, label="skewed")
plt.scatter(ks, stepped, label="stepped")
plt.legend()