In [1]:
import sys
import os

# 프로젝트 루트 디렉토리 경로를 추가
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import jax
import jax.numpy as jnp
import chex
import numpy as np
#disable jax JIT
#jax.config.update("jax_disable_jit", True)

from tqdm.autonotebook import tqdm
from JAxtar.hash import hash_func_builder, HashTable
from JAxtar.bgpq import BGPQ, HashTableIdx_HeapValue
from functools import partial
from puzzle.slidepuzzle import SlidePuzzle
from heuristic.slidepuzzle_heuristic import SlidePuzzleHeuristic

  from tqdm.autonotebook import tqdm
2024-08-06 02:19:01.985421: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [3]:
size = int(1e7)
batch_size = int(10000)

In [4]:
puzzle = SlidePuzzle(4)

In [5]:
hash_func = hash_func_builder(puzzle.State)
table = HashTable.build(puzzle.State, 1, size)
n_table = table.n_table
heap = BGPQ.build(size, batch_size, HashTableIdx_HeapValue)
cost = jnp.full((size, n_table), jnp.inf)
closed = jnp.full((size, n_table), False, dtype=bool)
parant = jnp.full((size, n_table, 2), -1, dtype=int)

In [6]:
lookup = jax.jit(jax.vmap(partial(HashTable.lookup, hash_func), in_axes=(None, 0)))
parallel_insert = jax.jit(partial(HashTable.parallel_insert, hash_func))
heuristic = jax.jit(jax.vmap(SlidePuzzleHeuristic(puzzle).distance, in_axes=(0, None)))
solved_fn = jax.jit(jax.vmap(jax.vmap(puzzle.is_solved, in_axes=(0, None)), in_axes=(0, None)))
neighbours_fn = jax.jit(jax.vmap(puzzle.get_neighbours, in_axes=(0,0)))
delete_fn = jax.jit(BGPQ.delete_mins)
insert_fn = jax.jit(BGPQ.insert)

In [7]:
states = jax.vmap(puzzle.get_initial_state, in_axes=0)(key=jax.random.split(jax.random.PRNGKey(3),1))
target = puzzle.State(board=jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0], dtype=jnp.uint8))
states, filled = HashTable.make_batched(puzzle.State, states, batch_size)
cost_val = jnp.full((batch_size,), jnp.inf).at[0].set(0)
table, inserted, idx, table_idx = parallel_insert(table, states, filled)
found = inserted
print(found[:10])
heur_val = heuristic(states, target)
print(heur_val[:10])
hash_idxs = HashTableIdx_HeapValue(index=idx, table_index=table_idx)[:, jnp.newaxis]
cost = cost.at[idx, table_idx].set(jnp.where(found, cost_val, cost[idx, table_idx]))
key = cost_val + heur_val
heap = BGPQ.insert(heap, key, hash_idxs)

[ True False False False False False False False False False]
[34  0  0  0  0  0  0  0  0  0]


In [8]:
print(states[0])
pbar = tqdm(total=size)
pbar.update(1)
while heap.size < size and not heap.size == 0 and table.size < size:
    pbar_str = f"heap_size: {heap.size:8d}, total_nodes: {table.size:8d}, "
    # get the minimum key
    heap, min_key, min_val = delete_fn(heap)
    min_idx, min_table_idx = min_val.index, min_val.table_index
    parant_idx = jnp.stack((min_idx, min_table_idx), axis=-1).astype(jnp.int32)
    # get the state
    cost_val, closed_val = cost[min_idx, min_table_idx], closed[min_idx, min_table_idx]
    states = table.table[min_idx, min_table_idx]

    # check if the state is already closed
    filled = jnp.logical_and(jnp.isfinite(min_key),~closed_val)
    if not filled.any():
        continue
    closed = closed.at[min_idx, min_table_idx].set(jnp.where(filled, True, closed[min_idx, min_table_idx]))
    pbar_str += f"cost: {jnp.mean(cost_val):.2f}"
    pbar.set_description_str(pbar_str)

    neighbours, ncost = neighbours_fn(states, filled)
    nextcosts = cost_val[:, jnp.newaxis] + ncost
    solved = solved_fn(neighbours, target)
    if solved.any():
        print("solved")
        break
    nextheur = jax.vmap(heuristic, in_axes=(0, None))(neighbours, target)
    nextkeys = 1.0 * nextcosts + nextheur
    for i in range(4):
        nextkey = nextkeys[:, i]
        nextcost = nextcosts[:, i]
        nextstates = neighbours[:, i]
        filled = jnp.isfinite(nextkey)

        table, inserted, idx, table_idx = parallel_insert(table, nextstates, filled)
        added = int(jnp.sum(inserted))
        vals = HashTableIdx_HeapValue(index=idx, table_index=table_idx)[:, jnp.newaxis]
        more_optimal = (nextcost <= cost[idx, table_idx])
        cost = cost.at[idx, table_idx].set(jnp.minimum(nextcost, cost[idx, table_idx]))
        parant = parant.at[idx, table_idx].set(jnp.where(more_optimal[:,jnp.newaxis], parant_idx, parant[idx, table_idx]))
        closed = closed.at[idx, table_idx].set(jnp.where(more_optimal, False, closed[idx, table_idx]))
        nextkey = jnp.where(closed[idx, table_idx], jnp.inf, nextkey)

        heap = insert_fn(heap, nextkey, vals)
        pbar.update(added)
