In [1]:
import sys
import os

# 프로젝트 루트 디렉토리 경로를 추가
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import jax
import jax.numpy as jnp
from functools import partial

from JAxtar.hash import hash_func_builder
from puzzle.rubikscube import RubiksCube
from heuristic.rubikscube_heuristic import RubiksCubeHeuristic

In [3]:
puzzle = RubiksCube(3)
heuristic = RubiksCubeHeuristic(puzzle)

In [4]:
#check batch generation
states = puzzle.get_target_state()
print(states)
next_states, costs = puzzle.get_neighbours(states)
print(next_states)
print(costs)

up    :[97m■[0m   ┏━━━up━━┓
down  :[33m■[0m   ┃ [97m■[0m [97m■[0m [97m■[0m ┃
left  :[31m■[0m   ┃ [97m■[0m [97m■[0m [97m■[0m ┃
right :[35m■[0m   ┃ [97m■[0m [97m■[0m [97m■[0m ┃
front :[32m■[0m   ┗━━━━━━━┛
back  :[34m■[0m
┏━━left━┓  ┏━front━┓  ┏━right━┓  ┏━━back━┓
┃ [31m■[0m [31m■[0m [31m■[0m ┃  ┃ [32m■[0m [32m■[0m [32m■[0m ┃  ┃ [35m■[0m [35m■[0m [35m■[0m ┃  ┃ [34m■[0m [34m■[0m [34m■[0m ┃
┃ [31m■[0m [31m■[0m [31m■[0m ┃  ┃ [32m■[0m [32m■[0m [32m■[0m ┃  ┃ [35m■[0m [35m■[0m [35m■[0m ┃  ┃ [34m■[0m [34m■[0m [34m■[0m ┃
┃ [31m■[0m [31m■[0m [31m■[0m ┃  ┃ [32m■[0m [32m■[0m [32m■[0m ┃  ┃ [35m■[0m [35m■[0m [35m■[0m ┃  ┃ [34m■[0m [34m■[0m [34m■[0m ┃
┗━━━━━━━┛  ┗━━━━━━━┛  ┗━━━━━━━┛  ┗━━━━━━━┛
           ┏━━down━┓
           ┃ [33m■[0m [33m■[0m [33m■[0m ┃
           ┃ [33m■[0m [33m■[0m [33m■[0m ┃
           ┃ [33m■[0m [33m■[0m [33m■[0m ┃
           ┗━━━━━━━┛
up    :[97m■[0m   ┏━━━up

In [5]:
print(puzzle.is_solved(states, states))
print(puzzle.is_solved(states, next_states[0]))

True
False


In [6]:
# check batch neighbours
states = jax.vmap(puzzle.get_initial_state, in_axes=0)(key=jax.random.split(jax.random.PRNGKey(0),int(1e3))) # total 10 million states
next_states, costs = jax.vmap(puzzle.get_neighbours, in_axes=0)(states)
print(next_states)
next_states = next_states.flatten()
costs = costs.flatten()
print(next_states.shape)
print(next_states.dtype)
print(costs.shape)

up    :[97m■[0m   ┏━━━up━━┓                        up    :[97m■[0m   ┏━━━up━━┓                        up    :[97m■[0m   ┏━━━up━━┓                        ...                   up    :[97m■[0m   ┏━━━up━━┓                        up    :[97m■[0m   ┏━━━up━━┓                        up    :[97m■[0m   ┏━━━up━━┓
down  :[33m■[0m   ┃ [32m■[0m [97m■[0m [32m■[0m ┃                        down  :[33m■[0m   ┃ [35m■[0m [97m■[0m [32m■[0m ┃                        down  :[33m■[0m   ┃ [97m■[0m [97m■[0m [32m■[0m ┃                        (batch : (1000, 12))  down  :[33m■[0m   ┃ [97m■[0m [34m■[0m [31m■[0m ┃                        down  :[33m■[0m   ┃ [35m■[0m [35m■[0m [33m■[0m ┃                        down  :[33m■[0m   ┃ [31m■[0m [97m■[0m [33m■[0m ┃
left  :[31m■[0m   ┃ [32m■[0m [97m■[0m [34m■[0m ┃                        left  :[31m■[0m   ┃ [34m■[0m [97m■[0m [34m■[0m ┃                        left  :[31m■[0m   ┃ [97m■[0m [97m■[0m 

In [7]:
puzzle_hash_fun: callable = hash_func_builder(puzzle.State)

In [8]:
#check hashing
hashes = jax.vmap(puzzle_hash_fun, in_axes=(0, None))(states, 1)
#count hash collision
print(hashes.shape)
print(hashes.dtype)
print(jnp.unique(hashes).shape) # Low collision
print(jnp.unique(states.faces, axis=0).shape) # Low collision

hashes = jax.vmap(puzzle_hash_fun, in_axes=(0, None))(next_states, 1)
#count hash collision
print(hashes.shape)
print(hashes.dtype)
print(jnp.unique(hashes).shape) # High collision
print(jnp.unique(next_states.faces, axis=0).shape) # High collision

(1000,)
uint32
(1000,)
(1000, 6, 9)
(12000,)
uint32
(11998,)
(11998, 6, 9)


In [9]:
#check heuristic
print("Heuristic")
dist = jax.vmap(heuristic.distance, in_axes=(0, None))(next_states, states[0])
print(dist)

Heuristic
[2.    2.    2.125 ... 5.125 4.375 4.875]


In [10]:
from heuristic.DAVI.neuralheuristic.rubikscube_neuralheuristic import RubiksCubeNeuralHeuristic
neural_heuristic = RubiksCubeNeuralHeuristic.load_model(puzzle, "../heuristic/DAVI/neuralheuristic/params/rubikscube_3.pkl")
neural_heuristic_fn = neural_heuristic.distance
target_state = puzzle.get_target_state()
target_neighbours, costs = puzzle.get_neighbours(target_state)
print(target_neighbours)
dist = jax.vmap(neural_heuristic_fn, in_axes=(0, None))(target_neighbours.flatten(), target_state)
print(dist)
print(dist.shape)
print(dist.dtype)
print(jnp.mean(dist))

dist = jax.vmap(neural_heuristic_fn, in_axes=(0, None))(next_states[:100], target_state)
print(dist)
print(dist.shape)
print(dist.dtype)
print(jnp.mean(dist))

up    :[97m■[0m   ┏━━━up━━┓                        up    :[97m■[0m   ┏━━━up━━┓                        up    :[97m■[0m   ┏━━━up━━┓                        up    :[97m■[0m   ┏━━━up━━┓                        up    :[97m■[0m   ┏━━━up━━┓                        up    :[97m■[0m   ┏━━━up━━┓                        up    :[97m■[0m   ┏━━━up━━┓                        up    :[97m■[0m   ┏━━━up━━┓                        up    :[97m■[0m   ┏━━━up━━┓                        up    :[97m■[0m   ┏━━━up━━┓                        up    :[97m■[0m   ┏━━━up━━┓                        up    :[97m■[0m   ┏━━━up━━┓                        batch : (12,)
down  :[33m■[0m   ┃ [32m■[0m [97m■[0m [97m■[0m ┃                        down  :[33m■[0m   ┃ [34m■[0m [97m■[0m [97m■[0m ┃                        down  :[33m■[0m   ┃ [97m■[0m [97m■[0m [97m■[0m ┃                        down  :[33m■[0m   ┃ [97m■[0m [97m■[0m [97m■[0m ┃                        down  :[33m■[0m   ┃ [97m■[