In [31]:
from utils import read_lines
from collections import defaultdict, deque

def parse_program(s):
    return [int(x) for x in s.split(',')]

def parse_code(num):
    if num == 99:
        return num, [0]
    else:
        digits = []
        while num > 0:
            digits.append(num % 10)
            num //=10
        digits += [0] * (5-len(digits))
        opcode = digits[0]
        param_modes = digits[2:]
        return opcode, param_modes

def run(program, input):
    i = 0
    extend_memory = defaultdict(int)
    relative_base = 0
    def read_memory(addr):
        if addr < len(program):
            return program[addr]
        else:
            return extend_memory[addr]
        
    def write_memory(idx, value, mode):
        addr = program[idx]
        if mode == 2:
            addr = relative_base + addr
        if addr < len(program):
            program[addr] = value
        else:
            extend_memory[addr] = value
    
    def get_oprand(idx, mode):
        if mode == 1:
            return program[idx]
        if mode == 2:
            addr = relative_base + program[idx]
        else:
            addr = program[idx]
        return read_memory(addr)

    while i < len(program):
        opcode, param_modes = parse_code(program[i])
        if opcode in (1, 2, 5, 6, 7, 8):
            oprand1 = get_oprand(i+1, param_modes[0])
            oprand2 = get_oprand(i+2, param_modes[1])
    
            if opcode == 1:
                write_memory(i+3, oprand1 + oprand2, param_modes[2])
                i+=4
            elif opcode == 2:
                write_memory(i+3, oprand1 * oprand2, param_modes[2])
                i+=4
            elif opcode == 5:
                if oprand1 != 0:
                    i = oprand2
                else:
                    i += 3
            elif opcode == 6:
                if oprand1 == 0:
                    i = oprand2
                else:
                    i += 3
            elif opcode == 7:
                write_memory(i+3, 1 if oprand1 < oprand2 else 0, param_modes[2])
                i+=4
            elif opcode == 8:
                write_memory(i+3, 1 if oprand1 == oprand2 else 0, param_modes[2])
                i+=4
        elif opcode ==3 :
            write_memory(i+1, input, param_modes[0])
            i += 2
        elif opcode == 4:
            oprand1 = get_oprand(i+1, param_modes[0])
            output = oprand1
            i += 2
            return output
        elif opcode == 9:
            oprand1 = get_oprand(i+1, param_modes[0])
            relative_base += oprand1
            i+=2
        elif program[i] == 99:
            raise ValueError('program end')
        else:
            raise ValueError(f'illegal code {i} {program[i]}')

deltas = {
    1: (-1, 0), 
    2: (1, 0),
    3: (0, -1),
    4: (0, 1),
}
reverse_command = {
    1: 2,
    2: 1,
    3: 4,
    4: 3,
}
def part1(input_file):
    s = read_lines(input_file)[0]
    program = parse_program(s)
    min_steps = {(0, 0): 0}
    board = {(0, 0): '.'}
    path = [(0, 0)]
    target_pos = None
    

    def backtrack():
        nonlocal target_pos
        cur_pos = path[-1]
        cur_step = len(path)
        for command, delta in deltas.items():
            new_pos = (cur_pos[0] + delta[0], cur_pos[1] + delta[1])
            if (new_pos not in min_steps or min_steps[new_pos] > cur_step) and (new_pos not in board or board[new_pos] != '#'):
                output = run(program, command)
                match output:
                    case 0:
                        board[new_pos] = '#' # wall
                    case 1:
                        board[new_pos] = '.'
                        min_steps[new_pos] = cur_step
                        path.append(new_pos)
                        backtrack()
                        path.pop()
                        run(program, reverse_command[command])
                    case 2:
                        if not target_pos:
                            target_pos = new_pos
                        min_steps[new_pos] = cur_step
                        run(program, reverse_command[command])
    backtrack()
    return min_steps[target_pos], board, target_pos

def part2(input_file):
    _, board, start = part1(input_file)
    q = deque([start])
    visisted = set([start])
    step = 0
    while q:
        cur_len = len(q)
        for _ in range(cur_len):
            pos = q.popleft()
            for delta in deltas.values():
                new_pos = (pos[0] + delta[0], pos[1] + delta[1])
                if new_pos not in visisted and new_pos in board and board[new_pos] == '.':
                    visisted.add(new_pos)
                    q.append(new_pos)
        step += 1
    return step - 1

In [28]:
min_step, _, _ = part1('inputs/day15.txt')
print(min_step)

236


In [32]:
part2('inputs/day15.txt')

368