In [None]:
from queue import Queue
from dataclasses import dataclass
import networkx as nx
import matplotlib.pyplot as plt

In [None]:
with open("input.txt") as f:
    lines = f.readlines()
lines = [l.strip() for l in lines]
lines[:5]

In [None]:
modules = lines.copy()

In [None]:
# Pulses are a 3-Tuple (sender, receiver, pulse)

In [None]:
@dataclass
class PulseCounter:
    high: int = 0
    low: int = 0


class Flip:
    def __init__(self, name, dests) -> None:
        self.name = name
        self.dests = dests
        self.state = False

    def receive_beam(self, sender, pulse, beam_q: Queue, pc):
        if pulse == 1:
            return

        self.state = not self.state
        new_pulse = 1 if self.state else 0
        for d in self.dests:
            if new_pulse:
                pc.high += 1
            else:
                pc.low += 1
            beam_q.put((self.name, d, new_pulse))

    def __repr__(self) -> str:
        return f"[{self.name}->{self.dests} | {self.state}]"


class Conj:
    def __init__(self, name, dests) -> None:
        self.name = name
        self.dests = dests
        self.incoming = {}

    def add_incoming(self, name):
        self.incoming[name] = False

    def receive_beam(self, sender, pulse, beam_q, pc):
        self.incoming[sender] = pulse == 1
        new_pulse = 0 if all([p for p in self.incoming.values()]) else 1
        for d in self.dests:
            if new_pulse:
                pc.high += 1
            else:
                pc.low += 1
            beam_q.put((self.name, d, new_pulse))

    def __repr__(self) -> str:
        return f"[&{self.name}->{self.dests} | {self.incoming}]"


class Broad:
    def __init__(self, name, dests) -> None:
        self.name = name
        self.dests = dests

    def receive_beam(self, sender, pulse, beam_q, pc):
        for d in self.dests:
            if pulse:
                pc.high += 1
            else:
                pc.low += 1
            beam_q.put((self.name, d, pulse))

    def __repr__(self) -> str:
        return f"[{self.name}->{self.dests}]"

In [None]:
conj_module_dict = {}
module_dict = {}
# First pass to get all Conj modules, because they need all incoming states
for m in modules:
    if m.startswith("&"):
        name = m.split()[0][1:]
        dests = [n.strip() for n in m.split("->")[-1].split(",")]
        conj_mod = Conj(name, dests)
        conj_module_dict[name] = conj_mod
        module_dict[name] = conj_mod

for m in modules:
    if m.startswith("&"):
        name = m.split()[0][1:]
        dests = [n.strip() for n in m.split("->")[-1].split(",")]
        for d in dests:
            if d in conj_module_dict:
                conj_module_dict[d].add_incoming(name)
    elif m.startswith("%"):
        name = m.split()[0][1:]
        dests = [n.strip() for n in m.split("->")[-1].split(",")]
        for d in dests:
            if d in conj_module_dict:
                conj_module_dict[d].add_incoming(name)
        flip_mod = Flip(name, dests)
        module_dict[name] = flip_mod
    elif m.startswith("broad"):
        name = "broadcaster"
        dests = [n.strip() for n in m.split("->")[-1].split(",")]
        broad_mod = Broad(name, dests)
        module_dict[name] = broad_mod


# module_dict

In [None]:
# Queue for pulses
pulse_q = Queue()

In [None]:
def push_button():
    pulse_q.put(("button", "broadcaster", 0))
    pulse_counter = PulseCounter()
    pulse_counter.low += 1
    while not pulse_q.empty():
        pulse = pulse_q.get()
        # print(pulse)
        sender = pulse[0]
        dest = pulse[1]
        p = pulse[2]
        if dest in module_dict:
            module_dict[dest].receive_beam(sender, p, pulse_q, pulse_counter)

    return pulse_counter

In [None]:
high_count = 0
low_count = 0
for i in range(0, 1000):
    pc = push_button()
    high_count += pc.high
    low_count += pc.low

high_count * low_count

part 2

In [None]:
# First naive idea is pushing the button a lot and counting the output towards rx

In [None]:
@dataclass
class PulseCounterRX:
    high: int = 0
    low: int = 0


class Flip:
    def __init__(self, name, dests) -> None:
        self.name = name
        self.dests = dests
        self.state = False

    def receive_beam(self, sender, pulse, beam_q: Queue, pc):
        if pulse == 1:
            return

        self.state = not self.state
        new_pulse = 1 if self.state else 0
        for d in self.dests:
            if d == "rx":
                if new_pulse:
                    pc.high += 1
                else:
                    pc.low += 1
            beam_q.put((self.name, d, new_pulse))

    def __repr__(self) -> str:
        return f"[{self.name}->{self.dests} | {self.state}]"


