In [None]:
from pathlib import Path
import os
from itertools import combinations, product
import copy

In [None]:
fp = os.path.join(Path().absolute(), "inputs", "input19.txt")
# fp = os.path.join(Path().absolute(), "inputs", "input19_test.txt")

with open(fp, "r") as f:
    data = f.read().split("\n")[:-1]

In [None]:
data

# Part 1

In [None]:
sep = data.index("")
workflows = data[:sep]
parts = data[(sep + 1):]

In [None]:
parts_list = []
for line in parts:
    line_elts = line[1:-1].split(",")
    part = {}
    for line_elt in line_elts:
        attr, value = line_elt.split("=")
        part[attr] = int(value)
    parts_list.append(part)

In [None]:
parts_list

In [None]:
workflows

In [None]:
workflow_dict = {}
for line in workflows:
    name, rest = line.split("{")
    rest = rest[:-1]
    instructions = []
    rest_elts = rest.split(",")
    for rest_elt in rest_elts:
        if ":" in rest_elt:
            condition, dest = rest_elt.split(":")
        else:
            condition = "True"
            dest = rest_elt
        
        instruction = [condition, dest]
        instructions.append(instruction)

    workflow_dict[name] = instructions

In [None]:
workflow_dict

In [None]:
accepted_parts = []

total = 0
for part in parts_list:
    x = part["x"]
    m = part["m"]
    a = part["a"]
    s = part["s"]

    current_workflow_name = "in"
    instruction_idx = 0
    terminated = False
    
    while not terminated:
        instructions = workflow_dict[current_workflow_name]
        condition, dest = instructions[instruction_idx]
        # print(part, condition, dest)
        if eval(condition):
            if dest == "A":
                accepted_parts.append(part)
                terminated = True
            elif dest == "R":
                terminated = True
            else:
                current_workflow_name = dest
            instruction_idx = 0
        else:
            instruction_idx += 1

In [None]:
accepted_parts

In [None]:
total = sum(sum(part.values()) for part in accepted_parts)
total

# Part 2

First we build a graph of all nodes. The accept node will accept the Cartesian product of all ranges 1 <= {x, m, a, s} <= 4000. The reject node will accept the empty set. 

We then propagate up the accept ranges using interval calculus. We can always find a node whose children's accept areas we already know (proof by induction). We then combine the children's accept areas with the node conditions to compute the node's accept area.

Note that this solution, while slow, works for any DAG (not just a tree structure, in which case we could usa recursion starting from the root node).

In [None]:
class Node:

    ratings = ["x", "m", "a", "s"]

    def __init__(self, name, condition=None):
        self.name = name
        self.condition = condition
        self.true_parents = []
        self.false_parents = []
        self.true_child = None
        self.false_child = None
        self.accept_area = None # everything outside is rejected

    def add_child(self, node, edge):
        assert isinstance(node, Node)
        assert edge in ["T", "F"]

        if edge == "T":
            self.true_child = node
            node.true_parents.append(self)
        else:
            self.false_child = node
            node.false_parents.append(self)

In [None]:
reject_node = Node("R", None)
accept_node = Node("A", None)
all_nodes = {"R": reject_node, "A": accept_node}

for prefix, instructions in workflow_dict.items():
    num_elts = len(instructions)

    for instruction_idx, instruction in enumerate(instructions[:-1]):
        cond, dest = instruction

        current_name = prefix + str(instruction_idx)
        if current_name in all_nodes:
            current_node = all_nodes[current_name]
            if current_node.condition is None:
                assert "<" in cond or ">" in cond or "=" in cond
                current_node.condition = cond
        else:
            current_node = Node(current_name, cond)
            all_nodes[current_name] = current_node
        
        # Add true child node
        if dest in ["R", "A"]:
            true_child_name = dest
        else:
            true_child_name = dest + "0"
        if true_child_name in all_nodes:
            true_child_node = all_nodes[true_child_name]
        else:
            true_child_node = Node(true_child_name)
            all_nodes[true_child_name] = true_child_node
        current_node.add_child(true_child_node, "T")

        # Add false child node
        if instruction_idx == num_elts - 2:
            # Penultimate instruction
            next_instruction = instructions[instruction_idx + 1]
            next_instruction_cond, next_instruction_dest = next_instruction
            assert next_instruction_cond == "True"
            if next_instruction_dest in ["R", "A"]:
                false_child_name = next_instruction_dest
            else:
                false_child_name = next_instruction_dest + "0"
        else:
            false_child_name = prefix + str(instruction_idx + 1)
    
        if false_child_name in all_nodes:
            false_child_node = all_nodes[false_child_name]
        else:
            false_child_node = Node(false_child_name)
            all_nodes[false_child_name] = false_child_node
        current_node.add_child(false_child_node, "F")

