In [1]:
import sys
import os

# 프로젝트 루트 디렉토리 경로를 추가
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import jax
import jax.numpy as jnp
import time
#disable jax JIT
#jax.config.update("jax_disable_jit", True)

from tqdm.autonotebook import trange
from functools import partial
from JAxtar.hash import hash_func_builder
from puzzle.slidepuzzle import SlidePuzzle
from heuristic.slidepuzzle_heuristic import SlidePuzzleHeuristic
from JAxtar.hash import HashTable

  from tqdm.autonotebook import trange


In [3]:
count = 1000
puzzle = SlidePuzzle(4)
hash_func = hash_func_builder(puzzle.State)
sample = jax.vmap(puzzle.get_initial_state)(key=jax.random.split(jax.random.PRNGKey(2),count))
new_sample = jax.vmap(puzzle.get_initial_state)(key=jax.random.split(jax.random.PRNGKey(1),count))
table = HashTable.build(puzzle.State, 1, int(1e5))

lookup = jax.jit(partial(HashTable.lookup, hash_func))
start = time.time()
idx, table_idx, found = jax.vmap(lookup, in_axes=(None, 0))(table, sample)
print(time.time()-start)

0.20058012008666992


In [4]:
batch = 4000
parallel_insert = jax.jit(partial(HashTable.parallel_insert, hash_func))
for i in range(18):
    sample = jax.vmap(puzzle.get_initial_state)(key=jax.random.split(jax.random.PRNGKey(i + 256),count))
    idx, table_idx, old_found = jax.vmap(lookup, in_axes=(None, 0))(table, sample)
    batched_sample, filled = HashTable.make_batched(puzzle.State, sample, batch)
    start = time.time()
    table, inserted, _, _ = parallel_insert(table, batched_sample, filled)
    print(time.time()-start)
    idx, table_idx, found = jax.vmap(lookup, in_axes=(None, 0))(table, sample)
    print(jnp.mean(found), jnp.mean(old_found), jnp.mean(inserted))

0.8646855354309082
1.0 0.0 0.25
0.8525784015655518
1.0 0.0 0.25
0.0007708072662353516
1.0 0.0 0.25
0.0005652904510498047
1.0 0.0 0.25
0.0005927085876464844
1.0 0.0 0.25
0.0006411075592041016
1.0 0.0 0.25
0.0006673336029052734
1.0 0.0 0.25
0.0006196498870849609
1.0 0.0 0.25
0.0007860660552978516
1.0 0.0 0.25
0.0009541511535644531
1.0 0.0 0.25
0.0005936622619628906
1.0 0.0 0.25
0.00063323974609375
1.0 0.0 0.25
0.0005793571472167969
1.0 0.0 0.25
0.0006735324859619141
1.0 0.0 0.25
0.0006277561187744141
1.0 0.0 0.25
0.0007169246673583984
1.0 0.0 0.25
0.0007035732269287109
1.0 0.0 0.25
0.0007371902465820312
1.0 0.0 0.25


In [5]:
jnp.max(table.table_idx)

Array(4, dtype=uint8)

In [6]:
count = int(1e6)
puzzle = SlidePuzzle(4)
hash_func = hash_func_builder(puzzle.State)
sample = jax.vmap(puzzle.get_initial_state)(key=jax.random.split(jax.random.PRNGKey(2),count))
new_sample = jax.vmap(puzzle.get_initial_state)(key=jax.random.split(jax.random.PRNGKey(1),count))
table = HashTable.build(puzzle.State, 1, int(1e7))

lookup = jax.jit(partial(HashTable.lookup, hash_func))
start = time.time()
idx, table_idx, found = jax.vmap(lookup, in_axes=(None, 0))(table, sample)
print(time.time()-start)

0.20178890228271484


