In [1]:
from conftest import setup_project
setup_project()

PosixPath('/home/tinker/JAxtar')

In [2]:
import jax
import jax.numpy as jnp

from JAxtar.hash import hash_func_builder
from puzzle.dotknot import DotKnot
from heuristic.dotknot_heuristic import DotKnotHeuristic

In [3]:
puzzle = DotKnot(7)
heuristic = DotKnotHeuristic(puzzle)

In [4]:
#check batch generation
target_states = puzzle.get_target_state()
print(target_states)
init_state = puzzle.get_initial_state()
print(init_state)
next_states, costs = puzzle.get_neighbours(init_state)
print(next_states)
print(costs)

┏━━━━━━━━━━━━━━━┓
┃ ? ? ? ? ? ? ? ┃
┃ ? ? ? ? ? ? ? ┃
┃ ? ? ? ? ? ? ? ┃
┃ ? ? ? ? ? ? ? ┃
┃ ? ? ? ? ? ? ? ┃
┃ ? ? ? ? ? ? ? ┃
┃ ? ? ? ? ? ? ? ┃
┗━━━━━━━━━━━━━━━┛
┏━━━━━━━━━━━━━━━┓
┃               ┃
┃               ┃
┃   [31m●[0m     [34m●[0m [31m●[0m   ┃
┃     [32m●[0m [33m●[0m [34m●[0m     ┃
┃       [33m●[0m       ┃
┃         [32m●[0m     ┃
┃               ┃
┗━━━━━━━━━━━━━━━┛




┏━━━━━━━━━━━━━━━┓  ┏━━━━━━━━━━━━━━━┓  ...             ┏━━━━━━━━━━━━━━━┓  ┏━━━━━━━━━━━━━━━┓
┃               ┃  ┃               ┃  (batch : (8,))  ┃               ┃  ┃               ┃
┃           [31m●[0m   ┃  ┃               ┃                  ┃               ┃  ┃               ┃
┃   [31m●[0m     [34m●[0m [31m■[0m   ┃  ┃   [31m●[0m     [34m●[0m [31m●[0m   ┃                  ┃   [31m●[0m     [34m●[0m [31m●[0m   ┃  ┃   [31m●[0m     [34m●[0m [31m●[0m   ┃
┃     [32m●[0m [33m●[0m [34m●[0m     ┃  ┃     [32m●[0m [33m●[0m [34m●[0m     ┃                  ┃     [32m●[0m [33m■[0m [34m●[0m     ┃  ┃     [32m●[0m [33m●[0m [34m■[0m     ┃
┃       [33m●[0m       ┃  ┃       [33m●[0m [32m●[0m     ┃                  ┃       [33m■[0m       ┃  ┃       [33m●[0m [34m●[0m     ┃
┃         [32m●[0m     ┃  ┃         [32m■[0m     ┃                  ┃         [32m●[0m     ┃  ┃         [32m●[0m     ┃
┃               ┃  ┃               ┃            

In [5]:
print(puzzle.is_solved(init_state, target_states))
print(puzzle.is_solved(init_state, target_states))

False
False


In [6]:
# check batch neighbours
states = jax.vmap(puzzle.get_initial_state, in_axes=0)(key=jax.random.split(jax.random.PRNGKey(0),int(1e3))) # total 10 million states
next_states, costs = jax.vmap(puzzle.get_neighbours, in_axes=0)(states)
print(next_states)
next_states = next_states.flatten()
costs = costs.flatten()
print(next_states.shape)
print(next_states.batch_shape)
print(next_states.dtype)
print(costs.shape)

┏━━━━━━━━━━━━━━━┓  ┏━━━━━━━━━━━━━━━┓  ...                  ┏━━━━━━━━━━━━━━━┓  ┏━━━━━━━━━━━━━━━┓
┃               ┃  ┃               ┃  (batch : (1000, 8))  ┃               ┃  ┃               ┃
┃       [31m●[0m       ┃  ┃       [31m●[0m       ┃                       ┃           [31m●[0m   ┃  ┃           [31m●[0m   ┃
┃     [34m●[0m         ┃  ┃     [34m●[0m   [32m●[0m     ┃                       ┃   [32m●[0m   [31m●[0m [34m●[0m     ┃  ┃   [32m●[0m   [31m●[0m [34m■[0m     ┃
┃     [33m●[0m   [32m●[0m     ┃  ┃     [33m●[0m   [32m■[0m     ┃                       ┃   [33m■[0m   [32m●[0m [34m●[0m     ┃  ┃   [33m●[0m   [32m●[0m [34m■[0m     ┃
┃   [34m●[0m   [31m●[0m [32m●[0m     ┃  ┃   [34m●[0m     [32m●[0m     ┃                       ┃   [33m●[0m           ┃  ┃               ┃
┃       [31m■[0m [33m●[0m     ┃  ┃       [31m●[0m [33m●[0m     ┃                       ┃   [33m●[0m           ┃  ┃   [33m●[0m           ┃
┃            

In [7]:
puzzle_hash_fun: callable = hash_func_builder(puzzle.State)

In [8]:
#check hashing
hashes, _ = jax.vmap(puzzle_hash_fun, in_axes=(0, None))(states, 1)
#count hash collision
print(hashes.shape)
print(hashes.dtype)
print(jnp.unique(hashes).shape) # Low collision
print(jnp.unique(states.board, axis=0).shape) # Low collision

hashes, _ = jax.vmap(puzzle_hash_fun, in_axes=(0, None))(next_states, 1)
#count hash collision
print(hashes.shape)
print(hashes.dtype)
print(jnp.unique(hashes).shape) # High collision
print(jnp.unique(next_states.board, axis=0).shape) # High collision

(1000,)
uint32
(1000,)
(1000, 49)
(8000,)
uint32
(7176,)
(7176, 49)


In [9]:
#check heuristic
print("Heuristic")
dist = jax.vmap(heuristic.distance, in_axes=(0, None))(next_states, states[0])
print(dist)

Heuristic
[11 13 12 ...  9  8  8]
