In [1]:
from conftest import setup_project
setup_project()

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import jax
import jax.numpy as jnp

from JAxtar.hash import hash_func_builder
from puzzle.slidepuzzle import SlidePuzzle
from heuristic.slidepuzzle_heuristic import SlidePuzzleHeuristic

In [3]:
puzzle = SlidePuzzle(4)
heuristic = SlidePuzzleHeuristic(puzzle)

In [4]:
#check batch generation
defualt_state = jax.vmap(puzzle.State.default)(jnp.zeros(10000))
print(defualt_state[0])
states = jax.vmap(puzzle.get_initial_state, in_axes=0)(key=jax.random.split(jax.random.PRNGKey(0),10))
print(states[0])
print("Solverable : ", puzzle._solvable(states[0]))

┏━━━┳━━━┳━━━┳━━━┓
┃   ┃   ┃   ┃   ┃
┣━━━╋━━━╋━━━╋━━━┫
┃   ┃   ┃   ┃   ┃
┣━━━╋━━━╋━━━╋━━━┫
┃   ┃   ┃   ┃   ┃
┣━━━╋━━━╋━━━╋━━━┫
┃   ┃   ┃   ┃   ┃
┗━━━┻━━━┻━━━┻━━━┛
┏━━━┳━━━┳━━━┳━━━┓
┃ 4 ┃ 2 ┃ C ┃ B ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ F ┃   ┃ 8 ┃ 3 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ A ┃ 9 ┃ 1 ┃ 7 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ E ┃ 5 ┃ D ┃ 6 ┃
┗━━━┻━━━┻━━━┻━━━┛
Solverable :  True


In [5]:
#check solverable is working
states = puzzle.State(board=jnp.array([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,0]))
print(states)
print("Solverable : ", puzzle._solvable(states))
states = puzzle.State(board=jnp.array([1,2,3,4,5,6,7,8,9,10,11,12,13,15,14,0]))
print(states)
print("Solverable : ", puzzle._solvable(states))

┏━━━┳━━━┳━━━┳━━━┓
┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 9 ┃ A ┃ B ┃ C ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ D ┃ E ┃ F ┃   ┃
┗━━━┻━━━┻━━━┻━━━┛
Solverable :  True
┏━━━┳━━━┳━━━┳━━━┓
┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 9 ┃ A ┃ B ┃ C ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ D ┃ F ┃ E ┃   ┃
┗━━━┻━━━┻━━━┻━━━┛
Solverable :  False


In [6]:
#check neighbours
states = puzzle.State(board=jnp.array([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,0]))
print(states)
next_states, costs = puzzle.get_neighbours(states)
print(next_states)
print(costs)

┏━━━┳━━━┳━━━┳━━━┓
┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 9 ┃ A ┃ B ┃ C ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ D ┃ E ┃ F ┃   ┃
┗━━━┻━━━┻━━━┻━━━┛
┏━━━┳━━━┳━━━┳━━━┓  ┏━━━┳━━━┳━━━┳━━━┓  ┏━━━┳━━━┳━━━┳━━━┓  ┏━━━┳━━━┳━━━┳━━━┓  batch : (4,)
┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃  ┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃  ┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃  ┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫
┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃  ┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃  ┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃  ┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃
┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫
┃ 9 ┃ A ┃ B ┃ C ┃  ┃ 9 ┃ A ┃ B ┃ C ┃  ┃ 9 ┃ A ┃ B ┃ C ┃  ┃ 9 ┃ A ┃ B ┃   ┃
┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫
┃ D ┃ E ┃ F ┃   ┃  ┃ D ┃ E ┃   ┃ F ┃  ┃ D ┃ E ┃ F ┃   ┃  ┃ D ┃ E ┃ F ┃ C ┃
┗━━━┻━━━┻━━━┻━━━┛  ┗━━━┻━━━┻━━━┻━━━┛  ┗━━━┻━━━┻━━━┻━━━┛  ┗━━━┻━━━┻━━━┻━━━┛
[inf  1. inf  1.]


In [7]:
# check batch neighbours
states = jax.vmap(puzzle.get_initial_state, in_axes=0)(key=jax.random.split(jax.random.PRNGKey(0),int(1e6))) # total 10 million states
next_states, costs = jax.vmap(puzzle.get_neighbours, in_axes=0)(states)
print(next_states)
next_states = next_states.flatten()
costs = costs.flatten()
print(next_states.shape)
print(next_states.dtype)
print(costs.shape)