class Conj:
    def __init__(self, name, dests) -> None:
        self.name = name
        self.dests = dests
        self.incoming = {}

    def add_incoming(self, name):
        self.incoming[name] = False

    def receive_beam(self, sender, pulse, beam_q, pc):
        self.incoming[sender] = pulse == 1
        new_pulse = 0 if all([p for p in self.incoming.values()]) else 1

        for d in self.dests:
            if d == "rx":
                if new_pulse:
                    pc.high += 1
                else:
                    pc.low += 1
            beam_q.put((self.name, d, new_pulse))

    def __repr__(self) -> str:
        return f"[&{self.name}->{self.dests} | {self.incoming}]"


class Broad:
    def __init__(self, name, dests) -> None:
        self.name = name
        self.dests = dests

    def receive_beam(self, sender, pulse, beam_q, pc):
        for d in self.dests:
            if d == "rx":
                if pulse:
                    pc.high += 1
                else:
                    pc.low += 1
            beam_q.put((self.name, d, pulse))

    def __repr__(self) -> str:
        return f"[{self.name}->{self.dests}]"


def init_module_dict():
    conj_module_dict = {}
    module_dict = {}
    # First pass to get all Conj modules, because they need all incoming states
    for m in modules:
        if m.startswith("&"):
            name = m.split()[0][1:]
            dests = [n.strip() for n in m.split("->")[-1].split(",")]
            conj_mod = Conj(name, dests)
            conj_module_dict[name] = conj_mod
            module_dict[name] = conj_mod

    for m in modules:
        if m.startswith("&"):
            name = m.split()[0][1:]
            dests = [n.strip() for n in m.split("->")[-1].split(",")]
            for d in dests:
                if d in conj_module_dict:
                    conj_module_dict[d].add_incoming(name)
        elif m.startswith("%"):
            name = m.split()[0][1:]
            dests = [n.strip() for n in m.split("->")[-1].split(",")]
            for d in dests:
                if d in conj_module_dict:
                    conj_module_dict[d].add_incoming(name)
            flip_mod = Flip(name, dests)
            module_dict[name] = flip_mod
        elif m.startswith("broad"):
            name = "broadcaster"
            dests = [n.strip() for n in m.split("->")[-1].split(",")]
            broad_mod = Broad(name, dests)
            module_dict[name] = broad_mod
    return module_dict


module_dict = init_module_dict()

In [None]:
module_dict = init_module_dict()


def push_button():
    pulse_q.put(("button", "broadcaster", 0))
    pulse_counter = PulseCounterRX()
    while not pulse_q.empty():
        pulse = pulse_q.get()
        # print(pulse)
        sender = pulse[0]
        dest = pulse[1]
        p = pulse[2]
        if dest in module_dict:
            module_dict[dest].receive_beam(sender, p, pulse_q, pulse_counter)

    return pulse_counter


# This doesn't work, takes way too long
for i in range(0, 100):
    pc = push_button()
    if pc.low == 1:
        print(i, pc)

In [None]:
# Lets draw the graph of the modules
G = nx.DiGraph()
edges_list = []
for k, v in module_dict.items():
    for d in v.dests:
        edges_list.append((v.name, d))
G.add_edges_from(edges_list)
pos = nx.spring_layout(G)

plt.figure(figsize=(25, 12))
nx.draw_networkx_labels(G, pos)
nx.draw(G, pos, node_size=250)
plt.show()

In [None]:
# the layout is a bit wild at times
# but recomputing eventually gives a nice setup
# Its easy to see that there are four sets of modules that influence rx
# one for each output of broadcaster

# rx is connected to a single conjunction module, so rx gets a 0 only if everything in the conjunction module is 1

# There probably are cycles in each of the 4 circuits

In [None]:
circuits_to_check = module_dict["broadcaster"].dests
circuits_to_check

In [None]:
module_dict = init_module_dict()
node_before_target = [m for m in module_dict.values() if "rx" in m.dests][0]


# connect button directly to a circuits
# broadcast does the same just to all of them
# That would also yield the cycle numbers but without distinct node names
def push_button(target):
    pulse_q.put(("button", target, 0))
    pulse_counter = PulseCounterRX()
    node_before_target_received_1 = False
    while not pulse_q.empty():
        pulse = pulse_q.get()
        sender = pulse[0]
        dest = pulse[1]
        p = pulse[2]
        if dest in module_dict:
            module_dict[dest].receive_beam(sender, p, pulse_q, pulse_counter)
            # Check if any of the conjunction before the result received a 1
            if dest == node_before_target.name and any(
                [v for v in module_dict[node_before_target.name].incoming.values()]
            ):
                node_before_target_received_1 = True
    return node_before_target_received_1


# Push button 15k times for each circuits
# And indeed there are 3 outputs per circuits that increase by 1 cycle length each
# Now only the lcm is needed to find when all of them fire at the same time
for circle in circuits_to_check:
    for i in range(0, 15000):
        if push_button(circle):
            print(f"{circle}: {i+1}")

In [None]:
from math import gcd

cycle_lengths = []  # list of circle lengths (lowest output per circle from cell above)
lcm = 1
for i in cycle_lengths:
    lcm = lcm * i // gcd(lcm, i)
print(lcm)