In [None]:
for name, node in all_nodes.items():
    print(name, node.condition)

In [None]:
class EmptySet:

    def __init__(self):
        pass

class Interval:

    def __init__(self, left, right):
        assert 1 <= left <= right <= 4000
        self.left = left
        self.right = right

    def compute_length(self):
        return self.right - self.left + 1
    
    def intersect(self, interval):
        assert isinstance(interval, Interval)
        intersection_left = max(self.left, interval.left)
        intersection_right = min(self.right, interval.right)
        if intersection_left <= intersection_right:
            return Interval(intersection_left, intersection_right)
        else:
            return EmptySet()
        
    def complement(self):
        intervals = []
        if self.left >= 2:
            interval = Interval(1, self.left - 1)
            intervals.append(interval)
        if self.right <= 3999:
            interval = Interval(self.right + 1, 4000)
            intervals.append(interval)

        interval_collection = IntervalCollection(intervals)
        return interval_collection
    
    def union(self, interval):
        assert isinstance(interval, Interval)
        intersection = self.intersect(interval)
        if not isinstance(intersection, EmptySet):
            union_left = min(self.left, interval.left)
            union_right = max(self.right, interval.right)
            return Interval(union_left, union_right)
        elif self.right + 1 == interval.left:
            union_left = self.left
            union_right = interval.right
            return Interval(union_left, union_right)
        elif interval.right + 1 == self.left:
            union_left = interval.left
            union_right = self.right
            return Interval(union_left, union_right)
        else:
            return IntervalCollection([self, interval])
    
class IntervalCollection:
    """Disjoint union of Intervals"""

    def __init__(self, intervals):
        assert isinstance(intervals, list)
        for interval in intervals:
            assert isinstance(interval, Interval)

        self.intervals = intervals
    
