In [116]:
from utils import read_lines
from collections import defaultdict

def parse_program(s):
    return [int(x) for x in s.split(',')]

def parse_code(num):
    if num == 99:
        return num, [0]
    else:
        digits = []
        while num > 0:
            digits.append(num % 10)
            num //=10
        digits += [0] * (5-len(digits))
        opcode = digits[0]
        param_modes = digits[2:]
        return opcode, param_modes

def run(program, inputs=[]):
    i = 0
    outputs = []
    extend_memory = defaultdict(int)
    relative_base = 0
    input_pos = 0
    def read_memory(addr):
        if addr < len(program):
            return program[addr]
        else:
            return extend_memory[addr]
        
    def write_memory(idx, value, mode):
        addr = program[idx]
        if mode == 2:
            addr = relative_base + addr
        if addr < len(program):
            program[addr] = value
        else:
            extend_memory[addr] = value
    
    def get_oprand(idx, mode):
        if mode == 1:
            return program[idx]
        if mode == 2:
            addr = relative_base + program[idx]
        else:
            addr = program[idx]
        return read_memory(addr)

    while i < len(program):
        opcode, param_modes = parse_code(program[i])
        if opcode in (1, 2, 5, 6, 7, 8):
            oprand1 = get_oprand(i+1, param_modes[0])
            oprand2 = get_oprand(i+2, param_modes[1])
    
            if opcode == 1:
                write_memory(i+3, oprand1 + oprand2, param_modes[2])
                i+=4
            elif opcode == 2:
                write_memory(i+3, oprand1 * oprand2, param_modes[2])
                i+=4
            elif opcode == 5:
                if oprand1 != 0:
                    i = oprand2
                else:
                    i += 3
            elif opcode == 6:
                if oprand1 == 0:
                    i = oprand2
                else:
                    i += 3
            elif opcode == 7:
                write_memory(i+3, 1 if oprand1 < oprand2 else 0, param_modes[2])
                i+=4
            elif opcode == 8:
                write_memory(i+3, 1 if oprand1 == oprand2 else 0, param_modes[2])
                i+=4
        elif opcode ==3 :
            input = inputs[input_pos]
            input_pos += 1
            write_memory(i+1, input, param_modes[0])
            i += 2
        elif opcode == 4:
            oprand1 = get_oprand(i+1, param_modes[0])
            i += 2
            outputs.append(oprand1)
        elif opcode == 9:
            oprand1 = get_oprand(i+1, param_modes[0])
            relative_base += oprand1
            i+=2
        elif program[i] == 99:
            return outputs
        else:
            raise ValueError(f'illegal code {i} {program[i]}')

def print_matrix(matrix):
    for row in matrix:
        print(''.join(row))
    
def part1(input_file):
    s = read_lines(input_file)[0]
    program = parse_program(s)
    matrix = []
    row = []
    outputs = run(program)
    
    for asc in outputs[:-1]:
        if asc == 10:
            matrix.append(row)
            row = []
        else:
            row.append(chr(asc))
    print_matrix (matrix)
    ans = 0
    m, n = len(matrix), len(matrix[0])
    
    for i in range(1, m-1):
        for j in range(1, n - 1):
            if matrix[i][j] == '#' and matrix[i-1][j] == '#' and matrix[i+1][j] == '#' and matrix[i][j-1] == '#' and matrix[i][j+1] == '#':
                ans += i * j
    return ans

deltas = {
    '^': (-1, 0),
    'v': (1, 0),
    '<': (0, -1),
    '>': (0, 1),
}

left_turn = {
    '^': '<',
    'v': '>',
    '<': 'v',
    '>': '^',
}

right_turn = {
    '^': '>',
    'v': '<',
    '<': '^',
    '>': 'v',
}

def find_start_pos(matrix):
    m, n = len(matrix), len(matrix[0])
    for i in range(m):
        for j in range(n):
            if matrix[i][j] == '^':
                return (i, j)

def get_value(matrix, pos):
    if 0 <= pos[0] < len(matrix) and 0<=pos[1]<len(matrix[0]):
        return matrix[pos[0]][pos[1]]
    return None

