In [1]:
from conftest import setup_project
setup_project()

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import jax
import jax.numpy as jnp

from JAxtar.hash import hash_func_builder
from puzzle.maze import Maze
from heuristic.maze_heuristic import MazeHeuristic

2024-09-24 08:26:17.463243: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [3]:
puzzle = Maze(20)
heuristic = MazeHeuristic(puzzle)

In [4]:
#check batch generation
states = puzzle.get_initial_state()
target = puzzle.get_target_state()
print(states)
print(target)

next_states, costs = puzzle.get_neighbours(states)
print(next_states)
print(costs)

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃                 ■ ■ ■           ■       ┃
┃       ■         ■       ■           ■ ■ ┃
┃       ■ ■ ■ ■   ■ ■     ■ ■ ■ ■       ■ ┃
┃     ■   ■ ■       ■     ■   ■ ■     ■   ┃
┃ ■     ■ ■     ■     ■   ■         ■ ■ ■ ┃
┃   ■   ■ ■   ■   ■ ■         ■           ┃
┃ ■ ■ ■             ■       ■ ■     ■     ┃
┃             ■ ■   ■ ■ ■   ■       ■     ┃
┃     ■     ■   ■               ■ ■   [31m●[0m   ┃
┃       ■ ■ ■   ■ ■           ■ ■         ┃
┃       ■ ■     ■ ■           ■       ■ ■ ┃
┃ ■   ■ ■                 ■     ■       ■ ┃
┃ ■   ■             ■ ■   ■ ■   ■     ■ ■ ┃
┃ ■           ■   ■ ■   ■       ■   ■     ┃
┃             ■   ■           ■     ■     ┃
┃ ■           ■                   ■       ┃
┃     ■       ■                 ■         ┃
┃ ■         ■       ■ ■     ■   ■     ■   ┃
┃ ■                 ■ ■   ■ ■ ■ ■         ┃
┃   ■         ■                         ■ ┃
┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛
┏━━━━━━━━━━━━━━━━━━━━━━

In [5]:
print(puzzle.is_solved(target, states))
print(puzzle.is_solved(target, next_states[0]))

False
False


In [6]:
# check batch neighbours
states = jax.vmap(puzzle.get_initial_state, in_axes=0)(key=jax.random.split(jax.random.PRNGKey(0),int(1e3))) # total 10 million states
print(states)
next_states, costs = jax.vmap(puzzle.get_neighbours, in_axes=0)(states)
print(next_states)
next_states = next_states.flatten()
costs = costs.flatten()
print(next_states.shape)
print(next_states.dtype)
print(costs.shape)

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓  ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓  ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓  ...                ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓  ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓  ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃                 ■ ■ ■           ■       ┃  ┃                 ■ ■ ■           ■       ┃  ┃                 ■ ■ ■           ■       ┃  (batch : (1000,))  ┃                 ■ ■ ■           ■       ┃  ┃                 ■ ■ ■           ■       ┃  ┃                 ■ ■ ■           ■       ┃
┃       ■         ■       ■           ■ ■ ┃  ┃       ■         ■       ■           ■ ■ ┃  ┃       ■         ■       ■           ■ ■ ┃                     ┃       ■         ■       ■           ■ ■ ┃  ┃       ■         ■       ■           ■ ■ ┃  ┃       ■         ■       ■           ■ ■ ┃
┃       ■ ■ ■ ■   ■ ■     ■ ■ ■ ■       ■ ┃  ┃       ■ ■ ■ ■   ■ ■     ■ ■ ■ ■       ■ ┃  ┃       ■ ■ ■ ■   ■ ■     ■ ■ ■ ■       ■ ┃   

In [7]:
puzzle_hash_fun: callable = hash_func_builder(puzzle.State)

In [8]:
#check hashing
hashes = jax.vmap(puzzle_hash_fun, in_axes=(0, None))(states, 1)
#count hash collision
print(hashes.shape)
print(hashes.dtype)
print(jnp.unique(hashes).shape) # Low collision
print(jnp.unique(states.pos, axis=0).shape) # Low collision

hashes = jax.vmap(puzzle_hash_fun, in_axes=(0, None))(next_states, 1)
#count hash collision
print(hashes.shape)
print(hashes.dtype)
print(jnp.unique(hashes).shape) # High collision    
print(jnp.unique(next_states.pos, axis=0).shape) # High collision

(1000,)
uint32
(270,)
(270, 2)
(4000,)
uint32
(274,)
(274, 2)


In [9]:
#check heuristic
print("Heuristic")
dist = jax.vmap(heuristic.distance, in_axes=(0, None))(next_states, states[0])
print(dist)

Heuristic
[1 0 0 ... 4 2 4]
