In [None]:
from aocd import get_data
import networkx as nx

In [None]:
raw_input_data = get_data(day=5)
example_data = """47|53
97|13
97|61
97|47
75|29
61|13
75|53
29|13
97|29
53|29
61|53
97|53
61|29
47|13
75|47
97|75
47|61
75|61
47|29
75|13
53|13

75,47,61,53,29
97,61,53,29,13
75,29,13
75,97,47,61,53
61,13,29
97,13,75,29,47"""

In [None]:
def process_data(raw_data):
    input_data = raw_data.split("\n\n")
    ordering_rules = input_data[0].split("\n")
    updates = input_data[1].split("\n")
    ordering_rules = [tuple(rule.split('|')) for rule in ordering_rules]
    ordering_rules = [(int(s[0]), int(s[1])) for s in ordering_rules]
    updates = [update.split(',') for update in updates]
    updates = [[int(n) for n in update] for update in updates]
    return ordering_rules, updates

In [None]:
def check_rule(update_map: list[int], rule):
    if rule[0] in update_map and rule[1] in update_map and update_map[rule[1]] < update_map[rule[0]]:
        return False
    return True

def check_all_rules(update, ordering_rules):
    update_map = {p:i for i, p in enumerate(update)}

    for rule in ordering_rules:
        if not check_rule(update_map, rule):
            return False
    return True
    
def get_all_valid_updates(ordering_rules, updates):
    valid_updates = []
    for update in updates:
        if check_all_rules(update, ordering_rules):
            valid_updates.append(update)
    return valid_updates

def get_invalid_updates(ordering_rules, updates):
    valid_updates = get_all_valid_updates(ordering_rules,updates)
    return [update for update in updates if update not in valid_updates]


In [None]:
ordering_rules, updates = process_data(raw_input_data)
valid_updates = get_all_valid_updates(ordering_rules,updates)

In [None]:
sum([update[len(update)//2] for update in valid_updates])

## Part 2

In [None]:
ordering_rules, updates = process_data(example_data)
invalid_updates = get_invalid_updates(ordering_rules,updates)

In [None]:
def reorder_graph(update, ordering_rules):
    graph = nx.DiGraph()
    graph.add_nodes_from(update)
    relevant_rules = [rule for rule in ordering_rules if rule[0] in update and rule[1] in update]
    graph.add_edges_from(relevant_rules)
    return nx.dag_longest_path(graph)

def get_answer(data):
    answer = 0
    ordering_rules, updates = process_data(data)
    invalid_updates = get_invalid_updates(ordering_rules,updates)
    for update in invalid_updates:
        answer += reorder_graph(update, ordering_rules)[len(update) // 2]
    return answer

In [None]:
get_answer(raw_input_data)