┏━━━┳━━━┳━━━┳━━━┓  ┏━━━┳━━━┳━━━┳━━━┓  ┏━━━┳━━━┳━━━┳━━━┓  ...                     ┏━━━┳━━━┳━━━┳━━━┓  ┏━━━┳━━━┳━━━┳━━━┓  ┏━━━┳━━━┳━━━┳━━━┓
┃ D ┃ 9 ┃ 4 ┃ 2 ┃  ┃ D ┃ 9 ┃ 4 ┃ 2 ┃  ┃ D ┃ 9 ┃ 4 ┃ 2 ┃  (batch : (1000000, 4))  ┃ B ┃ 7 ┃ 1 ┃ F ┃  ┃ B ┃ 7 ┃ 1 ┃ F ┃  ┃ B ┃ 7 ┃ 1 ┃ F ┃
┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫                          ┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫
┃ 8 ┃ 3 ┃   ┃ C ┃  ┃   ┃ 8 ┃ 3 ┃ C ┃  ┃ 8 ┃ B ┃ 3 ┃ C ┃                          ┃ 9 ┃ A ┃ C ┃ 4 ┃  ┃ 9 ┃ A ┃ C ┃ 4 ┃  ┃ 9 ┃ A ┃ C ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫                          ┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫
┃ 5 ┃ B ┃ A ┃ 7 ┃  ┃ 5 ┃ B ┃ A ┃ 7 ┃  ┃ 5 ┃   ┃ A ┃ 7 ┃                          ┃ 6 ┃ 3 ┃ E ┃ 2 ┃  ┃ 6 ┃ 3 ┃ E ┃ 2 ┃  ┃ 6 ┃ 3 ┃   ┃ 2 ┃
┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫                          ┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫
┃ E ┃ 1 ┃ 6 ┃ F ┃  ┃ E ┃ 1 ┃ 6 ┃ F ┃  ┃ E

In [8]:
puzzle_hash_fun: callable = hash_func_builder(puzzle.State)

In [9]:
#check hashing
hashes = jax.vmap(puzzle_hash_fun, in_axes=(0, None))(states, 1)
#count hash collision
print(hashes.shape)
print(hashes.dtype)
print(jnp.unique(hashes).shape) # Low collision
print(jnp.unique(states.board, axis=0).shape) # Low collision

hashes = jax.vmap(puzzle_hash_fun, in_axes=(0, None))(next_states, 1)
#count hash collision
print(hashes.shape)
print(hashes.dtype)
print(jnp.unique(hashes).shape) # High collision
print(jnp.unique(next_states.board, axis=0).shape) # High collision

(1000000,)
uint32
(999886,)
(999999, 16)
(4000000,)
uint32
(3748792,)
(3750470, 16)


In [10]:
#check heuristic
print("Heuristic")
dist = jax.vmap(heuristic.distance, in_axes=(0, None))(next_states, states[0])
print(dist)

Heuristic
[ 1.  1.  1. ... 44. 43. 42.]


In [11]:
from heuristic.DAVI.neuralheuristic.slidepuzzle_neuralheuristic import SlidePuzzleNeuralHeuristic
neural_heuristic = SlidePuzzleNeuralHeuristic.load_model(puzzle, "../heuristic/DAVI/neuralheuristic/params/n-puzzle_4.pkl")
neural_heuristic_fn = neural_heuristic.distance
target_state = puzzle.get_target_state()
target_neighbours, costs = puzzle.get_neighbours(target_state)
print(target_neighbours)
dist = jax.vmap(neural_heuristic_fn, in_axes=(0, None))(target_neighbours.flatten(), target_state)
print(dist)
print(dist.shape)
print(dist.dtype)
print(jnp.mean(dist))

dist = jax.vmap(neural_heuristic_fn, in_axes=(0, None))(next_states[:100], target_state)
print(dist)
print(dist.shape)
print(dist.dtype)
print(jnp.mean(dist))

┏━━━┳━━━┳━━━┳━━━┓  ┏━━━┳━━━┳━━━┳━━━┓  ┏━━━┳━━━┳━━━┳━━━┓  ┏━━━┳━━━┳━━━┳━━━┓  batch : (4,)
┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃  ┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃  ┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃  ┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫
┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃  ┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃  ┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃  ┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃
┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫
┃ 9 ┃ A ┃ B ┃ C ┃  ┃ 9 ┃ A ┃ B ┃ C ┃  ┃ 9 ┃ A ┃ B ┃ C ┃  ┃ 9 ┃ A ┃ B ┃   ┃
┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫  ┣━━━╋━━━╋━━━╋━━━┫
┃ D ┃ E ┃ F ┃   ┃  ┃ D ┃ E ┃   ┃ F ┃  ┃ D ┃ E ┃ F ┃   ┃  ┃ D ┃ E ┃ F ┃ C ┃
┗━━━┻━━━┻━━━┻━━━┛  ┗━━━┻━━━┻━━━┻━━━┛  ┗━━━┻━━━┻━━━┻━━━┛  ┗━━━┻━━━┻━━━┻━━━┛
4 (16,)
[0.        1.2606621 0.        1.4513779]
(4,)
float32
0.67801
[45.122543 44.578976 44.029255 45.682407 54.29162  53.69931  54.29162
 55.446392 43.21599  44.347206 44.347206 44.916676 52.93885  49.204758
 50.32633  50.339493 50.294098 51.00911  49.333225 50.657272 52.12806
 49.93093  52.17659  52.75805