def set_value(matrix, pos, v):
    
    matrix[pos[0]][pos[1]] = v

def move(pos, delta):
    return (pos[0] + delta[0], pos[1] + delta[1]) 

def find_path(matrix):
    start_pos = find_start_pos(matrix)
    print(start_pos)
    path = []
    matrix = matrix.copy()
    cur_pos = start_pos
    while True:
        cur_dir = get_value(matrix, cur_pos)
        next_dir = None
        for dir, delta in deltas.items():
            if get_value(matrix, move(cur_pos, delta)) == '#':
                next_dir = dir

        if not next_dir:
            # print_matrix(matrix)
            return path
        if left_turn[cur_dir] == next_dir:
            path.append('L')
        else:
            path.append('R')
        # set_value(matrix, cur_pos, next_dir)
        cur_dir = next_dir
        steps = 0
        while True:
            next_pos = move(cur_pos, deltas[cur_dir])
            v = get_value(matrix, next_pos)
            if v == '#' or v in deltas:
                steps += 1
                cur_pos = next_pos
                set_value(matrix, cur_pos, cur_dir)
            else:
                path.append(steps)
                break
    

def encode(func):
    ans = []
    for c in func:
        if type(c) == str:
            ans.append(ord(c))
        else:
            for x in str(c):
                ans.append(ord(x))
        ans.append(ord(','))
    ans[-1]=10
    return ans
                
def part2(input_file):
    s = read_lines(input_file)[0]
    program = parse_program(s)
    matrix = []
    row = []
    outputs = run(program.copy())
    
    for asc in outputs[:-1]:
        if asc == 10:
            matrix.append(row)
            row = []
        else:
            row.append(chr(asc))
    # print_matrix (matrix)
    print(len(matrix), len(matrix[0]))
    path = find_path(matrix)

    A = ['R', 12, 'L', 10, 'R', 10, 'L', 8]
    B = ['R', 12, 'L', 10, 'R', 12]
    C = ['L', 8, 'R', 10, 'R', 6]

    routine = ['B', 'C', 'B', 'A', 'C', 'A', 'C', 'A', 'B', 'A']
    assert B + C + B + A + C + A + C + A + B + A == path
    
    routine = encode(routine)
    fa = encode(A)
    fb = encode(B)
    fc = encode(C)
    inputs = routine + fa + fb + fc + [ord('n'), 10]
    program[0] = 2
    outputs = run(program, inputs)
    print(outputs[-10:])
    return outputs


In [None]:
part1('inputs/day17.txt')

In [117]:
outputs = part2('inputs/day17.txt')

65 53
(50, 18)


[46, 46, 46, 46, 46, 46, 46, 10, 10, 1409507]


In [107]:
encode(['A','B','C','B','A','C'])

[65, 44, 66, 44, 67, 44, 66, 44, 65, 44, 67, 10]

In [114]:
encode(['R', 8, 'R', 8])

[82, 44, 56, 44, 82, 44, 56, 10]

In [115]:
path = ['R', 12, 'L', 10, 'R', 12, 'L', 8, 'R', 10, 'R', 6, 'R', 12, 'L', 10, 'R', 12, 'R', 12, 'L', 10, 'R', 10, 'L', 8, 'L', 8, 'R', 10, 'R', 6, 'R', 12, 'L', 10, 'R', 10, 'L', 8, 'L', 8, 'R', 10, 'R', 6, 'R', 12, 'L', 10, 'R', 10, 'L', 8, 'R', 12, 'L', 10, 'R', 12, 'R', 12, 'L', 10, 'R', 10, 'L', 8]

A = ['R', 12, 'L', 10, 'R', 10, 'L', 8]
B = ['R', 12, 'L', 10, 'R', 12]
C = ['L', 8, 'R', 10, 'R', 6]

routine = ['B', 'C', 'B', 'A', 'C', 'A', 'C', 'A', 'B', 'A']
assert B + C + B + A + C + A + C + A + B + A == path