# 優先度付きキュー


In [1]:
from logging import getLogger, StreamHandler, DEBUG

logger = getLogger(__name__)
handler = StreamHandler()
handler.setLevel(DEBUG)
logger.setLevel(DEBUG)
logger.addHandler(handler)
logger.propagate = False


In [25]:
import sys


class PrioritizedQueue:
    def __init__(self, list: list[int] = []):
        self._array = [sys.maxsize] + list

    @property
    def array(self) -> list:
        return self._array[1:]

    def max_heapify(self, root: int):
        logger.debug(f"{self._array=}")
        left = root * 2
        right = left + 1
        max_i = root

        if len(self._array) - 1 < left:
            return

        if self._array[root] < self._array[left]:
            max_i = left

        if right < len(self._array) and self._array[max_i] < self._array[right]:
            max_i = right

        if max_i != root:
            self._array[root], self._array[max_i] = (
                self._array[max_i],
                self._array[root],
            )
            self.max_heapify(max_i)

    def insert(self, num: int):
        logger.debug(f"{self._array=}")
        # 螺旋本では`increaseKey`メソッドに分けて実装している箇所
        self._array.append(num)
        logger.debug(f"{self._array=}")
        current = len(self._array) - 1
        parent = current // 2
        while self._array[parent] < self._array[current]:
            self._array[parent], self._array[current] = (
                self._array[current],
                self._array[parent],
            )
            logger.debug(f"{self._array=}")
            current = parent
            parent = current // 2

    def extract(self) -> int:
        last = self._array.pop()
        if len(self._array) == 1:
            return last
        extracted, self._array[1] = self._array[1], last
        self.max_heapify(1)
        return extracted


In [26]:
expected = [3, 2, 1]
pq = PrioritizedQueue([2, 3, 1])
pq.max_heapify(1)
actual = pq.array
assert expected == actual


self._array=[9223372036854775807, 2, 3, 1]
self._array=[9223372036854775807, 3, 2, 1]


In [27]:
expected = [6, 4, 5]
pq = PrioritizedQueue([5, 4])
pq.insert(6)
actual = pq.array
assert expected == actual


self._array=[9223372036854775807, 5, 4]
self._array=[9223372036854775807, 5, 4, 6]
self._array=[9223372036854775807, 6, 4, 5]


In [28]:
expected_extracted = 10
expected_array = [8, 6, 4]
pq = PrioritizedQueue([10, 6, 8, 4])
actual_extracted = pq.extract()
actual_array = pq.array
assert expected_extracted == actual_extracted
assert expected_array == actual_array


self._array=[9223372036854775807, 4, 6, 8]
self._array=[9223372036854775807, 8, 6, 4]


In [29]:
def main(commands: list[set[str, int | None]]):
    logger.debug(f"{commands=}")
    pq = PrioritizedQueue()
    extracted = []
    for cmd, num in commands:
        if cmd == "insert":
            pq.insert(num)
        elif cmd == "extract":
            extracted.append(pq.extract())
        else:
            return extracted


In [30]:
def line_to_set(line: str) -> set[str, int | None]:
    splitted = line.split(" ")
    if len(splitted) == 2:
        return (splitted[0], int(splitted[1]))
    else:
        return (splitted[0], None)


def parse(input: str) -> list[set[str, int | None]]:
    lines = input.splitlines()
    return [line_to_set(line) for line in lines]


In [31]:
input = """insert 10
end"""
expected = [("insert", 10), ("end", None)]
actual = parse(input)
assert expected == actual


In [32]:
input = """insert 8
insert 2
extract
insert 10
extract
insert 11
extract
extract
end"""
expected = [8, 10, 11, 2]
actual = main(parse(input))
assert expected == actual


commands=[('insert', 8), ('insert', 2), ('extract', None), ('insert', 10), ('extract', None), ('insert', 11), ('extract', None), ('extract', None), ('end', None)]
self._array=[9223372036854775807]
self._array=[9223372036854775807, 8]
self._array=[9223372036854775807, 8]
self._array=[9223372036854775807, 8, 2]
self._array=[9223372036854775807, 2]
self._array=[9223372036854775807, 2]
self._array=[9223372036854775807, 2, 10]
self._array=[9223372036854775807, 10, 2]
self._array=[9223372036854775807, 2]
self._array=[9223372036854775807, 2]
self._array=[9223372036854775807, 2, 11]
self._array=[9223372036854775807, 11, 2]
self._array=[9223372036854775807, 2]


## 標準ライブラリによる優先度付きキュー


In [37]:
import heapq

pq = [5, 9, 6, 7, 4, 1, 0]
# WARN: heapq.heapify() はミュータブル！
heapq.heapify(pq)
popped = heapq.heappop(pq)
assert 0 == popped


In [39]:
pq = [10, 20, 30, 40, 50]
# heapq._heapify_max(pq)
# WARN: `heapq.heappush(pq, int)`すると崩れるし、そもそもPublic API ではないため、`_heapify_max`は使わないのが無難。
# https://discuss.python.org/t/make-max-heap-functions-public-in-heapq/16944/12

heapq.heappush(pq, 99)
popped = heapq.heappop(pq)
assert 99 == popped


AssertionError: 