# **Follow The Leading History 2** Algorithm

This algorithm is from Elad Hazan's book and is a good regret-minimizing algorithm.


## Imports

We'll start by importing some required packages.


In [1]:
import math

## Lifetimes

Lifetimes is an important concept that helps prune and keeps the algorithm memory-efficient.


In [2]:
def lifetime(i):
    if i == 0:
        return 1
    k = int(math.log2(i & -i))  # Get the position of the rightmost set bit
    return 2 ** (k + 2) + 1

def is_alive(i, t):
    return i <= t <= i + lifetime(i) - 1

def naive_pruning(T):
    """
    Naive pruning algorithm function.

    Args:
    T (int): The number of time steps to run the algorithm for.

    Returns:
    list: A list of sets, where the t-th set (0-indexed) represents S_{t+1}.
    """
    alive_sets = []
    for t in range(1, T + 1):
        S_t = set(i for i in range(1, t + 1) if is_alive(i, t))
        alive_sets.append(S_t)
    return alive_sets

def pruning(T):
    """
    Efficient pruning algorithm function without sequential looping over all integers.

    Args:
    T (int): The number of time steps to run the algorithm for.

    Returns:
    list: A list of sets, where the t-th set (0-indexed) represents S_{t+1}.
    """
    # Initialize an empty list of events
    events = []

    max_k = int(math.log2(T)) + 1

    for k in range(max_k):
        m = 2 ** (k + 2) + 1
        delta_i = 2 ** (k + 1)
        i_start = 2 ** k
        i_end = T

        # Generate all i values for this k
        i_values = range(i_start, i_end + 1, delta_i)

        # For each i, compute the addition and removal times
        for i in i_values:
            t_add = i
            t_remove = i + m
            if t_add > T:
                break
            t_remove = min(t_remove, T + 1)

            events.append((t_add, 'add', i))
            events.append((t_remove, 'remove', i))

    # Include the addition of each integer at its own time
    for t in range(1, T + 1):
        events.append((t, 'add', t))

    # Sort events by time
    events.sort()

    # Initialize alive_sets
    alive_sets = []
    current_set = set()

    event_idx = 0
    total_events = len(events)
    for t in range(1, T + 1):
        # Process all events at time t
        while event_idx < total_events and events[event_idx][0] == t:
            _, action, i = events[event_idx]
            if action == 'add':
                current_set.add(i)
            elif action == 'remove':
                current_set.discard(i)
            event_idx += 1
        alive_sets.append(current_set.copy())

    return alive_sets

def print_table(alive_sets):
    T = len(alive_sets)
    max_width = max(len(str(sorted(s))) for s in alive_sets)

    print(f"{'Time':5} | {'Alive Set':{max_width}} | New | Removed")
    print("-" * (max_width + 30))

    for t in range(1, T + 1):
        new = alive_sets[t - 1] - (alive_sets[t - 2] if t > 1 else set())
        removed = (alive_sets[t - 2] if t > 1 else set()) - alive_sets[t - 1]

        print(f"{t:5} | {str(sorted(alive_sets[t-1])):{max_width}} | {sorted(new)} | {sorted(removed)}")

def verify_property_1(alive_sets):
    for t, S_t in enumerate(alive_sets, start=1):
        for s in range(1, t + 1):
            interval = set(range(s, (s + t) // 2 + 1))
            assert S_t.intersection(interval), f"Property 1 violated at t={t}, s={s}"

def verify_property_2(alive_sets):
    for t, S_t in enumerate(alive_sets, start=1):
        if t == 1:
            assert len(S_t) == 1, f"Property 2 violated at t=1: |S_t| = {len(S_t)} != 1"
        else:
            bound = 2 * math.log2(t) + 3
            assert len(S_t) <= bound, f"Property 2 violated at t={t}: |S_t| = {len(S_t)} > {bound}"

def verify_property_3(alive_sets):
    for t in range(2, len(alive_sets) + 1):
        difference = alive_sets[t - 1] - alive_sets[t - 2]
        assert difference == {t}, f"Property 3 violated at t={t}"

def prove_logarithmic_size(alive_sets):
    for t, S_t in enumerate(alive_sets, start=1):
        count = 0
        for k in range(math.floor(math.log2(t)) + 1):
            lower = max(1, t - 2 ** (k + 2) - 1)
            upper = t
            count_k = sum(
                1 for i in range(lower, upper + 1) if i % 2**k == 0 and i // 2**k % 2 == 1 and is_alive(i, t)
            )
            count += count_k
        assert count == len(S_t), f"Logarithmic size proof failed at t={t}"

def compare_alive_sets(alive_sets1, alive_sets2):
    for t, (S1, S2) in enumerate(zip(alive_sets1, alive_sets2), start=1):
        if S1 != S2:
            print(f"Difference at time t={t}:")
            print(f"S1: {sorted(S1)}")
            print(f"S2: {sorted(S2)}")
            return False
    return True

# Set the number of time steps to simulate
T = 8192

# Run both pruning algorithms
alive_sets_naive = naive_pruning(T)
alive_sets_pruned = pruning(T)

# Compare the results
if compare_alive_sets(alive_sets_naive, alive_sets_pruned):
    print("Alive sets match between naive_pruning and pruning methods.")
else:
    print("Alive sets do not match between methods.")

# You can choose which alive_sets to use for printing and verification
alive_sets = alive_sets_pruned  # or alive_sets_naive

# Print the table
#print_table(alive_sets)  # Uncomment to print the table

# Verify properties
verify_property_1(alive_sets)
verify_property_2(alive_sets)
verify_property_3(alive_sets)

# Prove logarithmic size
prove_logarithmic_size(alive_sets)

print("All properties verified and proofs passed!")


Alive sets match between naive_pruning and pruning methods.