In [7]:
batch = 100000
parallel_insert = jax.jit(partial(HashTable.parallel_insert, hash_func))
for i in range(18):
    inserteds = []
    sample = jax.vmap(puzzle.get_initial_state)(key=jax.random.split(jax.random.PRNGKey(i + 256),count))
    idx, table_idx, found = jax.vmap(lookup, in_axes=(None, 0))(table, sample)
    same_ratio = jnp.mean(found)
    for j in trange(0, count, batch):
        table, inserted, _, _ = parallel_insert(table, sample[j:j+batch], jnp.ones(batch, dtype=jnp.bool_))
        inserteds.append(inserted)
    inserteds = jnp.concatenate(inserteds)
    idx, table_idx, found = jax.vmap(lookup, in_axes=(None, 0))(table, sample)
    print(jnp.mean(found), same_ratio, jnp.mean(inserteds))

  0%|          | 0/10 [00:00<?, ?it/s]

1.0 0.0 1.0


  0%|          | 0/10 [00:00<?, ?it/s]

1.0 0.0 1.0


  0%|          | 0/10 [00:00<?, ?it/s]

1.0 0.0 1.0


  0%|          | 0/10 [00:00<?, ?it/s]

1.0 1e-06 0.999999


  0%|          | 0/10 [00:00<?, ?it/s]

1.0 0.0 1.0


  0%|          | 0/10 [00:00<?, ?it/s]

1.0 2e-06 0.999998


  0%|          | 0/10 [00:00<?, ?it/s]

1.0 0.0 1.0


  0%|          | 0/10 [00:00<?, ?it/s]

1.0 1e-06 0.999998


  0%|          | 0/10 [00:00<?, ?it/s]

1.0 1e-06 0.999999


  0%|          | 0/10 [00:00<?, ?it/s]

1.0 0.0 1.0


  0%|          | 0/10 [00:00<?, ?it/s]

1.0 0.0 0.999999


  0%|          | 0/10 [00:00<?, ?it/s]

1.0 1e-06 0.999999


  0%|          | 0/10 [00:00<?, ?it/s]

1.0 2e-06 0.999998


  0%|          | 0/10 [00:00<?, ?it/s]

1.0 0.0 1.0


  0%|          | 0/10 [00:00<?, ?it/s]

1.0 0.0 1.0


  0%|          | 0/10 [00:00<?, ?it/s]

1.0 0.0 1.0


  0%|          | 0/10 [00:00<?, ?it/s]

1.0 2e-06 0.999998


  0%|          | 0/10 [00:00<?, ?it/s]

1.0 0.0 1.0


In [8]:
jnp.max(table.table_idx)

Array(10, dtype=uint8)

In [9]:
another_sample = jax.vmap(puzzle.get_initial_state)(key=jax.random.split(jax.random.PRNGKey(123),count))
idx, table_idx, found = jax.vmap(lookup, in_axes=(None, 0))(table, another_sample)
print(jnp.sum(found) / count)

1e-06


In [10]:
start = time.time()
inserteds = []
for i in trange(count // batch):
    table, inserted, _, _ = parallel_insert(table, another_sample[i*batch:(i+1)*batch], jnp.ones(batch, dtype=jnp.bool_))
    inserteds.append(inserted)
print(jnp.sum(jnp.concatenate(inserteds)) / count)
print("insert time:", time.time()-start)
start = time.time()
idx, table_idx, found = jax.vmap(lookup, in_axes=(None, 0))(table, sample)
print("check time:", time.time()-start)
print(jnp.sum(found) / count)
start = time.time()
idx, table_idx, found = jax.vmap(lookup, in_axes=(None, 0))(table, new_sample)
print("check time:", time.time()-start)
print(jnp.sum(found) / count)
idx, table_idx, found = jax.vmap(lookup, in_axes=(None, 0))(table, another_sample)
print("check time:", time.time()-start)
print(jnp.sum(found) / count)

  0%|          | 0/10 [00:00<?, ?it/s]

0.999999
insert time: 0.08522224426269531
check time: 0.011853456497192383
1.0
check time: 0.010859251022338867
1e-06
check time: 0.02329277992248535
0.999999
