In [None]:
from pathlib import Path
import os
import copy
import math

In [None]:
fp = os.path.join(Path().absolute(), "inputs", "input20.txt")
# fp = os.path.join(Path().absolute(), "inputs", "input20_test.txt")
# fp = os.path.join(Path().absolute(), "inputs", "input20_test2.txt")

with open(fp, "r") as f:
    data = f.read().split("\n")[:-1]

In [None]:
data

# Part 1

In [None]:
class Node:

    def __init__(self, name):
        self.name = name
        self.inputs = []
        self.outputs = []
        self.state = None

    def add_child(self, node):
        assert isinstance(node, Node)
        self.outputs.append(node)

    def add_parent(self, node):
        assert isinstance(node, Node)
        self.inputs.append(node)

    def process_pulse(self, input_node, type):
        return None

class Broadcaster(Node):

    def process_pulse(self, input_node, type):

        out = type
        return out

class FlipFlop(Node):

    def __init__(self, name):

        super().__init__(name)
        self.state = "off"

    def process_pulse(self, input_node, type):
        if type == "low":
            if self.state == "off":
                self.state = "on"
                out = "high"
            elif self.state == "on":
                self.state = "off"
                out = "low"
            else:
                out = None
        else:
            out = None

        return out

class Conjunction(Node):

    def process_pulse(self, input_node, type):
        self.state[input_node.name] = type

        if all(self.state[k] == "high" for k in self.state):
            out = "low"
        else:
            out = "high"

        return out

class Pulse:

    def __init__(self, from_node, pulse_type, to_node):

        self.from_node = from_node
        self.pulse_type = pulse_type
        self.to_node = to_node

    def __str__(self):

        return f"{self.from_node.name} -> {self.to_node.name} ({self.pulse_type})"

In [None]:
all_nodes = {}

for line in data:
    input, outputs = line.split(" -> ")
    if input[0] == "%":
        # flip-flop module
        name = input[1:]
        node = FlipFlop(name)
    elif input[0] == "&":
        # conjunction module
        name = input[1:]
        node = Conjunction(name)
    elif input == "broadcaster":
        name = "broadcaster"
        node = Broadcaster(name)
    else:
        raise ValueError

    all_nodes[name] = node

all_nodes["button"] = Node("button")

for line in data:
    input, outputs = line.split(" -> ")
    if input[0] == "%":
        # flip-flop module
        name = input[1:]
    elif input[0] == "&":
        # conjunction module
        name = input[1:]
    elif input == "broadcaster":
        name = "broadcaster"
    else:
        raise ValueError

    outputs = outputs.split(", ")
    if name not in all_nodes:
        raise ValueError
    
    for output in outputs:
        if output not in all_nodes:
            node = Node(output)
            all_nodes[output] = node
    

    for output in outputs:
        input_node = all_nodes[name]
        output_node = all_nodes[output]
        input_node.add_child(output_node)
        output_node.add_parent(input_node)

In [None]:
for node in all_nodes.values():
    if isinstance(node, Conjunction):
        node.state = {input_node.name: "low" for input_node in node.inputs}

In [None]:
all_nodes["broadcaster"].outputs

In [None]:
for node in all_nodes.values():
    node.initial_state = copy.deepcopy(node.state)

state_history = [{name: node.state for name, node in all_nodes.items()}]

num_low_pulses_per_cycle = []
num_high_pulses_per_cycle = []

for num_cycle in range(1000):
    if num_cycle % 10 == 0:
        print(num_cycle)

    initial_pulse = Pulse(all_nodes["button"], "low", all_nodes["broadcaster"])
    pulse_queue = [initial_pulse]

    num_low_pulses = 0
    num_high_pulses = 0

    max_num_iter = 1000000
    num_iter = 0
    while len(pulse_queue) > 0 and num_iter < max_num_iter:
        # print(num_iter, [str(p) for p in pulse_queue])
        pulse = pulse_queue.pop(0)
        # print(pulse)
        if pulse.pulse_type == "low":
            num_low_pulses += 1
        else:
            num_high_pulses += 1

        if pulse.to_node.name != "output":
            next_pulse_type = pulse.to_node.process_pulse(pulse.from_node, pulse.pulse_type)

        if next_pulse_type is not None:
            for next_to_node in pulse.to_node.outputs:
                next_pulse = Pulse(pulse.to_node, next_pulse_type, next_to_node)
                pulse_queue.append(next_pulse)

        num_iter += 1
    
    num_low_pulses_per_cycle.append(num_low_pulses)
    num_high_pulses_per_cycle.append(num_high_pulses)

    period_found = False
    for i, hist in enumerate(state_history):
        if all(node.state == hist[name] for name, node in all_nodes.items()):
            period = num_cycle + 1 - i
            offset = i
            print(f"Period = {period}, offset = {offset}")
            period_found = True
            break
    
    if period_found:
        break
    else:
        state_history.append(copy.deepcopy({name: node.state for name, node in all_nodes.items()}))

