In [228]:
with open('input.txt', 'r') as file:
    conditions, sequences = file.read().split('--')

In [229]:
conditions = conditions.splitlines()
sequences = sequences.splitlines()[1:]

In [230]:
conditions = [tuple(c.split('|')) for c in conditions]

In [231]:
sequences = [seq.split(',') for seq in sequences]

In [232]:
def validate_sequence(sequence: list) -> bool:
    for n1, n2 in conditions:
        if n1 in sequence and n2 in sequence:
            if sequence.index(n1) > sequence.index(n2):
                return False
    return True

## Part 1

In [233]:
valid_sequences = []
for sequence in sequences:
    if validate_sequence(sequence):
        valid_sequences.append(sequence)

In [234]:
def get_middle_val(sequence):
    mid_index = len(sequence)//2
    return int(sequence[mid_index])

In [235]:
middle_page_nums = [get_middle_val(seq) for seq in valid_sequences] 

In [236]:
sum(middle_page_nums)

6034

## Part 2

In [237]:
invalid_sequences = []
for sequence in sequences:
    if not validate_sequence(sequence):
        invalid_sequences.append(sequence)

In [238]:
from collections import defaultdict, deque

In [239]:
graph = defaultdict(list)
in_degree = defaultdict(int)

for a, b in conditions:
    graph[a].append(b)
    in_degree[b] += 1
    if a not in in_degree:
        in_degree[a] = 0

In [240]:
def topological_sort(sequence):
    sub_graph = defaultdict(list)
    sub_in_degree = defaultdict(int)
    
    for a, b in conditions:
        if a in sequence and b in sequence:
            sub_graph[a].append(b)
            sub_in_degree[b] += 1
            if a not in sub_in_degree:
                sub_in_degree[a] = 0
    
    order = []
    queue = deque([node for node in sequence if sub_in_degree[node] == 0])
    
    while queue:
        node = queue.popleft()
        order.append(node)
        
        for neighbor in sub_graph[node]:
            sub_in_degree[neighbor] -= 1
            if sub_in_degree[neighbor] == 0:
                queue.append(neighbor)
    
    return order

In [241]:
def reorder_sequences(sequences):
    for i, sequence in enumerate(sequences):
        ordered = topological_sort(sequence)
        sequences[i] = ordered
    return sequences

In [242]:
reordered_sequences = reorder_sequences(invalid_sequences)

In [243]:
middle_page_nums = [get_middle_val(seq) for seq in list(reordered_sequences)] 

In [245]:
sum(middle_page_nums)

6305