In [1]:
import sys
import os

# 프로젝트 루트 디렉토리 경로를 추가
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import jax
import jax.numpy as jnp
import chex
import time
#disable jax JIT
#jax.config.update("jax_disable_jit", True)

from tqdm.autonotebook import trange, tqdm
from functools import partial
from JAxtar.bgpq import HashTableHeapValue, BGPQ

  from tqdm.autonotebook import trange, tqdm


In [3]:
max_size = int(1e5)
group_size = 1000
filled_sample = 1000

In [4]:
#cpu heap test
import heapq as hq
import numpy as np
heap = []
for i in trange(max_size):
    hq.heappush(heap, np.random.rand())

for i in trange(max_size):
    out = hq.heappop(heap)

100%|██████████| 100000/100000 [00:00<00:00, 2771334.56it/s]
100%|██████████| 100000/100000 [00:00<00:00, 2420144.48it/s]


In [5]:
heap = BGPQ.make_heap(int(1e6), group_size, HashTableHeapValue)
insert = jax.jit(heap.insert)
delete_mins = jax.jit(heap.delete_mins)

In [6]:
heap_key = jax.random.uniform(jax.random.PRNGKey(0), shape=(filled_sample,), minval=0, maxval=10)
value = jax.vmap(HashTableHeapValue.default)(jnp.arange(filled_sample))
make_batch = BGPQ.make_batched(heap_key, value, group_size)

In [7]:
pbar = tqdm(total=int(max_size))
for i in range(0, max_size, filled_sample):
    heap_key = jax.random.uniform(jax.random.PRNGKey(i), shape=(filled_sample,), minval=0, maxval=10)
    value = jax.vmap(HashTableHeapValue.default)(jnp.arange(filled_sample))
    heap = insert(heap, heap_key, value)
    pbar.update(filled_sample)
pbar.close()

100%|██████████| 100000/100000 [00:01<00:00, 77169.92it/s]


In [8]:
pbar = tqdm(total=int(heap.size))
last_min = jnp.inf
last_max = -jnp.inf
while heap.size > 0:
    heap, min_key, min_val = delete_mins(heap)
    minimum = jnp.min(min_key)
    maximum = jnp.max(min_key)
    if minimum < last_max:
        print("Error last_max", last_max, "minimum", minimum)
        print("Not sorted")
        break
    last_min = minimum
    last_max = maximum
    pbar.update(group_size)
pbar.close()

100%|██████████| 100000/100000 [00:06<00:00, 15817.21it/s]


In [9]:
print(heap.key_buffer)

[inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf
 inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf
 inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf
 inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf
 inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf
 inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf
 inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf
 inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf
 inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf
 inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf
 inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf
 inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf
 inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf
 inf inf inf inf inf inf inf inf inf inf inf inf in