In [320]:
from abc import ABC, abstractmethod
from enum import Enum
from typing import List, Tuple


class Pulse(Enum):
    LOW = 0
    HIGH = 1


class Module(ABC):
    def __init__(self, name: str, destination_modules = List[str]):
        self.name = name
        self.destination_modules = destination_modules

    @abstractmethod
    def get_hash(self) -> str:
        ...

    @abstractmethod
    def handle_pulse(self, input_module: str, pulse: Pulse) -> List[Tuple[str, Pulse]]:
        ...

    def reset(self):
        ...


In [321]:
class FlipFlopModule(Module):
    def __init__(self, name: str, destination_modules: List[str]):
        super().__init__(name, destination_modules)
        self.active = False

    def get_hash(self) -> str:
        return f"flip[{self.active}]"

    def handle_pulse(self, input_module: str, pulse: Pulse) -> List[Tuple[str, Pulse]]:
        if pulse == Pulse.HIGH:
            return []
        
        result_pulse = Pulse.LOW if self.active else Pulse.HIGH
        self.active = not self.active
        return [
            (self.name, result_pulse, destination_module)
            for destination_module in self.destination_modules
        ]
    
    def reset(self):
        self.active = False


test_flip_flop_module = FlipFlopModule("test_flip", ['A', 'B'])
print(test_flip_flop_module.get_hash())
print(test_flip_flop_module.handle_pulse('A', Pulse.HIGH))
print(test_flip_flop_module.handle_pulse('A', Pulse.LOW))
print(test_flip_flop_module.handle_pulse('A', Pulse.LOW))
print(test_flip_flop_module.handle_pulse('A', Pulse.HIGH))

flip[False]
[]
[('test_flip', <Pulse.HIGH: 1>, 'A'), ('test_flip', <Pulse.HIGH: 1>, 'B')]
[('test_flip', <Pulse.LOW: 0>, 'A'), ('test_flip', <Pulse.LOW: 0>, 'B')]
[]


In [322]:
class ConjunctionModule(Module):
    def __init__(self, name: str, destination_modules: List[str]):
        super().__init__(name, destination_modules)
        self.last_pulses = dict()

    def get_hash(self) -> str:
        return f"conjunction[{';'.join(f'{module}:{int(pulse.value)}' for module, pulse in self.last_pulses.items())}]"
    
    def add_link(self, input_module: str):
        self.last_pulses[input_module] = Pulse.LOW

    def handle_pulse(self, input_module: str, pulse: Pulse) -> List[Tuple[str, Pulse]]:
        self.last_pulses[input_module] = pulse
        response_pulse = Pulse.LOW if all(pulse == Pulse.HIGH for pulse in self.last_pulses.values()) else Pulse.HIGH
        return [
            (self.name, response_pulse, destination_module)
            for destination_module in self.destination_modules
        ]
    
    def reset(self):
        for key in self.last_pulses:
            self.last_pulses[key] = Pulse.LOW


test_conjunction_module = ConjunctionModule("conjuction_test", ['A', 'B'])
print(test_conjunction_module.handle_pulse('A', Pulse.HIGH))
print(test_conjunction_module.handle_pulse('B', Pulse.LOW))
print(test_conjunction_module.handle_pulse('B', Pulse.HIGH))
print(test_conjunction_module.get_hash())

[('conjuction_test', <Pulse.LOW: 0>, 'A'), ('conjuction_test', <Pulse.LOW: 0>, 'B')]
[('conjuction_test', <Pulse.HIGH: 1>, 'A'), ('conjuction_test', <Pulse.HIGH: 1>, 'B')]
[('conjuction_test', <Pulse.LOW: 0>, 'A'), ('conjuction_test', <Pulse.LOW: 0>, 'B')]
conjunction[A:1;B:1]


In [323]:
class BroadcastModule(Module):
    def get_hash(self) -> str:
        return "broadcast"

    def handle_pulse(self, input_module: str, pulse: Pulse) -> List[Tuple[str, Pulse]]:
        return [
            (self.name, pulse, destination_module)
            for destination_module in self.destination_modules
        ]


test_broadcast_module = BroadcastModule("test_broadcast", ['A', 'B'])
print(test_broadcast_module.get_hash())
print(test_broadcast_module.handle_pulse('A', Pulse.HIGH))
print(test_broadcast_module.handle_pulse('A', Pulse.LOW))

broadcast
[('test_broadcast', <Pulse.HIGH: 1>, 'A'), ('test_broadcast', <Pulse.HIGH: 1>, 'B')]
[('test_broadcast', <Pulse.LOW: 0>, 'A'), ('test_broadcast', <Pulse.LOW: 0>, 'B')]


In [324]:
test_data_raw = """broadcaster -> a, b, c
%a -> b
%b -> c
%c -> inv
&inv -> a
"""