class AreaPart:
    """Cartesian product of Intervals"""

    def __init__(self, x_interval, m_interval, a_interval, s_interval):
        assert isinstance(x_interval, Interval)
        assert isinstance(m_interval, Interval)
        assert isinstance(a_interval, Interval)
        assert isinstance(s_interval, Interval)

        self.d = {"x": x_interval, "m": m_interval, "a": a_interval, "s": s_interval}

    def compute_size(self):
        prod = 1
        for rating in Node.ratings:
            interval_length = self.d[rating].compute_length()
            assert interval_length >= 1
            prod *= interval_length
        
        return prod
    
    def intersect(self, area_part):
        assert isinstance(area_part, AreaPart)

        x_interval_intersection = self.d["x"].intersect(area_part.d["x"])
        m_interval_intersection = self.d["m"].intersect(area_part.d["m"])
        a_interval_intersection = self.d["a"].intersect(area_part.d["a"])
        s_interval_intersection = self.d["s"].intersect(area_part.d["s"])

        if any(isinstance(intersection, EmptySet) for intersection in [x_interval_intersection, m_interval_intersection, a_interval_intersection, s_interval_intersection]):
            return EmptySet()

        return AreaPart(x_interval_intersection, m_interval_intersection, a_interval_intersection, s_interval_intersection)
    
    def complement(self):

        # # Negate at least one rating
        # for r in range(1, 5):
        #     combs = list(combinations(Node.ratings, r))
        #     # negate all ratings in comb
        #     for comb in combs:
        #         list_of_interval_collections = []
        #         for rating in Node.ratings:
        #             if rating in comb:
        #                 # negate
        #                 complement = self.d[rating].complement()
        #                 assert isinstance(complement, IntervalCollection)
        #                 list_of_interval_collections.append(complement)
        #             else:
        #                 list_of_interval_collections.append(IntervalCollection(self.d[rating]))
                
        #         # Each element of the Cartesian product is an AreaPart
        #         area_parts_this_comb = [AreaPart(*seq) for seq in list(product(list_of_interval_collections))]
        #         area_parts.append(area_parts_this_comb)


        # More compact
        area_list = []
        for idx_to_negate in range(4):
            
            list_of_interval_collections = []
            for i, rating in enumerate(Node.ratings):
                if i < idx_to_negate:
                    interval_collection = IntervalCollection([self.d[rating]])
                elif i == idx_to_negate:
                    interval_collection = self.d[rating].complement()
                else:
                    interval_collection = IntervalCollection([Interval(1, 4000)])

                assert isinstance(interval_collection, IntervalCollection)
                list_of_interval_collections.append(interval_collection.intervals)
            
            # Each element of the Cartesian product is an AreaPart
            if not any(len(ic) == 0 for ic in list_of_interval_collections):
                area = Area([AreaPart(*seq) for seq in product(*list_of_interval_collections)])
                area_list.append(area)

        # Disjoint union
        res = Area([area_part for area in area_list for area_part in area.area_parts])

        assert isinstance(res, Area)
        return res
    
class Area:
    """Disjoint union of AreaParts"""

    def __init__(self, area_parts):
        assert isinstance(area_parts, list)
        for area_part in area_parts:
            assert isinstance(area_part, AreaPart)

        self.area_parts = area_parts
        
        print(f"Start of compactification, {len(area_parts) = }")
        self.compactify()
        print("End of compactification")

    def compute_size(self):
        """Assumes that area parts are not overlapping"""
        total = 0
        for area_part in self.area_parts:
            size = area_part.compute_size()
            total += size

        return total
    
    def intersect(self, area):
        assert isinstance(area, Area)

        area_parts_intersection = []
        for area_part in self.area_parts:
            for area_part_other_area in area.area_parts:
                intersection = area_part.intersect(area_part_other_area)
                if not isinstance(intersection, EmptySet):
                    area_parts_intersection.append(intersection)

        return Area(area_parts_intersection)
    
    def complement(self):

        res = S
        for area_part in self.area_parts:
            cp = area_part.complement()
            if isinstance(cp, EmptySet):
                return EmptySet

            assert isinstance(cp, Area)
            res = res.intersect(cp)
            if isinstance(cp, EmptySet):
                return EmptySet

            assert isinstance(res, Area)

        return res
    
    def union(self, area):
        assert isinstance(area, Area)

        cp = self.complement()
        print("Computed complement")
        
        tmp_list = [self]
        for other_area_part in area.area_parts:
            tmp = cp.intersect(Area([other_area_part]))
            tmp_list.append(tmp)

        print("Computed tmp_list")
        # disjoint union of areas
        res = Area([area_part for tmp in tmp_list for area_part in tmp.area_parts])
        print("Computed res")

        assert isinstance(res, Area)
        return res
    

    def compactify(self):
        """Ensures that an area is represented in the most compact form possible"""

        found_comb_list = []
        combined_area_parts = []

        for comb in combinations(enumerate(self.area_parts), 2):
            (i, area_part1), (j, area_part2) = comb
            if i in found_comb_list or j in found_comb_list:
                continue

            equal_ratings = []
            for rating in Node.ratings:
                if area_part1.d[rating].left == area_part2.d[rating].left and area_part1.d[rating].right == area_part2.d[rating].right:
                    equal_ratings.append(rating)
            if len(equal_ratings) == 4:
                raise ValueError
            elif len(equal_ratings) == 3:
                differing_rating = [rating for rating in Node.ratings if rating not in equal_ratings][0]
                interval_union = area_part1.d[differing_rating].union(area_part2.d[differing_rating])
                if isinstance(interval_union, Interval):
                    # Can combine
                    d_new = copy.deepcopy(area_part1.d)
                    d_new[differing_rating] = interval_union
                    area_part_combined = AreaPart(x_interval=d_new["x"], m_interval=d_new["m"], a_interval=d_new["a"], s_interval=d_new["s"]) 
                    combined_area_parts.append(area_part_combined)
                    found_comb_list += [i, j]

        self.area_parts = [area_part for idx, area_part in enumerate(self.area_parts) if idx not in found_comb_list] + combined_area_parts
        if len(found_comb_list) > 0:
            self.compactify()