In [None]:
if period_found:

    num_low_pulses_one_period = sum(num_low_pulses_per_cycle[offset:(offset + period)])
    num_high_pulses_one_period = sum(num_high_pulses_per_cycle[offset:(offset + period)])

    num_periods = (1000 - offset) // period
    num_additional_cycles = (1000 - offset) % period

    num_low_pulses_all_cycles = sum(num_low_pulses_per_cycle[:offset]) + num_low_pulses_one_period * num_periods + sum(num_low_pulses_per_cycle[offset:(offset + num_additional_cycles)])
    num_high_pulses_all_cycles = sum(num_high_pulses_per_cycle[:offset]) + num_high_pulses_one_period * num_periods + sum(num_high_pulses_per_cycle[offset:(offset + num_additional_cycles)])

    res = num_low_pulses_all_cycles * num_high_pulses_all_cycles

else:
    # no period
    res = sum(num_low_pulses_per_cycle) * sum(num_high_pulses_per_cycle)

print(res)

# Part 2

The sole parent of rx is the conjunction module th. Its parents in turn are four conjunction modules xn, qn, xf and zl. So need to find their periods and from there the lowest common multiple.

In [None]:
for node in all_nodes.values():
    node.initial_state = copy.deepcopy(node.state)

periods_final_conjunction_modules = {}

for num_cycle in range(20000):
    if num_cycle % 1000 == 0:
        print(num_cycle)

    initial_pulse = Pulse(all_nodes["button"], "low", all_nodes["broadcaster"])
    pulse_queue = [initial_pulse]

    max_num_iter = 1000000
    num_iter = 0
    while len(pulse_queue) > 0 and num_iter < max_num_iter:
        pulse = pulse_queue.pop(0)

        # if (pulse.pulse_type == "low" and pulse.to_node.name == "rx") or \
        if   (pulse.pulse_type == "high" and pulse.from_node.name == "xn") or \
                (pulse.pulse_type == "high" and pulse.from_node.name == "qn") or \
                    (pulse.pulse_type == "high" and pulse.from_node.name == "xf") or \
                        (pulse.pulse_type == "high" and pulse.from_node.name == "zl"):
            if pulse.from_node.name not in periods_final_conjunction_modules:
                periods_final_conjunction_modules[pulse.from_node.name] = num_cycle + 1
            
        # if (pulse.to_node.name == "rx") or \
        # if    (pulse.from_node.name == "xn") or \
        #         ( pulse.from_node.name == "qn") or \
        #             ( pulse.from_node.name == "xf") or \
        #                 (pulse.from_node.name == "zl"):
            print(f"{pulse.from_node.name} -> {pulse.to_node.name} ({pulse.pulse_type}) at {num_cycle = }")

        if pulse.to_node.name != "output":
            next_pulse_type = pulse.to_node.process_pulse(pulse.from_node, pulse.pulse_type)

        if next_pulse_type is not None:
            for next_to_node in pulse.to_node.outputs:
                next_pulse = Pulse(pulse.to_node, next_pulse_type, next_to_node)
                pulse_queue.append(next_pulse)

        num_iter += 1

In [None]:
periods_final_conjunction_modules

In [None]:
def lcm(a, b):
    return int((a * b) / math.gcd(a, b))

nums = list(periods_final_conjunction_modules.values())
lcm_res = nums[0]
for i in range(1, 4):
    lcm_res = lcm(lcm_res, nums[i])

print(lcm_res)