def parse_data(data_raw):
    modules = dict()
    
    for line in data_raw.splitlines():
        base_module_name, destination_modules_raw = line.split(" -> ")
        destination_modules = destination_modules_raw.split(", ")
        module_name = base_module_name.lstrip("%&")
        module = None
        if base_module_name.startswith("%"):
            module = FlipFlopModule(module_name, destination_modules)
        elif base_module_name.startswith("&"):
            module = ConjunctionModule(module_name, destination_modules)
        elif base_module_name == "broadcaster":
            module = BroadcastModule(module_name, destination_modules)
        if module:
            modules[module_name] = module

    conjunction_module_names = {module_name for module_name, module in modules.items() if isinstance(module, ConjunctionModule)}

    for module_name, module in modules.items():
        for destination_module in module.destination_modules:
            if destination_module in conjunction_module_names:
                modules[destination_module].add_link(module_name)

    return modules


test_data = parse_data(test_data_raw)
print(test_data["inv"].last_pulses)
test_data

{'c': <Pulse.LOW: 0>}


{'broadcaster': <__main__.BroadcastModule at 0x7fb4754c1a50>,
 'a': <__main__.FlipFlopModule at 0x7fb4754c0fd0>,
 'b': <__main__.FlipFlopModule at 0x7fb4754c2710>,
 'c': <__main__.FlipFlopModule at 0x7fb4754c1f90>,
 'inv': <__main__.ConjunctionModule at 0x7fb4754c2390>}

In [325]:
def reset_modules(modules):
    for module in modules.values():
        module.reset()

In [336]:
class RXLowException(Exception):
    ...


def count_pulses(modules, pulses, except_rx_low=False):
    if len(pulses) == 0:
        return (0, 0)
    
    pulses_idx = 0
    result = [0, 0]

    # for caching
    pulses_hash = "-".join([f"{src_module_name}>{str(pulse.value)}>{dest_module_name}" for src_module_name, pulse, dest_module_name in pulses])
    modules_hash = "-".join(m.get_hash() for m in modules.values())
    current_hash = f"{pulses_hash}_{modules_hash}"
    
    while pulses[pulses_idx][2] not in modules.keys():
        if pulses[pulses_idx][1] == Pulse.LOW and pulses[pulses_idx][2] == "rx" and except_rx_low:
            raise RXLowException()
        result[int(pulses[pulses_idx][1] == Pulse.LOW)] += 1
        pulses_idx += 1
        if pulses_idx >= len(pulses):
            return tuple(result)

    src_module_name, pulse, dest_module_name = pulses[pulses_idx]
    result[int(pulse == Pulse.LOW)] += 1

    module = modules[dest_module_name]
    new_pulses = module.handle_pulse(src_module_name, pulse)
    new_result = count_pulses(modules, pulses[(pulses_idx + 1):] + new_pulses, except_rx_low)
    result[0] += new_result[0]
    result[1] += new_result[1]

    return tuple(result)



reset_modules(test_data)
print(count_pulses(test_data, [("button", Pulse.LOW, "broadcaster")]))

(4, 8)


In [337]:
def part1(data, push_count=1000):
    reset_modules(data)
    cache = dict()
    high_pulse_count, low_pulse_count = (0, 0)

    for _ in range(push_count):
        pulses = [("button", Pulse.LOW, "broadcaster")]
        current_high_pulse_count, current_low_pulse_count = count_pulses(data, pulses, cache)
        high_pulse_count += current_high_pulse_count
        low_pulse_count += current_low_pulse_count
    
    return high_pulse_count * low_pulse_count


part1(test_data)

32000000

In [338]:
test_data2_raw = """broadcaster -> a
%a -> inv, con
&inv -> b
%b -> con
&con -> output"""

test_data2 = parse_data(test_data2_raw)

test_data2

{'broadcaster': <__main__.BroadcastModule at 0x7fb45ca00050>,
 'a': <__main__.FlipFlopModule at 0x7fb45ca00ed0>,
 'inv': <__main__.ConjunctionModule at 0x7fb45ca03e50>,
 'b': <__main__.FlipFlopModule at 0x7fb45ca01250>,
 'con': <__main__.ConjunctionModule at 0x7fb45ca01350>}

In [339]:
part1(test_data2)

11687500

In [243]:
with open('input.txt') as f:
    data = parse_data(f.read())

data["broadcaster"]

<__main__.BroadcastModule at 0x7fb4750d29d0>

In [340]:
part1(data)

938065580

In [341]:
test_data3_raw = """broadcaster -> a
%a -> inv, con
&inv -> b
%b -> con
&con -> rx"""

test_data3 = parse_data(test_data3_raw)
test_data3

{'broadcaster': <__main__.BroadcastModule at 0x7fb45c9fad10>,
 'a': <__main__.FlipFlopModule at 0x7fb45c9f9490>,
 'inv': <__main__.ConjunctionModule at 0x7fb45c9f8e10>,
 'b': <__main__.FlipFlopModule at 0x7fb45c9fa190>,
 'con': <__main__.ConjunctionModule at 0x7fb45c9fbd90>}

In [343]:
def part2(data):
    reset_modules(data)
    presses = 0
    while True:
        presses += 1
        pulses = [("button", Pulse.LOW, "broadcaster")]
        try:
            count_pulses(data, pulses, except_rx_low=True)
        except RXLowException:
            return presses


part2(test_data3)

1