## Setup

In [None]:
import sys
from pathlib import Path

from aocd import get_data, submit

In [5]:
# Add parent directory to path to allow relative imports into Jupyter notebook
sys.path.append(str(Path.cwd().parent))

In [7]:
# Get raw advent-of-code data
data: str = get_data(year=2024, day=8)

## Part a

In [None]:
# Imports
from itertools import combinations
from typing import TYPE_CHECKING

from common.utils.dict_grid import print_grid, text_to_dict

if TYPE_CHECKING:
    from collections.abc import Callable

In [None]:
# Functions


def get_direct_antinodes_for_antenna_pair(
    grid: dict[complex, str], antenna_1: complex, antenna_2: complex
) -> set[complex]:
    """Get direct antinodes (which are one diff_vec away from each antenna) for antenna pair within grid bounds."""
    diff_vec: complex = antenna_2 - antenna_1  # Get the vector between the two antennas
    return {antenna_1 - diff_vec, antenna_2 + diff_vec} & grid.keys()  # Only return antinodes that are in the grid


def count_antinodes_in_grid(
    grid: dict[complex, str],
    anti_node_counting_func: Callable[
        [dict[complex, str], complex, complex], set[complex]
    ] = get_direct_antinodes_for_antenna_pair,
) -> int:
    """Count the number of antinodes in the grid."""
    # Find the antenna locations for each frequency
    antennas_per_frequency = {
        frequency: [pos for pos, char in grid.items() if char == frequency]  # Store the positions of the antennas
        for frequency in set(grid.values()) - {"."}  # For each frequency in the grid
    }

    # Find the antinodes for each frequency
    antinodes = {
        antinode
        for antennas_of_frequency in antennas_per_frequency.values()  # For each set of frequency antennas
        for a1, a2 in combinations(antennas_of_frequency, 2)  # Get all pairs of antennas
        for antinode in anti_node_counting_func(grid, a1, a2)  # Get the antinodes for each antenna pair
    }

    return len(antinodes)

In [None]:
# Parse the grid
grid = text_to_dict(data)
print_grid(grid)

In [None]:
# Count the number of direct antinodes in the grid
total_direct_antinodes_in_grid = count_antinodes_in_grid(grid)

In [None]:
# Submit answer
submit(total_direct_antinodes_in_grid, part="a", day=8, year=2024)

## Part b

In [25]:
# Functions
def get_all_antinodes_for_antenna_pair(
    grid: dict[complex, str], antenna_1: complex, antenna_2: complex
) -> set[complex]:
    """Antinodes are the points in the that are one or more diff_vecs away from the antennas."""
    # Get the vector between the two antennas
    diff_vec: complex = antenna_2 - antenna_1

    # Initialize the set of antinodes
    antinodes = set()
    antinode_1, antinode_2 = antenna_1, antenna_2

    # Traverse the grid in the direction of the diff_vec
    while antinode_1 in grid:
        antinodes.add(antinode_1)
        antinode_1 -= diff_vec

    # Traverse the grid in opposite direction
    while antinode_2 in grid:
        antinodes.add(antinode_2)
        antinode_2 += diff_vec
    return antinodes

In [26]:
# Count the number of all antinodes in the grid
total_antinodes_in_grid = count_antinodes_in_grid(grid, anti_node_counting_func=get_all_antinodes_for_antenna_pair)

In [None]:
# Submit answer
submit(total_antinodes_in_grid, part="b", day=8, year=2024)