In [1]:
import sys
import os

# 프로젝트 루트 디렉토리 경로를 추가
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import jax
import jax.numpy as jnp
import time
#disable jax JIT
#jax.config.update("jax_disable_jit", True)

from tqdm.autonotebook import trange
from functools import partial
from JAxtar.hash import hash_func_builder
from puzzle.slidepuzzle import SlidePuzzle
from heuristic.slidepuzzle_heuristic import SlidePuzzleHeuristic
from JAxtar.hash import HashTable

  from tqdm.autonotebook import trange
2024-08-05 09:01:18.318907: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [3]:
count = 1000
puzzle = SlidePuzzle(4)
hash_func = hash_func_builder(puzzle.State)
sample = jax.vmap(puzzle.get_initial_state)(key=jax.random.split(jax.random.PRNGKey(2),count))
new_sample = jax.vmap(puzzle.get_initial_state)(key=jax.random.split(jax.random.PRNGKey(1),count))
table = HashTable.build(puzzle.State, 1, int(1e5))

lookup = jax.jit(partial(HashTable.lookup, hash_func))
start = time.time()
idx, table_idx, found = jax.vmap(lookup, in_axes=(None, 0))(table, sample)
print(time.time()-start)

0.4201200008392334


In [4]:
batch = 4000
parallel_insert = jax.jit(partial(HashTable.parallel_insert, hash_func))
for i in range(18):
    sample = jax.vmap(puzzle.get_initial_state)(key=jax.random.split(jax.random.PRNGKey(i + 256),count))
    idx, table_idx, old_found = jax.vmap(lookup, in_axes=(None, 0))(table, sample)
    batched_sample, filled = HashTable.make_batched(puzzle.State, sample, batch)
    start = time.time()
    table, inserted, _, _ = parallel_insert(table, batched_sample, filled)
    print(time.time()-start)
    idx, table_idx, found = jax.vmap(lookup, in_axes=(None, 0))(table, sample)
    print(jnp.mean(found), jnp.mean(old_found), jnp.mean(inserted))

1.518310308456421
1.0 0.0 0.25
1.4585397243499756
0.9990001 0.0 0.25
0.0007376670837402344
1.0 0.0 0.25
0.0006704330444335938
1.0 0.0 0.25
0.0005617141723632812
1.0 0.0 0.25
0.0006668567657470703
1.0 0.0 0.25
0.0006608963012695312
0.9990001 0.0 0.25
0.0006058216094970703
1.0 0.0 0.25
0.0005161762237548828
0.9990001 0.0 0.25
0.0005009174346923828
1.0 0.0 0.25
0.0005142688751220703
0.9990001 0.0 0.25
0.0005397796630859375
1.0 0.0 0.25
0.0005323886871337891
0.9990001 0.0 0.25
0.0005404949188232422
1.0 0.0 0.25
0.0006415843963623047
0.9990001 0.0 0.25
0.0005421638488769531
1.0 0.0 0.25
0.0005505084991455078
1.0 0.0 0.25
0.0005810260772705078
1.0 0.0 0.25


In [5]:
jnp.max(table.table_idx)

Array(2, dtype=uint8)

In [6]:
count = int(1e6)
puzzle = SlidePuzzle(4)
hash_func = hash_func_builder(puzzle.State)
sample = jax.vmap(puzzle.get_initial_state)(key=jax.random.split(jax.random.PRNGKey(2),count))
new_sample = jax.vmap(puzzle.get_initial_state)(key=jax.random.split(jax.random.PRNGKey(1),count))
table = HashTable.build(puzzle.State, 1, int(1e7))

lookup = jax.jit(partial(HashTable.lookup, hash_func))
start = time.time()
idx, table_idx, found = jax.vmap(lookup, in_axes=(None, 0))(table, sample)
print(time.time()-start)

0.3768134117126465


