In [18]:
with open('input') as f:
    lines = f.readlines()
    lines = [line.strip() for line in lines]

rules: list[tuple[str, str]] = []
updates: list[list[str]] = []

for line in lines:
    if '|' in line:
        a, b = line.split('|')
        rules.append((a, b))
    elif ',' in line:
        updates.append(line.split(','))

In [19]:
# Part 1

from dataclasses import dataclass

@dataclass(frozen=True)
class Node:
    name: str
    prerequisites: set[str]
    postrequisites: set[str]

def make_graph(available_nodes: set[str]) -> dict[str, Node]:
    nodes: dict[str, Node] = {}

    for rule in rules:
        pre, post = rule

        if pre not in available_nodes or post not in available_nodes:
            continue

        if pre not in nodes:
            nodes[pre] = Node(pre, set(), set())

        pre_node = nodes[pre]

        if post not in nodes:
            nodes[post] = Node(post, set(), set())

        post_node = nodes[post]

        pre_node.postrequisites.add(post)
        post_node.prerequisites.add(pre)

    return nodes

sum_ = 0

for update in updates:
    available_nodes = set(update)
    nodes = make_graph(available_nodes)

    pre_set: set[str] = set()
    post_set = available_nodes.copy()

    for n in update:
        post_set.remove(n)
        node = nodes[n]
        if not node.prerequisites.issubset(pre_set):
            break
        pre_set.add(n)
    else:
        sum_ += int(update[len(update) // 2])

print(sum_)
    

4959


In [20]:
# Part 2

# Kahn's algorithm
def topsort(graph: dict[str, Node]) -> list[str]:
    l: list[str] = []
    s: set[str] = set([node.name for node in graph.values() if len(node.prerequisites) == 0])

    while len(s) > 0:
        n = s.pop()
        l.append(n)

        for m in graph[n].postrequisites:
            graph[m].prerequisites.remove(n)
            if len(graph[m].prerequisites) == 0:
                s.add(m)

    return l

sum_ = 0

for update in updates:
    available_nodes = set(update)
    nodes = make_graph(available_nodes)

    pre_set: set[str] = set()
    post_set = available_nodes.copy()

    for n in update:
        post_set.remove(n)
        node = nodes[n]
        if not node.prerequisites.issubset(pre_set):
            rearranged = topsort(nodes)
            sum_ += int(rearranged[len(rearranged) // 2])
            break
        pre_set.add(n)

print(sum_)

4655
