In [166]:
data = """11-22,95-115,998-1012,1188511880-1188511890,222220-222224,1698522-1698528,446443-446449,38593856-38593862,565653-565659,824824821-824824827,2121212118-2121212124"""

with open("input.txt", "r") as f:
    data = f.read().strip()

ranges = [tuple(map(int, p.split("-"))) for p in data.split(",")]

In [167]:
from math import log10, floor

def num_digits(x: int) -> int:
    return floor(log10(x)) + 1

def sum_invalid_ids_pt1(low : int, high : int) -> int:
    L = num_digits(low)
    if num_digits(high) > L:
        return sum_invalid_ids_pt1(low, 10**L - 1) + sum_invalid_ids_pt1(10**L, high)
    if L % 2 == 1:
        return 0
    
    total = 0
    low_half = int(low // 10**(L/2))
    high_half = int(high // 10**(L/2))
    for n in range(low_half, high_half+1):
        N = int(n * 10**(L/2)) + n
        if low <= N <= high:
            total += N

    return total

print(sum(sum_invalid_ids_pt1(*r) for r in ranges))


43952536386


In [234]:
%%timeit

def sum_invalid_ids_pt2(low: int, high: int) -> int:
    
    L = num_digits(low)
    if num_digits(high) > L:
        return sum_invalid_ids_pt2(low, 10**L - 1) + sum_invalid_ids_pt2(10**L, high)
    if L == 1:
        return 0
    
    # Look only at pattern lengths l that divide L and do not have a multiple (duplicates):
    ll = {l for l in range(1, max(1, L // 2 + 1)) if (L % l) == 0}
    ll = {l for l in ll if all(l*i not in ll for i in range(2, max(ll)+1))}

    total = 0
    for l in ll:
        low_q = low // 10 ** (L - l)
        high_q = high // 10 ** (L - l)

        # Consider sequence of values repeated patterns based of low_q ... high_q:
        
        # Treat the first and last values specially
        first = sum(low_q * 10 ** (p * l) for p in range(L // l))
        if low <= first <= high:
            total += first
            
        if low_q < high_q:
            last = sum(high_q * 10 ** (p * l) for p in range(L // l))
            if low <= last <= high:
                total += last
        
            if low_q < high_q - 1:
                # Calculate the sum of the middle values using arithmetic series
                S = ((high_q - low_q - 1) * (high_q + low_q) // 2)
                total += sum(S * 10 ** (p * l) for p in range(L // l))
                
    # treat same-digit numbers (present in all pattern lengths)
    for d in range(1, 10):
        if low <= int(str(d) * L) <= high:
            total -= int(str(d) * L) * (len(ll) - 1)
    

    return total

part2 = sum(sum_invalid_ids_pt2(*r) for r in ranges)
# print(part2)


117 μs ± 1.17 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
%%timeit

# Simpler O(n) version:

invalid_ids = set()

def sum_invalid_ids_pt2(low: int, high: int) -> int:
    L = num_digits(low)
    if num_digits(high) > L:
        return sum_invalid_ids_pt2(low, 10**L - 1) + sum_invalid_ids_pt2(10**L, high)

    total = 0
    for l in range(1, L // 2 + 1):
        if L % l != 0:
            continue
        low_q = low // 10 ** (L - l)
        high_q = high // 10 ** (L - l)
        for n in range(low_q, high_q + 1):
            N = sum(n * 10 ** (p * l) for p in range(L // l))
            if low <= N <= high:
                if N in invalid_ids:
                    # Avoid double counting on things like 42424242
                    continue
                invalid_ids.add(N)
                total += N

    return total

# print(sum(sum_invalid_ids_pt2(*r) for r in ranges))

73.2 ns ± 0.424 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