pbar.close()
solved_st = neighbours[solved][0]
n_cost = nextcosts[solved][0]
print(solved_st)
print(n_cost)

┏━━━┳━━━┳━━━┳━━━┓
┃   ┃ 3 ┃ 5 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ C ┃ 8 ┃ 6 ┃ 7 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 2 ┃ E ┃ A ┃ F ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ B ┃ D ┃ 1 ┃ 9 ┃
┗━━━┻━━━┻━━━┻━━━┛


heap_size:    86750, total_nodes:   681186, cost: 27.42:   7%|▋         | 681186/10000000 [00:13<03:11, 48753.40it/s] 


solved
┏━━━┳━━━┳━━━┳━━━┓
┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 9 ┃ A ┃ B ┃ C ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ D ┃ E ┃ F ┃   ┃
┗━━━┻━━━┻━━━┻━━━┛
48.0


In [9]:
path = []
idx = jnp.argmax(jnp.max(solved, axis=1))
parant_last = parant_idx[idx]
for i in range(100):
    if parant_last[0] == -1:
        break
    path.append(parant_last)
    parant_last = parant[parant_last[0], parant_last[1]]

for p in path[::-1]:
    state = table.table[p[0], p[1]]
    c = cost[p[0], p[1]]
    print(state)
    print(c)
print(solved_st)

┏━━━┳━━━┳━━━┳━━━┓
┃   ┃ 3 ┃ 5 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ C ┃ 8 ┃ 6 ┃ 7 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 2 ┃ E ┃ A ┃ F ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ B ┃ D ┃ 1 ┃ 9 ┃
┗━━━┻━━━┻━━━┻━━━┛
0.0
┏━━━┳━━━┳━━━┳━━━┓
┃ 3 ┃   ┃ 5 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ C ┃ 8 ┃ 6 ┃ 7 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 2 ┃ E ┃ A ┃ F ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ B ┃ D ┃ 1 ┃ 9 ┃
┗━━━┻━━━┻━━━┻━━━┛
1.0
┏━━━┳━━━┳━━━┳━━━┓
┃ 3 ┃ 5 ┃   ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ C ┃ 8 ┃ 6 ┃ 7 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 2 ┃ E ┃ A ┃ F ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ B ┃ D ┃ 1 ┃ 9 ┃
┗━━━┻━━━┻━━━┻━━━┛
2.0
┏━━━┳━━━┳━━━┳━━━┓
┃ 3 ┃ 5 ┃ 6 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ C ┃ 8 ┃   ┃ 7 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 2 ┃ E ┃ A ┃ F ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ B ┃ D ┃ 1 ┃ 9 ┃
┗━━━┻━━━┻━━━┻━━━┛
3.0
┏━━━┳━━━┳━━━┳━━━┓
┃ 3 ┃ 5 ┃ 6 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ C ┃   ┃ 8 ┃ 7 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 2 ┃ E ┃ A ┃ F ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ B ┃ D ┃ 1 ┃ 9 ┃
┗━━━┻━━━┻━━━┻━━━┛
4.0
┏━━━┳━━━┳━━━┳━━━┓
┃ 3 ┃ 5 ┃ 6 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ C ┃ E ┃ 8 ┃ 7 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 2 ┃   ┃ A ┃ F ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ B ┃ D ┃ 1 ┃ 9 ┃
┗━━━┻━━━┻━━━┻━━━┛
5.0
┏━━━