In [None]:
S = Area([AreaPart(Interval(1, 4000), Interval(1, 4000), Interval(1, 4000), Interval(1, 4000))])
E = Area([])

In [None]:
A1 = Area([AreaPart(Interval(3, 4000), Interval(1, 4000), Interval(1, 4000), Interval(1, 4000)),
          AreaPart(Interval(1, 2), Interval(1, 4000), Interval(1, 4000), Interval(1, 4000))])
print(A1.area_parts)
A1.compactify()
print(A1.area_parts)

In [None]:
reject_node.accept_area = E
accept_node.accept_area = S

In [None]:
def convert_condition_to_accept_area(condition):
    """condition is of the form 
    x<2000 (or =, >)
    """
    if "<" in condition:
        rating, value = condition.split("<")
        value = int(value)
        tmp = {}
        for r in Node.ratings:
            if r == rating:
                tmp[r] = Interval(1, value - 1)
            else:
                tmp[r] = Interval(1, 4000)
    
    elif ">" in condition:
        rating, value = condition.split(">")
        value = int(value)
        tmp = {}
        for r in Node.ratings:
            if r == rating:
                tmp[r] = Interval(value + 1, 4000)
            else:
                tmp[r] = Interval(1, 4000)

    elif "=" in condition:
        rating, value = condition.split("=")
        value = int(value)
        tmp = {}
        for r in Node.ratings:
            if r == rating:
                tmp[r] = Interval(value, value)
            else:
                tmp[r] = Interval(1, 4000)

    else:
        raise ValueError("Unknown comparison operator")

    d = AreaPart(x_interval=tmp["x"], m_interval=tmp["m"], a_interval=tmp["a"], s_interval=tmp["s"]) 
    accept_area = Area([d])

    return accept_area

In [None]:
num_nodes = len(all_nodes)
num_nodes

In [None]:
processed_nodes = [reject_node, accept_node]

while len(processed_nodes) < num_nodes:
    # print(processed_nodes)
    print(f"Number of processed nodes = {len(processed_nodes)}")
    
    # Find node whose children are all processed already
    current_node = None
    for name, node in all_nodes.items():
        if node not in processed_nodes and node.true_child in processed_nodes and node.false_child in processed_nodes:
            current_node = node
            break
    assert current_node is not None

    # Compute accepted / rejected areas for current_node
    A_true = node.true_child.accept_area
    A_false = node.false_child.accept_area
    A_condition = convert_condition_to_accept_area(node.condition)

    print(f"Computing tmp1")
    tmp1 = A_true.intersect(A_condition)
    assert isinstance(tmp1, Area)

    print(f"Computing tmp2")
    tmp2 = A_condition.complement()
    assert isinstance(tmp2, Area)

    print(f"Computing tmp3")
    tmp3 = A_false.intersect(tmp2)
    assert isinstance(tmp3, Area)

    print(f"Computing accept area: {len(tmp1.area_parts) = }, {len(tmp3.area_parts) = }")
    accept_area = tmp1.union(tmp3)
    assert isinstance(accept_area, Area)
    current_node.accept_area = accept_area
    print("Done processing node")

    processed_nodes.append(current_node)

In [None]:
# compute number of elements of root node accept area
num_combs = all_nodes["in0"].accept_area.compute_size()
num_combs