In [1]:
import sys
import os

# 프로젝트 루트 디렉토리 경로를 추가
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import jax
import jax.numpy as jnp

from JAxtar.hash import hash_func_builder
from puzzle.slidepuzzle import SlidePuzzle
from heuristic.slidepuzzle_heuristic import SlidePuzzleHeuristic

In [3]:
puzzle = SlidePuzzle(4)
heuristic = SlidePuzzleHeuristic(puzzle)

In [4]:
#check batch generation
defualt_state = jax.vmap(puzzle.State.default)(jnp.zeros(10000))
print(defualt_state[0])
states = jax.vmap(puzzle.get_initial_state, in_axes=0)(key=jax.random.split(jax.random.PRNGKey(0),10))
print(states[0])
print("Solverable : ", puzzle._solverable(states[0]))

┏━━━┳━━━┳━━━┳━━━┓
┃   ┃   ┃   ┃   ┃
┣━━━╋━━━╋━━━╋━━━┫
┃   ┃   ┃   ┃   ┃
┣━━━╋━━━╋━━━╋━━━┫
┃   ┃   ┃   ┃   ┃
┣━━━╋━━━╋━━━╋━━━┫
┃   ┃   ┃   ┃   ┃
┗━━━┻━━━┻━━━┻━━━┛
┏━━━┳━━━┳━━━┳━━━┓
┃ 4 ┃ 2 ┃ C ┃ B ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ F ┃   ┃ 8 ┃ 3 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ A ┃ 9 ┃ 1 ┃ 7 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ E ┃ 5 ┃ D ┃ 6 ┃
┗━━━┻━━━┻━━━┻━━━┛
Solverable :  True


In [5]:
#check solverable is working
states = puzzle.State(board=jnp.array([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,0]))
print(states)
print("Solverable : ", puzzle._solverable(states))
states = puzzle.State(board=jnp.array([1,2,3,4,5,6,7,8,9,10,11,12,13,15,14,0]))
print(states)
print("Solverable : ", puzzle._solverable(states))

┏━━━┳━━━┳━━━┳━━━┓
┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 9 ┃ A ┃ B ┃ C ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ D ┃ E ┃ F ┃   ┃
┗━━━┻━━━┻━━━┻━━━┛
Solverable :  True
┏━━━┳━━━┳━━━┳━━━┓
┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 9 ┃ A ┃ B ┃ C ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ D ┃ F ┃ E ┃   ┃
┗━━━┻━━━┻━━━┻━━━┛
Solverable :  False


In [6]:
#check neighbours
states = puzzle.State(board=jnp.array([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,0]))
print(states)
next_states, costs = puzzle.get_neighbours(states)
for i in range(4):
    print(next_states[i])
    print(costs[i])

┏━━━┳━━━┳━━━┳━━━┓
┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 9 ┃ A ┃ B ┃ C ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ D ┃ E ┃ F ┃   ┃
┗━━━┻━━━┻━━━┻━━━┛
┏━━━┳━━━┳━━━┳━━━┓
┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 9 ┃ A ┃ B ┃ C ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ D ┃ E ┃ F ┃   ┃
┗━━━┻━━━┻━━━┻━━━┛
inf
┏━━━┳━━━┳━━━┳━━━┓
┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 9 ┃ A ┃ B ┃ C ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ D ┃ E ┃   ┃ F ┃
┗━━━┻━━━┻━━━┻━━━┛
1.0
┏━━━┳━━━┳━━━┳━━━┓
┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 9 ┃ A ┃ B ┃ C ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ D ┃ E ┃ F ┃   ┃
┗━━━┻━━━┻━━━┻━━━┛
inf
┏━━━┳━━━┳━━━┳━━━┓
┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 9 ┃ A ┃ B ┃   ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ D ┃ E ┃ F ┃ C ┃
┗━━━┻━━━┻━━━┻━━━┛
1.0


In [7]:
# check batch neighbours
states = jax.vmap(puzzle.get_initial_state, in_axes=0)(key=jax.random.split(jax.random.PRNGKey(0),int(1e7))) # total 10 million states
next_states, costs = jax.vmap(puzzle.get_neighbours, in_axes=0)(states)
first_flat = lambda x: jnp.reshape(x, (-1, *x.shape[2:]))
next_states = jax.tree_util.tree_map(first_flat, next_states)
costs = first_flat(costs)
print(next_states.shape)
print(next_states.dtype)
print(costs.shape)

shape(board=(40000000, 16))
dtype(board=dtype('uint8'))
(40000000,)


In [8]:
puzzle_hash_fun: callable = hash_func_builder(puzzle.State)

In [9]:
#check hashing
print("States Hashing, this should be not collision")
hashes = jax.vmap(puzzle_hash_fun, in_axes=(0, None))(states, 1)
#count hash collision
print(hashes.shape)
print(hashes.dtype)
print(jnp.unique(hashes).shape) # No collision
print(jnp.unique(states.board, axis=0).shape) # No collision

print("Next states Hashing, this should be collision")
hashes = jax.vmap(puzzle_hash_fun, in_axes=(0, None))(next_states, 1)
#count hash collision
print(hashes.shape)
print(hashes.dtype)
print(jnp.unique(hashes).shape) # Collision
print(jnp.unique(next_states.board, axis=0).shape) # No collision

States Hashing, this should be not collision
(10000000,)
uint32
(7952747,)
(9999997, 16)
Next states Hashing, this should be collision
(40000000,)
uint32
(17933809,)
(37620490, 16)


In [10]:
#check heuristic
print("Heuristic")
dist = jax.vmap(heuristic.distance, in_axes=(0, None))(next_states, states[0])
print(dist)

Heuristic


2024-07-20 22:10:33.088680: W external/xla/xla/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 305.18MiB (rounded to 320000000)requested by op 
2024-07-20 22:10:33.088762: W external/xla/xla/tsl/framework/bfc_allocator.cc:494] *_*************************************************************************************************_
E0720 22:10:33.088780    2937 pjrt_stream_executor_client.cc:2985] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 320000000 bytes.


ValueError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 320000000 bytes.