In [50]:
from utils import read_lines

def parse_program(s):
    return [int(x) for x in s.split(',')]

def parse_code(num):
    if num == 99:
        return num, [0]
    else:
        digits = []
        while num > 0:
            digits.append(num % 10)
            num //=10
        digits += [0] * (5-len(digits))
        opcode = digits[0]
        param_modes = digits[2:]
        return opcode, param_modes

def run(program, initial_input):
    i = 0
    input = initial_input
    output = 0
    while i < len(program):
        opcode, param_modes = parse_code(program[i])
        if opcode == 1:
            if param_modes[0] == 0:
                oprand1 = program[program[i+1]]
            else:
                oprand1 = program[i+1]
            if param_modes[1] == 0:
                oprand2 = program[program[i+2]]
            else:
                oprand2 = program[i+2]
            program[program[i+3]] = oprand1 + oprand2
            i += 4
        elif opcode == 2:
            if param_modes[0] == 0:
                oprand1 = program[program[i+1]]
            else:
                oprand1 = program[i+1]
            if param_modes[1] == 0:
                oprand2 = program[program[i+2]]
            else:
                oprand2 = program[i+2]
            program[program[i+3]] = oprand1 * oprand2
            i+=4
        elif opcode == 3:
            program[program[i+1]] = input
            i += 2
        elif opcode == 4:
            output = program[program[i+1]]
            i+=2
        elif program[i] == 99:
            return output
        else:
            raise ValueError(f'illegal code {i} {program[i]}')

def part1(input_file):
    s = read_lines(input_file)[0]
    program = parse_program(s)
    return run(program, 1)

def run2(program, initial_input):
    i = 0
    input = initial_input
    output = 0
    while i < len(program):
        opcode, param_modes = parse_code(program[i])
        if opcode in (1, 2, 5, 6, 7, 8):
            if param_modes[0] == 0:
                oprand1 = program[program[i+1]]
            else:
                oprand1 = program[i+1]
            if param_modes[1] == 0:
                oprand2 = program[program[i+2]]
            else:
                oprand2 = program[i+2]
            if opcode == 1:
                program[program[i+3]] = oprand1 + oprand2
                i+=4
            elif opcode == 2:
                program[program[i+3]] = oprand1 * oprand2
                i+=4
            elif opcode == 5:
                if oprand1 != 0:
                    i = oprand2
                else:
                    i += 3
            elif opcode == 6:
                if oprand1 == 0:
                    i = oprand2
                else:
                    i += 3
            elif opcode == 7:
                program[program[i+3]] = 1 if oprand1 < oprand2 else 0
                i+=4
            elif opcode == 8:
                program[program[i+3]] = 1 if oprand1 == oprand2 else 0
                i+=4
        elif opcode ==3 :
            program[program[i+1]] = input
            i += 2
        elif opcode == 4:
            if param_modes[0] == 0:
                oprand1 = program[program[i+1]]
            else:
                oprand1 = program[i+1]
            output = oprand1
            i += 2
        elif program[i] == 99:
            return output
        else:
            raise ValueError(f'illegal code {i} {program[i]}')

def part2(input_file):
    s = read_lines(input_file)[0]
    program = parse_program(s)
    return run2(program, 5)

In [69]:
part1('inputs/day5.txt')

16225258

In [70]:
part2('inputs/day5.txt')

2808771

In [71]:
s = '3,9,8,9,10,9,4,9,99,-1,8'
program = parse_program(s)
assert run2(program.copy(), 8) == 1
assert run2(program, 9) == 0

s = '3,9,7,9,10,9,4,9,99,-1,8'
program = parse_program(s)
assert run2(program.copy(), 7) == 1
assert run2(program, 8) == 0

s = '3,3,1108,-1,8,3,4,3,99'
program = parse_program(s)
assert run2(program.copy(), 8) == 1
assert run2(program, 9) == 0

s = '3,3,1107,-1,8,3,4,3,99'
program = parse_program(s)
assert run2(program.copy(), 7) == 1
assert run2(program, 8) == 0

In [72]:
s = '3,12,6,12,15,1,13,14,13,4,13,99,-1,0,1,9'
program = parse_program(s)
assert run2(program.copy(), 0) == 0
assert run2(program, 1) == 1

s = '3,3,1105,-1,9,1101,0,0,12,4,12,99,1'
program = parse_program(s)
assert run2(program.copy(), 0) == 0
assert run2(program, 1) == 1

In [73]:
s = '3,21,1008,21,8,20,1005,20,22,107,8,21,20,1006,20,31,1106,0,36,98,0,0,1002,21,125,20,4,20,1105,1,46,104,999,1105,1,46,1101,1000,1,20,4,20,1105,1,46,98,99'
program = parse_program(s)
assert run2(program.copy(), 7) == 999
assert run2(program.copy(), 8) == 1000
assert run2(program.copy(), 9) == 1001