In [86]:
from utils import read_lines
from collections import defaultdict
import numpy as np

# there are total 8 oritations: 
# non_flip, rotate 0, 90, 180, 270
# flip: rotate 0, 90, 180,270
def parse_input(input_file):
    ans = {}
    lines = read_lines(input_file)
    for i in range(0, len(lines), 12):
        tid = int(lines[i][5:-1])
        rows = []
        for j in range(i+1, i+11):
            rows.append(lines[j])
        ans[tid] = rows
    return ans

def find_corners(input_file):
    tiles = parse_input(input_file)
    n = 10
    edge_to_tile_orient = defaultdict(list) # edge -> [(tid, edge_no, flip)]
    for tid, rows in tiles.items():
        edge_to_tile_orient[rows[0]].append((tid, 0, False))
        edge_to_tile_orient[rows[0][::-1]].append((tid, 0, True))
        edge_to_tile_orient[rows[n-1]].append((tid, 2, False))
        edge_to_tile_orient[rows[n-1][::-1]].append((tid, 2, True))
        edge1 = ''.join([row[n-1] for row in rows])
        edge_to_tile_orient[edge1].append((tid, 1, False))
        edge_to_tile_orient[edge1[::-1]].append((tid, 1, True))
        edge3 = ''.join([row[0] for row in rows])
        edge_to_tile_orient[edge3].append((tid, 3, False))
        edge_to_tile_orient[edge3[::-1]].append((tid, 3, True))
    
    corners = []
    for tid, rows in tiles.items():
        same = 0
        edges = [rows[0], rows[-1], ''.join([row[n-1] for row in rows]), ''.join([row[0] for row in rows])]
        for edge in edges:
            if len(edge_to_tile_orient[edge]) > 1:
                same += 1
        if same == 2:
            corners.append(tid)
    return corners, edge_to_tile_orient, tiles

def part1(input_file):
    corners, _, _ = find_corners(input_file)
    return np.prod(corners), corners

def flip_left_right(matrix):
    ans = []
    for row in matrix:
        ans.append(row[::-1])
    return ans

def flip_top_down(matrix):
    n = len(matrix)
    ans = [[None] * n for _ in range(n)]
    for i in range(n):
        for j in range(n):
            ans[i][j] = matrix[n-i-1][j]
    return ans

def rotate_clockwise_90(matrix):
    n = len(matrix)
    ans = [[None] * n for _ in range(n)]
    for i in range(n):
        for j in range(n):
            ans[i][j] = matrix[n-1-j][i]
    return ans

def print_matrix(matrix):
    for row in matrix:
        print(''.join(row))

def tile_to_matrix(tile):
    ans = []
    for row in tile:
        ans.append([c for c in row])
    return ans

def sea_monster():
    s = \
"""
                  # 
#    ##    ##    ###
 #  #  #  #  #  #   
"""
    lines = s.split('\n')
    lines = lines[1:-1]
    lines
    height = len(lines)
    width = len(lines[0])
    sharps = []
    for i in range(height):
        for j in range(width):
            if lines[i][j] == '#':
                sharps.append([i,j])
    return sharps, height, width


def all_oritentions(matrix):
    ans = [matrix.copy()]
    for _ in range(3):
        ans.append(rotate_clockwise_90(ans[-1]))
        
    ans.append(flip_top_down(matrix))
    for _ in range(3):
        ans.append(rotate_clockwise_90(ans[-1]))
    
    return ans

def get_matrix_edges(matrix):
    n = len(matrix)
    return [''.join(matrix[0]), ''.join([matrix[i][-1] for i in range(n)]), ''.join(matrix[-1]), ''.join([matrix[i][0] for i in range(n)])]

def puzzle_to_matrix(puzzle):
    n = len(puzzle)
    m = 8 * n
    ans = [[] for _ in range(m)]
    for i in range(m):
        for j in range(m):
            pi = i // 8
            pj = j // 8
            ans[i].append(puzzle[pi][pj][i % 8 + 1][j % 8 + 1])
    return ans

def count_monster(matrix):
    total_sharps = 0
    n = len(matrix)
    for i in range(n):
        for j in range(n):
            if matrix[i][j] == '#':
                total_sharps += 1
    max_monster_count = 0
    sharps, height, width = sea_monster()
    for ori in all_oritentions(matrix):
        monster_count = 0
        for i in range(n-height):
            for j in range(n-width):
                is_monster = True
                for di, dj in sharps:
                    if ori[i+di][j+dj] != '#':
                        is_monster = False
                        break
                if is_monster:
                    monster_count += 1
        max_monster_count = max(max_monster_count, monster_count)
    return total_sharps - max_monster_count * len(sharps)
        

def part2(input_file):
    corners, edge_to_tile_orient, tiles = find_corners(input_file)
    assert all(len(x) <= 2 for x in edge_to_tile_orient.values())
    m = int(np.sqrt(len(tiles)))
    puzzle_ids = [[0] * m for _ in range(m)]
    puzzle_ids[0][0] = corners[0]
    puzzle= [[None] * m for _ in range(m)]
    for i in range(m):
        for j in range(m):
            left, up = None, None
            if i > 0:
                up = get_matrix_edges(puzzle[i-1][j])[2]
            if j > 0:
                left = get_matrix_edges(puzzle[i][j-1])[1]
            matrix = tile_to_matrix(tiles[puzzle_ids[i][j]])
            found = False
            for ori in all_oritentions(matrix):
                edges = get_matrix_edges(ori)
                if ((not left and len(edge_to_tile_orient[edges[3]])==1) or left == edges[3]) and ((not up and len(edge_to_tile_orient[edges[0]]) == 1) or up == edges[0]):
                    found = True
                    puzzle[i][j] = ori
                    break
            if not found:
                raise ValueError(f'cannot find oritation in cell {i}, {j}')
            edges = get_matrix_edges(puzzle[i][j])
            if j < m-1 and not puzzle_ids[i][j+1]:
                right = edges[1]
                same_edge_tiles = edge_to_tile_orient[right]
                if len(same_edge_tiles) != 2:
                    raise ValueError(f'no same edge right: {i} {j}')
                for k in range(2):
                    if same_edge_tiles[k][0] != puzzle_ids[i][j]:
                        puzzle_ids[i][j+1] = same_edge_tiles[k][0]
                        break
            if i < m - 1:
                down = edges[2]
                same_edge_tiles = edge_to_tile_orient[down]
                if len(same_edge_tiles) != 2:
                    raise ValueError(f'no same edge down: {i} {j}')
                for k in range(2):
                    if same_edge_tiles[k][0] != puzzle_ids[i][j]:
                        puzzle_ids[i+1][j] = same_edge_tiles[k][0]
                        break
    # print(puzzle_ids)
    puzzle_matrix = puzzle_to_matrix(puzzle)
    # print_matrix(puzzle_matrix)    
    return count_monster(puzzle_matrix)

    
    

In [57]:
part1('inputs/day20.txt')

(4006801655873, [1327, 1087, 2753, 1009])

In [35]:
part1('inputs/day20_test.txt')

(20899048083289, [1951, 1171, 2971, 3079])

In [87]:
part2('inputs/day20_test.txt')

273

In [88]:
part2('inputs/day20.txt')

1838

In [5]:
input = parse_input('inputs/day20.txt')
len(input)

144