In [50]:
from utils import read_lines
import numpy as np

def read_input(input_file):
    lines = read_lines(input_file)
    ts = int(lines[0])
    buses = [int(x) for x in lines[1].split(',') if x != 'x']
    return ts, buses

def part1(input_file):
    ts, buses = read_input(input_file)
    min_wait = ts
    ans = 0
    for bid in buses:
        d = ts // bid
        wait = (d + 1) * bid - ts
        if wait < min_wait:
            min_wait = wait
            ans = wait * bid
    return ans

def parse_remainders(nums):
    remainders = {}
    for i, num in enumerate(nums.split(',')):
        if num == 'x': 
            continue
        num = int(num)
        rem = i % num
        if rem > 0:
            remainders[num] = num - i % num
        else:
            remainders[num] = 0
    return remainders

# chinese remainder theorem
def solve(remainders):
    mods = list(remainders)
    prod = np.prod(mods)
    ans = 0
    for mod, rem in remainders.items():
        aa = prod // mod
        a = aa
        while a % mod != 1:
            a += aa
        ans += a * rem
    
    ans %= prod
    return ans

def check(remainders, num):
    for mod, rem in remainders.items():
        if num % mod != rem:
            print(f'{num} % {mod} = {num % mod} != {rem}')
            return False
    return True

def part2(input_file):
    lines = read_lines(input_file)
    remainders = parse_remainders(lines[1])
    return solve(remainders)

In [7]:
part1('inputs/day13_test.txt')

295

In [8]:
part1('inputs/day13.txt')

153

In [51]:
part2('inputs/day13.txt')

471793476184394

In [49]:
remainders = parse_remainders('7,13,x,x,59,x,31,19')
print(remainders)
s = solve(remainders)
print(s)
check(remainders, s)

{7: 0, 13: 12, 59: 55, 31: 25, 19: 12}
1068781


True