In [25]:
from dataclasses import dataclass
import re
from typing import Optional
from tqdm import tqdm
from IPython.display import clear_output

In [31]:
with open("input.txt", "r") as f:
    lines = f.read()

map, movements = lines.split("\n\n")


@dataclass
class Block:
    cords: list[tuple[int, int]]

    def __init__(self, cords: list[tuple[int, int]]):
        self.cords = tuple(cords)

    def get_new_cords(self, direction: tuple) -> list[tuple]:
        return [
            (cord[0] + direction[0], cord[1] + direction[1])
            for cord in self.cords
        ]
    
    def update_cords(self, direction: tuple):
        self.cords = tuple(self.get_new_cords(direction))
    
    def __hash__(self):
        return hash(self.cords)
    
    def __eq__(self, other):
        return self.cords == other.cords


def load_data(part_2: bool = False):    
    walls: set[tuple] = set()
    blocks: dict[tuple, Block] = dict()

    robot: tuple

    for y, line in enumerate(map.splitlines()):
        for x, cell in enumerate(line):

            if part_2:
                x = x*2

            if cell == "#":
                walls.add((x, y))
                if part_2:
                    walls.add((x+1, y))
            elif cell == "@":
                robot = (x, y)
            elif cell == "O":
                if part_2:
                    block = Block([(x, y), (x+1, y)])
                    blocks[(x, y)] = block
                    blocks[(x+1, y)] = block
                else:
                    blocks[(x, y)] = Block([(x, y)])
    
    return walls, blocks, robot

def print_map(robot, walls, blocks):

    max_x = max(x for x,y in walls)
    max_y = max(y for x,y in walls)

    for y in range(max_y+1):
        for x in range(max_x+1):

            if (x, y) == robot:
                print("@", end="")
            elif (x, y) in blocks:
                print("O", end="")
            elif (x, y) in walls:
                print("#", end="")
            else:
                print(".", end="")
        print()
    print()

def push_block(block_cords: tuple, direction: tuple, walls, blocks):

    block = blocks[block_cords]

    block_tree = set()
    block_to_check = set()
    block_to_check.add(block)

    while len(block_to_check) > 0:

        block = block_to_check.pop()
        block_tree.add(block)

        for block_cord in block.get_new_cords(direction):

            if block_cord in walls:
                return False
            
            if block_cord in blocks:
                
                new_block = blocks[block_cord]
                if new_block not in block_tree:
                    block_to_check.add(new_block)
    
    for block in block_tree:
        for cord in block.cords:
            blocks.pop(cord)

    for block in block_tree:
        for cord in block.get_new_cords(direction):
            blocks[cord] = block
        block.update_cords(direction)

    return True

def simulate_movements(movements, walls, blocks, robot):

    move_dict = {"v": (0, 1), "^": (0, -1), "<": (-1, 0), ">": (1, 0)}

    for movement in movements.replace("\n", ""):

        direction = move_dict[movement]
        new_cords = (robot[0] + direction[0], robot[1] + direction[1])

        if new_cords in walls:
            continue

        if new_cords in blocks:
            if push_block((new_cords[0], new_cords[1]), direction, walls, blocks):
                robot = new_cords
        else:
            robot = new_cords
    
    return blocks

walls, blocks, robot = load_data(part_2=False)
print(sum(b[0] + b[1] * 100 for b in simulate_movements(movements, walls, blocks, robot)))

walls, blocks, robot = load_data(part_2=True)
print(sum(min(b.cords, key=lambda x: x[0])[0] + min(b.cords, key=lambda x: x[1])[1] * 100 for b in simulate_movements(movements, walls, blocks, robot).values()) / 2)

1563092
1582688.0
