In [166]:
data = """11-22,95-115,998-1012,1188511880-1188511890,222220-222224,1698522-1698528,446443-446449,38593856-38593862,565653-565659,824824821-824824827,2121212118-2121212124"""

with open("input.txt", "r") as f:
    data = f.read().strip()

ranges = [tuple(map(int, p.split("-"))) for p in data.split(",")]

In [None]:
from math import log10, floor

def num_digits(x: int) -> int:
    return floor(log10(x)) + 1

43952536386


In [248]:
# Simpler O(n) version:

def sum_invalid_ids(low: int, high: int, part1: bool = True) -> int:
    L = num_digits(low)
    if num_digits(high) > L:
        return sum_invalid_ids(low, 10**L - 1, part1) + sum_invalid_ids(
            10**L, high, part1
        )

    if part1:
        if L % 2 != 0:
            return 0
        ll = [L // 2]
    else:
        ll = {l for l in range(1, L // 2 + 1) if L % l == 0}
        # Optional: remove lengths that are multiples of others
        # ll = {l for l in ll if all(l*i not in ll for i in range(2, L // l + 1))}

    total = 0
    for l in ll:
        low_q = low // 10 ** (L - l)
        high_q = high // 10 ** (L - l)
        for n in range(low_q, high_q + 1):
            N = sum(n * 10 ** (p * l) for p in range(L // l))
            if low <= N <= high:
                if N in invalid_ids:
                    # Avoid double counting on things like 42424242
                    continue
                invalid_ids.add(N)
                total += N

    return total


invalid_ids = set()
part1 = sum(sum_invalid_ids(*r, part1=True) for r in ranges)

invalid_ids = set()
part2 = sum(sum_invalid_ids(*r, part1=False) for r in ranges)

# print("Part 1:", part1)
# print("Part 2:", part2)

debug = part2
part2

54486209192

In [None]:

# Optimized O(log n) version for part 2:

def sum_invalid_ids_pt2(low: int, high: int) -> int:
    
    L = num_digits(low)
    if num_digits(high) > L:
        return sum_invalid_ids_pt2(low, 10**L - 1) + sum_invalid_ids_pt2(10**L, high)
    if L == 1:
        return 0
    
    # Add primes, subtract multiples of primes
    ll = {l for l in range(1, max(1, L // 2 + 1)) if (L % l) == 0}
    ll = {l for l in ll if all(l*i not in ll for i in range(2, max(ll)+1))}

    total = 0
    for l in ll:
        low_q = low // 10 ** (L - l)
        high_q = high // 10 ** (L - l)

        # Consider sequence of values repeated patterns based of low_q ... high_q:
        
        # Treat the first and last values specially
        first = sum(low_q * 10 ** (p * l) for p in range(L // l))
        if low <= first <= high:
            total += first
            
        if low_q < high_q:
            last = sum(high_q * 10 ** (p * l) for p in range(L // l))
            if low <= last <= high:
                total += last
        
            if low_q < high_q - 1:
                # Calculate the sum of the middle values using arithmetic series
                S = ((high_q - low_q - 1) * (high_q + low_q) // 2)
                total += sum(S * 10 ** (p * l) for p in range(L // l))
                
    # treat same-digit numbers (present in all pattern lengths)
    for d in range(1, 10):
        if low <= int(str(d) * L) <= high:
            total -= int(str(d) * L) * (len(ll) - 1)
    

    return total

part2 = sum(sum_invalid_ids_pt2(*r) for r in ranges)
# print(part2)

assert part2 == 54486209192
