In [1]:
import sys
import os

# 프로젝트 루트 디렉토리 경로를 추가
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import jax
import jax.numpy as jnp

from JAxtar.hash import hash_func_builder
from puzzle.lightsout import LightsOut
from heuristic.lightsout_heuristic import LightsOutHeuristic

In [3]:
puzzle = LightsOut(10)
heuristic = LightsOutHeuristic(puzzle)

In [4]:
#check batch generation
states = puzzle.get_target_state()
print(states)
next_states, costs = puzzle.get_neighbours(states)
print(next_states)
print(costs)

┏━━━━━━━━━━━━━━━━━━━━┓
┃□ □ □ □ □ □ □ □ □ □ ┃
┃□ □ □ □ □ □ □ □ □ □ ┃
┃□ □ □ □ □ □ □ □ □ □ ┃
┃□ □ □ □ □ □ □ □ □ □ ┃
┃□ □ □ □ □ □ □ □ □ □ ┃
┃□ □ □ □ □ □ □ □ □ □ ┃
┃□ □ □ □ □ □ □ □ □ □ ┃
┃□ □ □ □ □ □ □ □ □ □ ┃
┃□ □ □ □ □ □ □ □ □ □ ┃
┃□ □ □ □ □ □ □ □ □ □ ┃
┗━━━━━━━━━━━━━━━━━━━━┛
+------------------------+------------------------+------------------------+-----+------------------------+------------------------+------------------------+
| ┏━━━━━━━━━━━━━━━━━━━━┓ | ┏━━━━━━━━━━━━━━━━━━━━┓ | ┏━━━━━━━━━━━━━━━━━━━━┓ | ... | ┏━━━━━━━━━━━━━━━━━━━━┓ | ┏━━━━━━━━━━━━━━━━━━━━┓ | ┏━━━━━━━━━━━━━━━━━━━━┓ |
| ┃■ ■ □ □ □ □ □ □ □ □ ┃ | ┃■ □ □ □ □ □ □ □ □ □ ┃ | ┃□ □ □ □ □ □ □ □ □ □ ┃ |     | ┃□ □ □ □ □ □ □ □ □ □ ┃ | ┃□ □ □ □ □ □ □ □ □ □ ┃ | ┃□ □ □ □ □ □ □ □ □ □ ┃ |
| ┃■ □ □ □ □ □ □ □ □ □ ┃ | ┃■ ■ □ □ □ □ □ □ □ □ ┃ | ┃■ □ □ □ □ □ □ □ □ □ ┃ |     | ┃□ □ □ □ □ □ □ □ □ □ ┃ | ┃□ □ □ □ □ □ □ □ □ □ ┃ | ┃□ □ □ □ □ □ □ □ □ □ ┃ |
| ┃□ □ □ □ □ □ □ □ □ □ ┃ | ┃■ □ □ □ □ □ □ □ □ □ ┃ | ┃■ ■ □ □ □ □ □ □ □ □ ┃ |     | ┃□ □ □ □ 

In [5]:
print(puzzle.is_solved(states, states))
print(puzzle.is_solved(states, next_states[0]))

True
False


In [6]:
# check batch neighbours
states = jax.vmap(puzzle.get_initial_state, in_axes=0)(key=jax.random.split(jax.random.PRNGKey(0),int(1e3))) # total 10 million states
next_states, costs = jax.vmap(puzzle.get_neighbours, in_axes=0)(states)
first_flat = lambda x: jnp.reshape(x, (-1, *x.shape[2:]))
next_states = jax.tree_util.tree_map(first_flat, next_states)
costs = first_flat(costs)
print(next_states.shape)
print(next_states.dtype)
print(costs.shape)

shape(board=(100000, 13))
dtype(board=dtype('uint8'))
(100000,)


In [7]:
puzzle_hash_fun: callable = hash_func_builder(puzzle.State)

In [8]:
#check hashing
hashes = jax.vmap(puzzle_hash_fun, in_axes=(0, None))(states, 1)
#count hash collision
print(hashes.shape)
print(hashes.dtype)
print(jnp.unique(hashes).shape) # Low collision
print(jnp.unique(states.board, axis=0).shape) # Low collision

hashes = jax.vmap(puzzle_hash_fun, in_axes=(0, None))(next_states, 1)
#count hash collision
print(hashes.shape)
print(hashes.dtype)
print(jnp.unique(hashes).shape) # High collision
print(jnp.unique(next_states.board, axis=0).shape) # High collision

(1000,)
uint32
(1000,)
(1000, 13)
(100000,)
uint32
(99999,)
(100000, 13)


In [9]:
#check heuristic
print("Heuristic")
dist = jax.vmap(heuristic.distance, in_axes=(0, None))(next_states, states[0])
print(dist)

Heuristic
[0.6       0.8       0.8       ... 7.        7.4       7.2000003]
