In [1]:
import sys
import os

# 프로젝트 루트 디렉토리 경로를 추가
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import jax
import jax.numpy as jnp

from JAxtar.hash import hash_func_builder, HashTable
from JAxtar.bgpq import BGPQ, HashTableHeapValue
from functools import partial
from puzzle.slidepuzzle import SlidePuzzle
from heuristic.slidepuzzle_heuristic import SlidePuzzleHeuristic

2024-08-01 08:59:26.707637: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [3]:
puzzle = SlidePuzzle(4)
heuristic = SlidePuzzleHeuristic(puzzle)

In [4]:
size = int(1e7)
group_size = int(1024)
#check batch generation
states = jax.vmap(puzzle.get_initial_state, in_axes=0)(key=jax.random.split(jax.random.PRNGKey(123),1))
print(states[0])
hash_func = hash_func_builder(puzzle.State)
table = HashTable.make_lookup_table(puzzle.State, 1, size)
heap = BGPQ.make_heap(size, group_size, HashTableHeapValue)

lookup = jax.jit(partial(HashTable.lookup, hash_func))
parallel_insert = jax.jit(partial(HashTable.parallel_insert, hash_func))

┏━━━┳━━━┳━━━┳━━━┓
┃ 5 ┃   ┃ A ┃ E ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 8 ┃ 7 ┃ 2 ┃ D ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 1 ┃ 6 ┃ 9 ┃ 3 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 4 ┃ F ┃ C ┃ B ┃
┗━━━┻━━━┻━━━┻━━━┛


In [5]:
states, filled = HashTable.make_batched(puzzle.State, states, group_size)
cost = jnp.full((group_size,), jnp.inf).at[0].set(0)
table, inserted = parallel_insert(table, states, filled)
idx, table_idx, found = jax.vmap(lookup, in_axes=(None, 0))(table, states)
found = found & filled
hash_idxs = HashTableHeapValue(index=jnp.expand_dims(idx,axis=-1), table_index=jnp.expand_dims(table_idx,axis=-1))
heap = BGPQ.insert(heap, cost, hash_idxs)

In [6]:
def visualize_state(state, filled):
    for i in range(size):
        if filled[i]:
            print(state[i])

In [7]:
print(states[0])
while heap.size < size:
    print(heap.size)
    heap, min_key, min_val = BGPQ.delete_mins(heap)
    cost = min_key
    filled = jnp.isfinite(min_key)
    states = table.table[min_val.index.squeeze(), min_val.table_index.squeeze()]
    #print(states[0])

    neighbours, ncost = jax.vmap(puzzle.get_neighbours, in_axes=(0,0))(states, filled)
    _, _, nfound = jax.vmap(jax.vmap(lookup, in_axes=(None, 0)), in_axes=(None, 0))(table, neighbours)
    for i in range(4):
        n_state = neighbours[:,i]
        n_cost = ncost[:,i] + cost + jnp.where(nfound[:,i], jnp.inf, 0)
        filled = jnp.isfinite(n_cost)
        table, inserted = parallel_insert(table, n_state, filled)
        idx, table_idx, found = jax.vmap(lookup, in_axes=(None, 0))(table, n_state)
        hash_idxs = HashTableHeapValue(index=jnp.expand_dims(idx,axis=-1), table_index=jnp.expand_dims(table_idx,axis=-1))
        heap = BGPQ.insert(heap, n_cost, hash_idxs)
print(states[0])

┏━━━┳━━━┳━━━┳━━━┓
┃ 5 ┃   ┃ A ┃ E ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 8 ┃ 7 ┃ 2 ┃ D ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 1 ┃ 6 ┃ 9 ┃ 3 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 4 ┃ F ┃ C ┃ B ┃
┗━━━┻━━━┻━━━┻━━━┛
1
3
6
14
32
66
136
286
609
1277
2391
3412
2388
2997
1973
949
1982
3153
4185
3161
2187
1163
139
249
599
1351
2541
3573
2549
3352
2328
1304
280
439
861
1832
3304
4313
3289
2924
1900
876
1987
3000
4007
2983
1959
935
1869
3009
3792
2768
2545
1521
497
1031
2170
3138
2114
1406
382
778
1649
2764
3672
2648
1676
652
1239
2452
3796
2772
3667
2643
1619
595
1146
2007
2928
4350
3531
2507
1483
459
1032
2081
3176
2152
1235
211
454
983
2201
3502
2478
1711
687
1301
2317
3552
2528
3397
2373
1349
325
598
1117
2227
3466
2442
2179
1155
131
243
374
767
1908
2946
4347
3323
2299
1275
251
455
958
2243
3720
2696
2966
1942
918
2058
3193
2169
1165
141
311
579
1149
2443
3409
2385
2724
1700
676
1363
2568
3685
2661
3410
2386
1362
338
742
1626
2726
3779
2755
2217
1193
169
321
535
1087


KeyboardInterrupt: 

In [None]:
print(filled)

[ True  True  True ... False  True False]
