In [37]:
from utils import read_lines
from itertools import permutations

def parse_program(s):
    return [int(x) for x in s.split(',')]

def parse_code(num):
    if num == 99:
        return num, [0]
    else:
        digits = []
        while num > 0:
            digits.append(num % 10)
            num //=10
        digits += [0] * (5-len(digits))
        opcode = digits[0]
        param_modes = digits[2:]
        return opcode, param_modes

def run(program, phase, input):
    i = 0
    first_input = True
    output = 0
    while i < len(program):
        opcode, param_modes = parse_code(program[i])
        if opcode in (1, 2, 5, 6, 7, 8):
            if param_modes[0] == 0:
                oprand1 = program[program[i+1]]
            else:
                oprand1 = program[i+1]
            if param_modes[1] == 0:
                oprand2 = program[program[i+2]]
            else:
                oprand2 = program[i+2]
            if opcode == 1:
                program[program[i+3]] = oprand1 + oprand2
                i+=4
            elif opcode == 2:
                program[program[i+3]] = oprand1 * oprand2
                i+=4
            elif opcode == 5:
                if oprand1 != 0:
                    i = oprand2
                else:
                    i += 3
            elif opcode == 6:
                if oprand1 == 0:
                    i = oprand2
                else:
                    i += 3
            elif opcode == 7:
                program[program[i+3]] = 1 if oprand1 < oprand2 else 0
                i+=4
            elif opcode == 8:
                program[program[i+3]] = 1 if oprand1 == oprand2 else 0
                i+=4
        elif opcode ==3 :
            if first_input:
                program[program[i+1]] = phase
                first_input = False
            else:
                program[program[i+1]] = input
            i += 2
        elif opcode == 4:
            if param_modes[0] == 0:
                oprand1 = program[program[i+1]]
            else:
                oprand1 = program[i+1]
            output = oprand1
            i += 2
        elif program[i] == 99:
            return output
        else:
            raise ValueError(f'illegal code {i} {program[i]}')

def run_combination(program, phases):
    input = 0
    for phase in phases:
        input = run(program.copy(), phase, input)
    return input

def part1(input_file):
    s = read_lines(input_file)[0]
    program = parse_program(s)
    ans = 0
    for phases in permutations(range(5), 5):
        ans = max(ans, run_combination(program, phases))
    return ans

def run2(program, phases):
    programs = [program.copy() for _ in range(5)]
    program = programs[0]
    program_pointers = [0] * 5
    i = 0
    input = 0
    last_ouput_5 = 0
    phase_idx = 0
    first_inputs = [True] * 5
    while i < len(program):
        opcode, param_modes = parse_code(program[i])
        if opcode in (1, 2, 5, 6, 7, 8):
            if param_modes[0] == 0:
                oprand1 = program[program[i+1]]
            else:
                oprand1 = program[i+1]
            if param_modes[1] == 0:
                oprand2 = program[program[i+2]]
            else:
                oprand2 = program[i+2]
            if opcode == 1:
                program[program[i+3]] = oprand1 + oprand2
                i+=4
            elif opcode == 2:
                program[program[i+3]] = oprand1 * oprand2
                i+=4
            elif opcode == 5:
                if oprand1 != 0:
                    i = oprand2
                else:
                    i += 3
            elif opcode == 6:
                if oprand1 == 0:
                    i = oprand2
                else:
                    i += 3
            elif opcode == 7:
                program[program[i+3]] = 1 if oprand1 < oprand2 else 0
                i+=4
            elif opcode == 8:
                program[program[i+3]] = 1 if oprand1 == oprand2 else 0
                i+=4
        elif opcode ==3 :
            if first_inputs[phase_idx]:
                program[program[i+1]] = phases[phase_idx]
                first_inputs[phase_idx] = False
            else:
                program[program[i+1]] = input
            i += 2
        elif opcode == 4:
            if param_modes[0] == 0:
                oprand1 = program[program[i+1]]
            else:
                oprand1 = program[i+1]
            input = oprand1
            if phase_idx == 4:
                last_ouput_5 = input
            i += 2
            program_pointers[phase_idx] = i
            phase_idx = (phase_idx + 1) % len(phases)
            program = programs[phase_idx]
            i = program_pointers[phase_idx]
            
        elif program[i] == 99:
            return last_ouput_5
        else:
            raise ValueError(f'illegal code {i} {program[i]}')

def part2(input_file):
    s = read_lines(input_file)[0]
    program = parse_program(s)
    ans = 0
    for phases in permutations(range(5, 10), 5):
        ans = max(ans, run2(program, phases))
    return ans

In [38]:
part1('inputs/day7.txt')

24405

In [39]:
part2('inputs/day7.txt')

8271623

In [6]:
s = '3,15,3,16,1002,16,10,16,1,16,15,15,4,15,99,0,0'
program = parse_program(s)
assert run_combination(program, [4,3,2,1,0]) == 43210

s = '3,23,3,24,1002,24,10,24,1002,23,-1,23,101,5,23,23,1,24,23,23,4,23,99,0,0'
program = parse_program(s)
assert run_combination(program, [0, 1, 2, 3, 4]) == 54321

s = '3,31,3,32,1002,32,10,32,1001,31,-2,31,1007,31,0,33,1002,33,7,33,1,33,31,31,1,32,31,31,4,31,99,0,0,0'
program = parse_program(s)
assert run_combination(program, [1, 0, 4, 3, 2]) == 65210


In [28]:
s = '3,26,1001,26,-4,26,3,27,1002,27,2,27,1,27,26,27,4,27,1001,28,-1,28,1005,28,6,99,0,0,5'
program = parse_program(s)
assert run2(program, [9,8,7,6,5]) == 139629729

s = '3,52,1001,52,-5,52,3,53,1,52,56,54,1007,54,5,55,1005,55,26,1001,54,-5,54,1105,1,12,1,53,54,53,1008,54,0,55,1001,55,1,55,2,53,55,53,4,53,1001,56,-1,56,1005,56,6,99,0,0,0,0,10'
program = parse_program(s)
assert run2(program, [9,7,8,5,6]) == 18216