In [1]:
from pathlib import Path
import os
import copy
import math

In [2]:
fp = os.path.join(Path().absolute(), "inputs", "input20.txt")
# fp = os.path.join(Path().absolute(), "inputs", "input20_test.txt")
# fp = os.path.join(Path().absolute(), "inputs", "input20_test2.txt")

with open(fp, "r") as f:
    data = f.read().split("\n")[:-1]

In [3]:
data

['%cv -> xz',
 '%kt -> qx, rz',
 '%cb -> kt',
 '%pl -> sf, db',
 '%zd -> ln, gf',
 '%bf -> qx, pf',
 '%xz -> jd',
 '%xm -> db',
 '%vz -> cr, vc',
 '%qq -> qm, gf',
 '&xn -> th',
 '%nn -> ff, db',
 '%gx -> cd',
 '&qn -> th',
 '%qk -> vc',
 '&xf -> th',
 '%qj -> xm, db',
 '%fn -> pr, gf',
 '%sf -> bp',
 '%jd -> qx, vm',
 '%mc -> ds, db',
 '%tj -> lc, gf',
 '%jz -> qj, db',
 '%sb -> ks, vc',
 '%ln -> gf, qq',
 '%bx -> qx, qp',
 'broadcaster -> sr, ch, hd, bx',
 '%ch -> db, mc',
 '%ds -> cc',
 '&qx -> cb, cv, bx, xz, vm, zl',
 '%bp -> db, jz',
 '&zl -> th',
 '%vl -> gf, fj',
 '&db -> ff, ds, sf, ch, cc, xf',
 '&th -> rx',
 '%cr -> gx, vc',
 '%sr -> gf, vl',
 '%lr -> sb',
 '%hv -> lr',
 '%cl -> qx, bf',
 '%lc -> gf, fn',
 '%pm -> vc, qk',
 '%cc -> nn',
 '%gm -> tj, gf',
 '%vm -> cl',
 '%ff -> pl',
 '%qp -> cb, qx',
 '%pf -> qx',
 '&vc -> lr, hd, ks, qn, gx, nh, hv',
 '%qm -> gm',
 '%nh -> hv',
 '%rz -> qx, cv',
 '%ks -> vz',
 '%fj -> zd',
 '&gf -> fj, qm, xn, sr',
 '%pr -> gf',
 '%cd -> pm,

# Part 1

In [4]:
class Node:

    def __init__(self, name):
        self.name = name
        self.inputs = []
        self.outputs = []
        self.state = None

    def add_child(self, node):
        assert isinstance(node, Node)
        self.outputs.append(node)

    def add_parent(self, node):
        assert isinstance(node, Node)
        self.inputs.append(node)

    def process_pulse(self, input_node, type):
        return None

class Broadcaster(Node):

    def process_pulse(self, input_node, type):

        out = type
        return out

class FlipFlop(Node):

    def __init__(self, name):

        super().__init__(name)
        self.state = "off"

    def process_pulse(self, input_node, type):
        if type == "low":
            if self.state == "off":
                self.state = "on"
                out = "high"
            elif self.state == "on":
                self.state = "off"
                out = "low"
            else:
                out = None
        else:
            out = None

        return out

class Conjunction(Node):

    def process_pulse(self, input_node, type):
        self.state[input_node.name] = type

        if all(self.state[k] == "high" for k in self.state):
            out = "low"
        else:
            out = "high"

        return out

class Pulse:

    def __init__(self, from_node, pulse_type, to_node):

        self.from_node = from_node
        self.pulse_type = pulse_type
        self.to_node = to_node

    def __str__(self):

        return f"{self.from_node.name} -> {self.to_node.name} ({self.pulse_type})"

In [5]:
all_nodes = {}

for line in data:
    input, outputs = line.split(" -> ")
    if input[0] == "%":
        # flip-flop module
        name = input[1:]
        node = FlipFlop(name)
    elif input[0] == "&":
        # conjunction module
        name = input[1:]
        node = Conjunction(name)
    elif input == "broadcaster":
        name = "broadcaster"
        node = Broadcaster(name)
    else:
        raise ValueError

    all_nodes[name] = node

all_nodes["button"] = Node("button")

for line in data:
    input, outputs = line.split(" -> ")
    if input[0] == "%":
        # flip-flop module
        name = input[1:]
    elif input[0] == "&":
        # conjunction module
        name = input[1:]
    elif input == "broadcaster":
        name = "broadcaster"
    else:
        raise ValueError

    outputs = outputs.split(", ")
    if name not in all_nodes:
        raise ValueError
    
    for output in outputs:
        if output not in all_nodes:
            node = Node(output)
            all_nodes[output] = node
    

    for output in outputs:
        input_node = all_nodes[name]
        output_node = all_nodes[output]
        input_node.add_child(output_node)
        output_node.add_parent(input_node)

In [6]:
for node in all_nodes.values():
    if isinstance(node, Conjunction):
        node.state = {input_node.name: "low" for input_node in node.inputs}

In [7]:
all_nodes["broadcaster"].outputs

[<__main__.FlipFlop at 0x7fdbc9a6ab80>,
 <__main__.FlipFlop at 0x7fdbc9a6a820>,
 <__main__.FlipFlop at 0x7fdbc9a603a0>,
 <__main__.FlipFlop at 0x7fdbc9a6a760>]

In [8]:
for node in all_nodes.values():
    node.initial_state = copy.deepcopy(node.state)

state_history = [{name: node.state for name, node in all_nodes.items()}]

num_low_pulses_per_cycle = []
num_high_pulses_per_cycle = []

for num_cycle in range(1000):
    if num_cycle % 10 == 0:
        print(num_cycle)

    initial_pulse = Pulse(all_nodes["button"], "low", all_nodes["broadcaster"])
    pulse_queue = [initial_pulse]

    num_low_pulses = 0
    num_high_pulses = 0

    max_num_iter = 1000000
    num_iter = 0
    while len(pulse_queue) > 0 and num_iter < max_num_iter:
        # print(num_iter, [str(p) for p in pulse_queue])
        pulse = pulse_queue.pop(0)
        # print(pulse)
        if pulse.pulse_type == "low":
            num_low_pulses += 1
        else:
            num_high_pulses += 1

        if pulse.to_node.name != "output":
            next_pulse_type = pulse.to_node.process_pulse(pulse.from_node, pulse.pulse_type)

        if next_pulse_type is not None:
            for next_to_node in pulse.to_node.outputs:
                next_pulse = Pulse(pulse.to_node, next_pulse_type, next_to_node)
                pulse_queue.append(next_pulse)

        num_iter += 1
    
    num_low_pulses_per_cycle.append(num_low_pulses)
    num_high_pulses_per_cycle.append(num_high_pulses)

    period_found = False
    for i, hist in enumerate(state_history):
        if all(node.state == hist[name] for name, node in all_nodes.items()):
            period = num_cycle + 1 - i
            offset = i
            print(f"Period = {period}, offset = {offset}")
            period_found = True
            break
    
    if period_found:
        break
    else:
        state_history.append(copy.deepcopy({name: node.state for name, node in all_nodes.items()}))

0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
400
410
420
430
440
450
460
470
480
490
500
510
520
530
540
550
560
570
580
590
600
610
620
630
640
650
660
670
680
690
700
710
720
730
740
750
760
770
780
790
800
810
820
830
840
850
860
870
880
890
900
910
920
930
940
950
960
970
980
990


In [11]:
if period_found:

    num_low_pulses_one_period = sum(num_low_pulses_per_cycle[offset:(offset + period)])
    num_high_pulses_one_period = sum(num_high_pulses_per_cycle[offset:(offset + period)])

    num_periods = (1000 - offset) // period
    num_additional_cycles = (1000 - offset) % period

    num_low_pulses_all_cycles = sum(num_low_pulses_per_cycle[:offset]) + num_low_pulses_one_period * num_periods + sum(num_low_pulses_per_cycle[offset:(offset + num_additional_cycles)])
    num_high_pulses_all_cycles = sum(num_high_pulses_per_cycle[:offset]) + num_high_pulses_one_period * num_periods + sum(num_high_pulses_per_cycle[offset:(offset + num_additional_cycles)])

    res = num_low_pulses_all_cycles * num_high_pulses_all_cycles

else:
    # no period
    res = sum(num_low_pulses_per_cycle) * sum(num_high_pulses_per_cycle)

print(res)

856482136


# Part 2

In [None]:
for node in all_nodes.values():
    node.initial_state = copy.deepcopy(node.state)

periods_final_conjunction_modules = {}

for num_cycle in range(20000):
    if num_cycle % 1000 == 0:
        print(num_cycle)

    initial_pulse = Pulse(all_nodes["button"], "low", all_nodes["broadcaster"])
    pulse_queue = [initial_pulse]

    max_num_iter = 1000000
    num_iter = 0
    while len(pulse_queue) > 0 and num_iter < max_num_iter:
        pulse = pulse_queue.pop(0)

        # if (pulse.pulse_type == "low" and pulse.to_node.name == "rx") or \
        if   (pulse.pulse_type == "high" and pulse.from_node.name == "xn") or \
                (pulse.pulse_type == "high" and pulse.from_node.name == "qn") or \
                    (pulse.pulse_type == "high" and pulse.from_node.name == "xf") or \
                        (pulse.pulse_type == "high" and pulse.from_node.name == "zl"):
            if pulse.from_node.name not in periods_final_conjunction_modules:
                periods_final_conjunction_modules[pulse.from_node.name] = num_cycle + 1
            
        # if (pulse.to_node.name == "rx") or \
        # if    (pulse.from_node.name == "xn") or \
        #         ( pulse.from_node.name == "qn") or \
        #             ( pulse.from_node.name == "xf") or \
        #                 (pulse.from_node.name == "zl"):
            print(f"{pulse.from_node.name} -> {pulse.to_node.name} ({pulse.pulse_type}) at {num_cycle = }")

        if pulse.to_node.name != "output":
            next_pulse_type = pulse.to_node.process_pulse(pulse.from_node, pulse.pulse_type)

        if next_pulse_type is not None:
            for next_to_node in pulse.to_node.outputs:
                next_pulse = Pulse(pulse.to_node, next_pulse_type, next_to_node)
                pulse_queue.append(next_pulse)

        num_iter += 1

In [17]:
periods_final_conjunction_modules

{'zl': 3739, 'qn': 3793, 'xf': 3923, 'xn': 4027}

In [18]:
def lcm(a, b):
    return int((a * b) / math.gcd(a, b))

nums = list(periods_final_conjunction_modules.values())
lcm_res = nums[0]
for i in range(1, 4):
    lcm_res = lcm(lcm_res, nums[i])

print(lcm_res)

224046542165867
