In [None]:
from aoc2024 import load_input, load_example

In [None]:
rows = load_input(4).splitlines()

## Part 1

In [None]:
def flip(rows: list[str]) -> list[str]:
    return [row[::-1] for row in rows]


def transpose(rows: list[str]) -> list[str]:
    return [
        "".join([
            row[i_col]
            for row in rows
        ])
        for i_col in range(len(rows[0]))
    ]

def diagonals(rows: list[str]) -> list[str]:
    result = []

    width = len(rows[0])
    height = len(rows)

    start_coords = [(0, x) for x in range(width)] + [(y, 0) for y in range(1, height)]

    for y_start, x_start in start_coords:
        y, x = y_start, x_start

        diagonal = []
        while y < height and x < width:
            diagonal.append(rows[y][x])
            y += 1
            x += 1
        result.append("".join(diagonal))
    
    return result

In [None]:
import re

def count_xmas(row: str) -> int:
    return len(re.findall("XMAS", row))

def count_xmas_all_rows(rows: list[str]) -> int:
    return sum(count_xmas(row) for row in rows)

In [None]:
to_check = [
    rows,
    flip(rows),
    transpose(rows),
    flip(transpose(rows)),
    diagonals(rows),
    flip(diagonals(rows)),
    diagonals(flip(rows)),
    flip(diagonals(flip(rows)))
]

sum(count_xmas_all_rows(c) for c in to_check)

## Part 2

In [None]:
def is_x_mas(rows: list[str], row: int, col: int) -> bool:
    center = rows[row][col]
    if center == "A":
        first_left = rows[row-1][col-1]
        first_right = rows[row+1][col+1]

        first_diagonal = (first_left, first_right)

        if first_diagonal == ("M", "S") or first_diagonal == ("S", "M"):
            
            second_left = rows[row-1][col+1]
            second_right = rows[row+1][col-1]

            second_diagonal = (second_left, second_right)

            return second_diagonal == ("M", "S") or second_diagonal == ("S", "M")

    return False

def count_x_mas(rows: list[str]):
    width = len(rows[0])
    height = len(rows)

    return sum(
        is_x_mas(rows, row, col)
        for row in range(1, height-1)
        for col in range(1, width-1)
    )


In [None]:
count_x_mas(rows)

## Part 2 - bonus numpy

In [None]:
import numpy as np

In [None]:
input_arr = np.array([list(row) for row in rows])

one = np.array([
    ["M", "", ""],
    ["", "A", ""],
    ["", "", "S"]
])

two = np.array([
    ["", "", "M"],
    ["", "", ""],
    ["S", "", ""]
])

possibilities = np.array([
    one + two,
    one + np.flip(two),
    np.flip(one) + two,
    np.flip(one) + np.flip(two),
])

EXPECTED_OVERLAPPING = (possibilities[0] != "").sum()

In [None]:
windows = np.lib.stride_tricks.sliding_window_view(input_arr, (3,3)).reshape((-1, 3, 3))
windows_expanded = windows[:, np.newaxis, :, :]
possibilities_expanded = possibilities[np.newaxis, :, :, :]

((windows_expanded == possibilities_expanded).sum((2,3)) == EXPECTED_OVERLAPPING).sum().item()