In [1]:
import sys
import os

# 프로젝트 루트 디렉토리 경로를 추가
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import jax
import jax.numpy as jnp
import chex
import numpy as np
#disable jax JIT
#jax.config.update("jax_disable_jit", True)

from tqdm.autonotebook import tqdm
from JAxtar.hash import hash_func_builder, HashTable
from JAxtar.bgpq import BGPQ, HashTableIdx_HeapValue
from JAxtar.astar import astar_builder
from functools import partial
from puzzle.slidepuzzle import SlidePuzzle
from heuristic.slidepuzzle_heuristic import SlidePuzzleHeuristic

  from tqdm.autonotebook import tqdm


In [3]:
puzzle = SlidePuzzle(4)
size = int(2e7)
batch_size = int(10000)

In [4]:
states = jax.vmap(puzzle.get_initial_state, in_axes=0)(key=jax.random.split(jax.random.PRNGKey(32),1))
#states = puzzle.State(board=jnp.array([0, 12, 9, 13, 15, 11, 10, 14, 3, 7, 2, 5, 4, 8, 6, 1], dtype=jnp.uint8))[jnp.newaxis, ...]
target = puzzle.State(board=jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0], dtype=jnp.uint8))

In [5]:
print(states[0])
print(target)

┏━━━┳━━━┳━━━┳━━━┓
┃ C ┃ 1 ┃ B ┃ E ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ A ┃ 5 ┃ 6 ┃   ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 8 ┃ F ┃ 9 ┃ D ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 3 ┃ 4 ┃ 2 ┃ 7 ┃
┗━━━┻━━━┻━━━┻━━━┛
┏━━━┳━━━┳━━━┳━━━┓
┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 9 ┃ A ┃ B ┃ C ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ D ┃ E ┃ F ┃   ┃
┗━━━┻━━━┻━━━┻━━━┛


In [6]:
astar_fn = astar_builder(puzzle, SlidePuzzleHeuristic(puzzle).distance, batch_size, size)

states, filled = HashTable.make_batched(puzzle.State, states, batch_size)
astar_result, solved, solved_idx = astar_fn(states, filled, target)

In [7]:
print(solved, solved_idx)

True HashTableIdx_HeapValue(index=Array([7355726], dtype=int32), table_index=Array([1], dtype=int32))


In [8]:
solved_st = astar_result.hashtable.table[solved_idx.index, solved_idx.table_index][0]
solved_cost = astar_result.cost[solved_idx.index, solved_idx.table_index][0]
print(solved_st)
print(solved_cost)

┏━━━┳━━━┳━━━┳━━━┓
┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 9 ┃ A ┃ B ┃ C ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ D ┃ E ┃ F ┃   ┃
┗━━━┻━━━┻━━━┻━━━┛
66.0


In [9]:
astar_result, solved, solved_idx = astar_fn(states, filled, target)

In [10]:
parants = astar_result.parant
table = astar_result.hashtable.table
cost = astar_result.cost

In [11]:
path = []
parant_last = parants[solved_idx.index, solved_idx.table_index][0]
for i in range(100):
    if parant_last[0] == -1:
        break
    path.append(parant_last)
    parant_last = parants[*parant_last]

for p in path[::-1]:
    state = table[p[0], p[1]]
    c = cost[p[0], p[1]]
    print(state)
    print(c)
print(solved_st)
print(solved_cost)

┏━━━┳━━━┳━━━┳━━━┓
┃ C ┃ 1 ┃ B ┃ E ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ A ┃ 5 ┃ 6 ┃   ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 8 ┃ F ┃ 9 ┃ D ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 3 ┃ 4 ┃ 2 ┃ 7 ┃
┗━━━┻━━━┻━━━┻━━━┛
0.0
┏━━━┳━━━┳━━━┳━━━┓
┃ C ┃ 1 ┃ B ┃   ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ A ┃ 5 ┃ 6 ┃ E ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 8 ┃ F ┃ 9 ┃ D ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 3 ┃ 4 ┃ 2 ┃ 7 ┃
┗━━━┻━━━┻━━━┻━━━┛
1.0
┏━━━┳━━━┳━━━┳━━━┓
┃ C ┃ 1 ┃   ┃ B ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ A ┃ 5 ┃ 6 ┃ E ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 8 ┃ F ┃ 9 ┃ D ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 3 ┃ 4 ┃ 2 ┃ 7 ┃
┗━━━┻━━━┻━━━┻━━━┛
2.0
┏━━━┳━━━┳━━━┳━━━┓
┃ C ┃   ┃ 1 ┃ B ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ A ┃ 5 ┃ 6 ┃ E ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 8 ┃ F ┃ 9 ┃ D ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 3 ┃ 4 ┃ 2 ┃ 7 ┃
┗━━━┻━━━┻━━━┻━━━┛
3.0
┏━━━┳━━━┳━━━┳━━━┓
┃ C ┃ 5 ┃ 1 ┃ B ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ A ┃   ┃ 6 ┃ E ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 8 ┃ F ┃ 9 ┃ D ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 3 ┃ 4 ┃ 2 ┃ 7 ┃
┗━━━┻━━━┻━━━┻━━━┛
4.0
┏━━━┳━━━┳━━━┳━━━┓
┃ C ┃ 5 ┃ 1 ┃ B ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ A ┃ F ┃ 6 ┃ E ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 8 ┃   ┃ 9 ┃ D ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 3 ┃ 4 ┃ 2 ┃ 7 ┃
┗━━━┻━━━┻━━━┻━━━┛
5.0
┏━━━