# Problem 217: Balanced Numbers

A positive integer with $k$ (decimal) digits is called balanced if its first $\lceil k/2 \rceil$ digits sum to the same value as its last $\lceil k/2 \rceil$ digits, where $\lceil x \rceil$, pronounced <i>ceiling</i> of $x$, is the smallest integer $\ge x$, thus $\lceil \pi \rceil = 4$ and $\lceil 5 \rceil = 5$.

So, for example, all palindromes are balanced, as is $13722$.

Let $T(n)$ be the sum of all balanced numbers less than $10^n$.\
Thus: $T(1) = 45$, $T(2) = 540$ and $T(5) = 334795890$.

Find $T(47) \bmod 3^{15}$.

In [1]:
# ith digit of n, indexed from the last digit
from math import ceil, log10


def d(n: int, i: int) -> int:
    if n == 0:
        return 0
    return (n % 10 ** (i + 1) - n % 10**i) // 10**i


def digit_sum(n: int) -> int:
    if n == 0:
        return 0
    return sum([d(n, i) for i in range(int(log10(n)) + 1)])


def is_balanced(n: int) -> bool:
    num_digits = int(log10(n)) + 1
    _part1 = 0
    _part2 = 0
    for i in range(num_digits):
        digit = d(n, i)
        _part1 += digit if i <= ceil(num_digits // 2) - 1 else 0
        _part2 += digit if i >= num_digits - ceil(num_digits // 2) else 0

    return _part1 == _part2


# sanity check
print("Digits of 394002:", [d(394002, i) for i in range(6)])
print("Digit sum of 394002:", digit_sum(394002))
print("is_balanced(5):", is_balanced(5))
print("is_balanced(11):", is_balanced(11))
print("is_balanced(141):", is_balanced(141))
print("is_balanced(2314):", is_balanced(2314))
print("is_balanced(13722):", is_balanced(13722))

Digits of 394002: [2, 0, 0, 4, 9, 3]
Digit sum of 394002: 18
is_balanced(5): True
is_balanced(11): True
is_balanced(141): True
is_balanced(2314): True
is_balanced(13722): True


In [2]:
# sum of all balanced number < 10**n (brute force)
def T_bf(n: int) -> int:
    _total = 0
    for _n in range(1, 10**n):
        if is_balanced(_n):
            _total += _n
    return _total


print("T_bf(1) = ", T_bf(1))
print("T_bf(2) = ", T_bf(2))
print("T_bf(3) = ", T_bf(3))
print("T_bf(4) = ", T_bf(4))
print("T_bf(5) = ", T_bf(5))
print("T_bf(6) = ", T_bf(6))

T_bf(1) =  45
T_bf(2) =  540
T_bf(3) =  50040
T_bf(4) =  3364890
T_bf(5) =  334795890
T_bf(6) =  27671338200


In [3]:
# sum of all numbers < 10**dmax where the digits sum to sdigit (brute foce)
def bf_s(sdigit: int, dmax: int) -> int:
    if 0 < sdigit <= 9 * dmax:
        return sum([n for n in range(1, 10**dmax) if digit_sum(n) == sdigit])
    else:
        return 0


print("bf_s(1, 1)", bf_s(1, 1))
print("bf_s(8, 1)", bf_s(8, 1))
print("bf_s(2, 2)", bf_s(2, 2))
print("bf_s(2, 5)", bf_s(2, 5))
print("bf_s(3, 3)", bf_s(3, 3))

bf_s(1, 1) 1
bf_s(8, 1) 8
bf_s(2, 2) 33
bf_s(2, 5) 66666
bf_s(3, 3) 1110


In [4]:
from functools import cache
from typing import Tuple


# sum of all numbers < 10**dmax where the digits sum to sdigit
# also returns how many integers make up the sum
@cache
def s(sdigit: int, ndigit: int) -> Tuple[int, int]:
    if ndigit == 0:
        return 1, 0
    _multiplicity, _sum = 0, 0
    for digit in range(10):
        match sdigit - digit:
            case ds if ds > 0 and ndigit - 1 == 0:
                continue
            case ds if ds >= 0:
                _m, _s = s(sdigit - digit, ndigit - 1)
                _sum += _m * digit * 10 ** (ndigit - 1) + _s
                _multiplicity += _m
    return _multiplicity, _sum


s.cache_clear()
print("s(1, 1)", s(1, 1))
print("s(8, 1)", s(8, 1))
print("s(2, 2)", s(2, 2))
print("s(2, 5)", s(2, 5))
print("s(3, 3)", s(3, 3))
print("s(100, 20)", s(100, 20))

s(1, 1) (1, 1)
s(8, 1) (1, 8)
s(2, 2) (3, 33)
s(2, 5) (15, 66666)
s(3, 3) (10, 1110)
s(100, 20) (2295217152050316613, 127512064002795367387613768248860935215)


In [5]:
# sum of all balanced numbers with ndigit digits
def t(ndigit: int) -> int:
    if ndigit == 1:
        return 45
    _ndigit = ndigit // 2
    _total = 0
    for _digit_sum in range(1, 9 * _ndigit + 1):
        m_post, s_post = s(_digit_sum, _ndigit)
        (_m, _s) = s(_digit_sum, _ndigit - 1) if _ndigit > 1 else (0, 0)
        m_pre, s_pre = m_post - _m, s_post - _s
        if ndigit % 2 == 0:
            _total += m_post * 10**_ndigit * s_pre + m_pre * s_post
        else:
            _total += (
                m_post * 10**_ndigit * (100 * s_pre + m_pre * 45) + 10 * m_pre * s_post
            )
    return _total


def T(dmax: int) -> int:
    return sum([t(_d) for _d in range(1, dmax + 1)])


# sanity checks
print("T(1) = ", T(1))
print("T(2) = ", T(2))
print("T(3) = ", T(3))
print("T(4) = ", T(4))
print("T(5) = ", T(5))
print("T(6) = ", T(6))
print("T(47) = ", T(47))
print("T(47) mod 3**15 = ", T(47) % pow(3, 15))

T(1) =  45
T(2) =  540
T(3) =  50040
T(4) =  3364890
T(5) =  334795890
T(6) =  27671338200
T(47) =  102046816840408378801546174937467503924727519749183651680790418160257686837140983269435948860
T(47) mod 3**15 =  6273134
