In [25]:
from utils import read_lines
from collections import defaultdict
import numpy as np

# there are total 8 oritations: 
# non_flip, rotate 0, 90, 180, 270
# flip: rotate 0, 90, 180,270
def parse_input(input_file):
    rules = {}
    lines = read_lines(input_file)
    for line in lines:
        a, b = line.split(' => ')
        rules[a] = b
    return rules

def flip_left_right(matrix):
    ans = []
    for row in matrix:
        ans.append(row[::-1])
    return ans

def flip_top_down(matrix):
    n = len(matrix)
    ans = [[None] * n for _ in range(n)]
    for i in range(n):
        for j in range(n):
            ans[i][j] = matrix[n-i-1][j]
    return ans

def rotate_clockwise_90(matrix):
    n = len(matrix)
    ans = [[None] * n for _ in range(n)]
    for i in range(n):
        for j in range(n):
            ans[i][j] = matrix[n-1-j][i]
    return ans

def print_matrix(matrix):
    for row in matrix:
        print(''.join(row))

def rule_to_matrix(rule):
    parts = rule.split('/')
    ans = []
    for row in parts:
        ans.append([c for c in row])
    return ans

def matrix_to_rule(matrix):
    return '/'.join(''.join(row) for row in matrix)

def all_oritentions(matrix):
    ans = [matrix.copy()]
    for _ in range(3):
        ans.append(rotate_clockwise_90(ans[-1]))
        
    ans.append(flip_top_down(matrix))
    for _ in range(3):
        ans.append(rotate_clockwise_90(ans[-1]))
    
    return ans

def split_matrix(matrix):
    ans = []
    n = len(matrix)
    if n % 2 == 0:
        step = 2
    else:
        step = 3
    for i in range(0, n, step):
        for j in range(0, n, step):
            sub = [['.'] * step for _ in range(step)]
            for i1 in range(step):
                for j1 in range(step):
                    sub[i1][j1] = matrix[i+i1][j+j1]
            ans.append(sub)
    return ans

def combine_matrix(subs):
    ans = []
    rt = int(np.sqrt(len(subs)))
    n = len(subs)
    m = len(subs[0])
    for i in range(0, n, rt):
        for j in range(m):
            row = []
            for k in range(rt):
                row += subs[i+k][j]
            ans.append(row)
    return ans
    

start = """.#.
..#
###"""
start = [[c for c in line] for line in start.split('\n')]

def enhance(matrix, rules, iterations):
    for _ in range(iterations):
        subs = split_matrix(matrix)
        enhanced = []
        for sub in subs:
            for ori in all_oritentions(sub):
                r = matrix_to_rule(ori)
                if r in rules:
                    enhanced.append(rule_to_matrix(rules[r]))
                    break
        matrix = combine_matrix(enhanced)
    return matrix

def count_pixels(matrix):
    ans = 0
    for i in range(len(matrix)):
        for j in range(len(matrix[0])):
            if matrix[i][j] == '#':
                ans += 1
    return ans

def part1(input_file, iterations=5):
    rules = parse_input(input_file)
    matrix = enhance(start, rules, iterations)
    return count_pixels(matrix)


In [26]:
part1('inputs/day21_test.txt', 2)

12

In [27]:
part1('inputs/day21.txt', 5)

147

In [28]:
part1('inputs/day21.txt', 18)

1936582