In [32]:
from utils import read_lines
from itertools import permutations
from collections import defaultdict

def parse_program(s):
    return [int(x) for x in s.split(',')]

def parse_code(num):
    if num == 99:
        return num, [0]
    else:
        digits = []
        while num > 0:
            digits.append(num % 10)
            num //=10
        digits += [0] * (5-len(digits))
        opcode = digits[0]
        param_modes = digits[2:]
        return opcode, param_modes

def run(program, input=0):
    i = 0
    output = 0
    extend_memory = defaultdict(int)
    relative_base = 0
    def read_memory(addr):
        if addr < len(program):
            return program[addr]
        else:
            return extend_memory[addr]
        
    def write_memory(idx, value, mode):
        addr = program[idx]
        if mode == 2:
            addr = relative_base + addr
        if addr < len(program):
            program[addr] = value
        else:
            extend_memory[addr] = value
    
    def get_oprand(idx, mode):
        if mode == 1:
            return program[idx]
        if mode == 2:
            addr = relative_base + program[idx]
        else:
            addr = program[idx]
        return read_memory(addr)

    while i < len(program):
        opcode, param_modes = parse_code(program[i])
        if opcode in (1, 2, 5, 6, 7, 8):
            oprand1 = get_oprand(i+1, param_modes[0])
            oprand2 = get_oprand(i+2, param_modes[1])
    
            if opcode == 1:
                write_memory(i+3, oprand1 + oprand2, param_modes[2])
                i+=4
            elif opcode == 2:
                write_memory(i+3, oprand1 * oprand2, param_modes[2])
                i+=4
            elif opcode == 5:
                if oprand1 != 0:
                    i = oprand2
                else:
                    i += 3
            elif opcode == 6:
                if oprand1 == 0:
                    i = oprand2
                else:
                    i += 3
            elif opcode == 7:
                write_memory(i+3, 1 if oprand1 < oprand2 else 0, param_modes[2])
                i+=4
            elif opcode == 8:
                write_memory(i+3, 1 if oprand1 == oprand2 else 0, param_modes[2])
                i+=4
        elif opcode ==3 :
            write_memory(i+1, input, param_modes[0])
            i += 2
        elif opcode == 4:
            oprand1 = get_oprand(i+1, param_modes[0])
            output = oprand1
            print(output)
            i += 2
        elif opcode == 9:
            oprand1 = get_oprand(i+1, param_modes[0])
            relative_base += oprand1
            i+=2
        elif program[i] == 99:
            return output
        else:
            raise ValueError(f'illegal code {i} {program[i]}')
        
def part1(input_file):
    s = read_lines(input_file)[0]
    program = parse_program(s)
    return run(program, 1)

def part2(input_file):
    s = read_lines(input_file)[0]
    program = parse_program(s)
    return run(program, 2)

In [31]:
part1('inputs/day9.txt')

2870072642


2870072642

In [33]:
part2('inputs/day9.txt')

58534


58534

In [24]:
s = '109,1,204,-1,1001,100,1,100,1008,100,16,101,1006,101,0,99'
program = parse_program(s)
assert run(program) == 99

s = '1102,34915192,34915192,7,4,7,99,0'
program = parse_program(s)
assert len(str(run(program))) == 16

109
1
204
-1
1001
100
1
100
1008
100
16
101
1006
101
0
99
1219070632396864


In [27]:
s = '3,9,8,9,10,9,4,9,99,-1,8'
program = parse_program(s)
assert run(program.copy(), 8) == 1
assert run(program, 9) == 0

s = '3,9,7,9,10,9,4,9,99,-1,8'
program = parse_program(s)
assert run(program.copy(), 7) == 1
assert run(program, 8) == 0

s = '3,3,1108,-1,8,3,4,3,99'
program = parse_program(s)
assert run(program.copy(), 8) == 1
assert run(program, 9) == 0

s = '3,3,1107,-1,8,3,4,3,99'
program = parse_program(s)
assert run(program.copy(), 7) == 1
assert run(program, 8) == 0

s = '3,12,6,12,15,1,13,14,13,4,13,99,-1,0,1,9'
program = parse_program(s)
assert run(program.copy(), 0) == 0
assert run(program, 1) == 1

s = '3,3,1105,-1,9,1101,0,0,12,4,12,99,1'
program = parse_program(s)
assert run(program.copy(), 0) == 0
assert run(program, 1) == 1

s = '3,21,1008,21,8,20,1005,20,22,107,8,21,20,1006,20,31,1106,0,36,98,0,0,1002,21,125,20,4,20,1105,1,46,104,999,1105,1,46,1101,1000,1,20,4,20,1105,1,46,98,99'
program = parse_program(s)
assert run(program.copy(), 7) == 999
assert run(program.copy(), 8) == 1000
assert run(program.copy(), 9) == 1001

1
0
1
0
1
0
1
0
0
1
0
1
999
1000
1001
