In [1]:
import sys
import os

# 프로젝트 루트 디렉토리 경로를 추가
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import jax
import jax.numpy as jnp

from JAxtar.hash import hash_func_builder
from puzzle.slidepuzzle import SlidePuzzle
from heuristic.slidepuzzle_heuristic import SlidePuzzleHeuristic

In [3]:
puzzle = SlidePuzzle(4)
heuristic = SlidePuzzleHeuristic(puzzle)

In [4]:
size = 20
#check batch generation
states = jax.vmap(puzzle.get_initial_state, in_axes=0)(key=jax.random.split(jax.random.PRNGKey(0),1))
print(states[0])
print("Solverable : ", puzzle._solverable(states[0]))

┏━━━┳━━━┳━━━┳━━━┓
┃ B ┃ 4 ┃ 6 ┃   ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 1 ┃ 5 ┃ 7 ┃ 3 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ F ┃ A ┃ 2 ┃ C ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ E ┃ D ┃ 9 ┃ 8 ┃
┗━━━┻━━━┻━━━┻━━━┛
Solverable :  True


In [5]:
dummy = jax.vmap(puzzle.State.default)(jnp.zeros(size - 1))
states = jax.tree_util.tree_map(lambda x, y: jnp.concatenate([x, y]), states, dummy)
filled = jnp.zeros(size, dtype=jnp.bool_).at[0].set(1)



In [6]:
def visualize_state(state, filled):
    for i in range(size):
        if filled[i]:
            print(state[i])

In [7]:
for i in range(3):
    visualize_state(states, filled)
    neighbours, cost = jax.vmap(puzzle.get_neighbours, in_axes=(0,0))(states, filled)
    neighbours_filled = jnp.isfinite(cost)

    first_flat = lambda x: jnp.reshape(x, (-1, *x.shape[2:]))
    neighbours = jax.tree_util.tree_map(first_flat, neighbours)
    neighbours_filled = first_flat(neighbours_filled)

    filled_sort = jnp.argsort(neighbours_filled)[::-1]
    neighbours = jax.tree_util.tree_map(lambda x: x[filled_sort], neighbours)
    filled = neighbours_filled[filled_sort]

    states = jax.tree_util.tree_map(lambda x: x[:size], neighbours)
    filled = filled[:size]
    print(filled)

┏━━━┳━━━┳━━━┳━━━┓
┃ B ┃ 4 ┃ 6 ┃   ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 1 ┃ 5 ┃ 7 ┃ 3 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ F ┃ A ┃ 2 ┃ C ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ E ┃ D ┃ 9 ┃ 8 ┃
┗━━━┻━━━┻━━━┻━━━┛


In [None]:
print(filled)

[ True  True  True  True  True  True False False False False False False
 False False False False False False False False]
