In [24]:
from utils import read_lines

class Rule:
    def __init__(self, rid, fix=None, rule1=None, rule2=None):
        self.rid = rid
        self.fix = fix
        self.rule1= rule1
        self.rule2 = rule2
    
    def __repr__(self):
        return f'Rule({self.rid=}, {self.fix=}, {self.rule1=}, {self.rule2=})'
    

def parse_rule(line):
    rid, rules = line.split(': ')
    rid = int(rid)
    if '"' in rules:
        return Rule(rid, fix=rules[1])
    rules = rules.split(' | ')
    if len(rules) == 2:
        rule1 = [int(r) for r in rules[0].split(' ')]
        rule2 = [int(r) for r in rules[1].split(' ')]
        return Rule(rid, rule1=rule1, rule2=rule2)
    else:
        rule1 = [int(r) for r in rules[0].split(' ')]
        return Rule(rid, rule1=rule1)

def parse_input(input_file):
    lines = read_lines(input_file)
    i = 0
    rules = {}
    while lines[i]:   
        rule = parse_rule(lines[i])
        rules[rule.rid] = rule
        i += 1
    
    return rules, lines[i+1:]

def part1(input_file):
    rules, messages = parse_input(input_file)
    valid_set = {}
    
    def dfs(rid):
        if rid in valid_set:
            return valid_set[rid]
        rule = rules[rid]
        if rule.fix:
            valid_set[rid] = set([rule.fix])
            return valid_set[rid]
        ss = set([''])
        for partial_rid in rule.rule1:
            new_ss = set()
            for prefix in ss:
                for partial in dfs(partial_rid):
                    new_ss.add(prefix+partial)
            ss = new_ss
        if rule.rule2:
            ss2 = set([''])
            for partial_rid in rule.rule2:
                new_ss = set()
                for prefix in ss2:
                    for partial in dfs(partial_rid):
                        new_ss.add(prefix+partial)
                ss2 = new_ss
            ss = ss.union(ss2)
        valid_set[rid] = ss
        return ss
    
    valid0 = dfs(0)
    ans = 0
    for m in messages:
        if m in valid0:
            ans += 1
    return ans

In [25]:
part1('inputs/day19_test.txt')

2

In [26]:
part1('inputs/day19.txt')

129

In [50]:
def part2(input_file):
    rules, messages = parse_input(input_file)
    valid_set = {}
    
    def dfs(rid):
        if rid in valid_set:
            return valid_set[rid]
        rule = rules[rid]
        if rule.fix:
            valid_set[rid] = set([rule.fix])
            return valid_set[rid]
        ss = set([''])
        for partial_rid in rule.rule1:
            new_ss = set()
            for prefix in ss:
                for partial in dfs(partial_rid):
                    new_ss.add(prefix+partial)
            ss = new_ss
        if rule.rule2:
            ss2 = set([''])
            for partial_rid in rule.rule2:
                new_ss = set()
                for prefix in ss2:
                    for partial in dfs(partial_rid):
                        new_ss.add(prefix+partial)
                ss2 = new_ss
            ss = ss.union(ss2)
        valid_set[rid] = ss
        return ss
    
    dfs(0)
    valid42 = valid_set[42]
    valid31 = valid_set[31]
    len42 = len(next(iter(valid42)))
    len31 = len(next(iter(valid31)))

    ans = 0
    for m in messages:
        cnt42 = 0
        cnt31 = 0
        i = 0
        while i < len(m):
            if m[i:i+len42] in valid42:
                cnt42 += 1
                i += len42
            else:
                break
        while i < len(m):
            if m[i:i+len31] in valid31:
                cnt31 += 1
                i += len31
            else:
                break
        if i != len(m):
            continue
        if cnt42 > cnt31 and cnt31 > 0:
            ans += 1
            # print(m)
    return ans

In [51]:
part2('inputs/day19_test2.txt')

12

In [52]:
part2('inputs/day19.txt')

243