In [7]:
batch = 10000
parallel_insert = jax.jit(partial(HashTable.parallel_insert, hash_func))
for i in range(18):
    inserteds = []
    sample = jax.vmap(puzzle.get_initial_state)(key=jax.random.split(jax.random.PRNGKey(i + 256),count))
    idx, table_idx, found = jax.vmap(lookup, in_axes=(None, 0))(table, sample)
    same_ratio = jnp.mean(found)
    for j in trange(0, count, batch):
        table, inserted, _, _ = parallel_insert(table, sample[j:j+batch], jnp.ones(batch, dtype=jnp.bool_))
        inserteds.append(inserted)
    inserteds = jnp.concatenate(inserteds)
    idx, table_idx, found = jax.vmap(lookup, in_axes=(None, 0))(table, sample)
    print(jnp.mean(found), same_ratio, jnp.mean(inserteds))

100%|██████████| 100/100 [00:03<00:00, 27.41it/s]


0.999972 0.0 1.0


100%|██████████| 100/100 [00:00<00:00, 566.86it/s]


0.999923 0.0 1.0


100%|██████████| 100/100 [00:00<00:00, 566.66it/s]


0.9999 0.0 1.0


100%|██████████| 100/100 [00:00<00:00, 560.13it/s]


0.99987 1e-06 0.999999


100%|██████████| 100/100 [00:00<00:00, 517.39it/s]


0.999789 0.0 1.0


100%|██████████| 100/100 [00:00<00:00, 525.93it/s]


0.999758 2e-06 0.999998


100%|██████████| 100/100 [00:00<00:00, 510.12it/s]


0.999713 0.0 1.0


100%|██████████| 100/100 [00:00<00:00, 483.11it/s]


0.999638 1e-06 0.999998


100%|██████████| 100/100 [00:00<00:00, 443.78it/s]


0.999618 1e-06 0.999999


100%|██████████| 100/100 [00:00<00:00, 416.03it/s]


0.999511 0.0 1.0


100%|██████████| 100/100 [00:00<00:00, 393.49it/s]


0.999442 0.0 0.999999


100%|██████████| 100/100 [00:00<00:00, 377.75it/s]


0.999343 1e-06 0.999999


100%|██████████| 100/100 [00:00<00:00, 327.97it/s]


0.999194 2e-06 0.999998


100%|██████████| 100/100 [00:00<00:00, 295.29it/s]


0.999016 0.0 1.0


100%|██████████| 100/100 [00:00<00:00, 250.49it/s]


0.998677 0.0 1.0


100%|██████████| 100/100 [00:00<00:00, 204.74it/s]


0.998372 0.0 1.0


100%|██████████| 100/100 [00:00<00:00, 165.42it/s]


0.997811 2e-06 0.999998


100%|██████████| 100/100 [00:00<00:00, 116.98it/s]


0.996594 0.0 1.0


In [8]:
jnp.max(table.table_idx)

Array(2, dtype=uint8)

In [9]:
another_sample = jax.vmap(puzzle.get_initial_state)(key=jax.random.split(jax.random.PRNGKey(123),count))
idx, table_idx, found = jax.vmap(lookup, in_axes=(None, 0))(table, another_sample)
print(jnp.sum(found) / count)

1e-06


In [10]:
start = time.time()
inserteds = []
for i in trange(count // batch):
    table, inserted, _, _ = parallel_insert(table, another_sample[i*batch:(i+1)*batch], jnp.ones(batch, dtype=jnp.bool_))
    inserteds.append(inserted)
print(jnp.sum(jnp.concatenate(inserteds)) / count)
print("insert time:", time.time()-start)
start = time.time()
idx, table_idx, found = jax.vmap(lookup, in_axes=(None, 0))(table, sample)
print("check time:", time.time()-start)
print(jnp.sum(found) / count)
start = time.time()
idx, table_idx, found = jax.vmap(lookup, in_axes=(None, 0))(table, new_sample)
print("check time:", time.time()-start)
print(jnp.sum(found) / count)
idx, table_idx, found = jax.vmap(lookup, in_axes=(None, 0))(table, another_sample)
print("check time:", time.time()-start)
print(jnp.sum(found) / count)

100%|██████████| 100/100 [00:01<00:00, 63.70it/s]


0.999999
insert time: 1.576587200164795
check time: 0.13985466957092285
0.996594
check time: 0.15284442901611328
1e-06
check time: 0.31472277641296387
0.993682
