In [1]:
import sys
import os

# 프로젝트 루트 디렉토리 경로를 추가
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import jax
import jax.numpy as jnp

from JAxtar.hash import hash_func_builder, HashTable
from functools import partial
from puzzle.slidepuzzle import SlidePuzzle
from heuristic.slidepuzzle_heuristic import SlidePuzzleHeuristic

In [3]:
puzzle = SlidePuzzle(4)
heuristic = SlidePuzzleHeuristic(puzzle)

In [4]:
size = 100
#check batch generation
states = jax.vmap(puzzle.get_initial_state, in_axes=0)(key=jax.random.split(jax.random.PRNGKey(123),1))
print(states[0])
hash_func = hash_func_builder(puzzle.State)
table = HashTable.make_lookup_table(puzzle.State, 1, int(1e2))

lookup = jax.jit(partial(HashTable.lookup, hash_func))
parallel_insert = jax.jit(partial(HashTable.parallel_insert, hash_func))

┏━━━┳━━━┳━━━┳━━━┓
┃ 5 ┃   ┃ A ┃ E ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 8 ┃ 7 ┃ 2 ┃ D ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 1 ┃ 6 ┃ 9 ┃ 3 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 4 ┃ F ┃ C ┃ B ┃
┗━━━┻━━━┻━━━┻━━━┛


In [5]:
states, filled = HashTable.make_batched(puzzle.State, states, size)

In [6]:
def visualize_state(state, filled):
    for i in range(size):
        if filled[i]:
            print(state[i])

In [7]:
for i in range(6):
    visualize_state(states, filled)
    print(filled.sum())
    table, inserted = parallel_insert(table, states, filled)

    neighbours, cost = jax.vmap(puzzle.get_neighbours, in_axes=(0,0))(states, filled)
    idx, table_idx, found = jax.vmap(jax.vmap(lookup, in_axes=(None, 0)), in_axes=(None, 0))(table, neighbours)
    neighbours_filled = jnp.isfinite(cost)

    first_flat = lambda x: jnp.reshape(x, (-1, *x.shape[2:]))
    neighbours = jax.tree_util.tree_map(first_flat, neighbours)
    neighbours_filled = first_flat(neighbours_filled) & ~first_flat(found)

    filled_sort = jnp.argsort(neighbours_filled)[::-1]
    neighbours = jax.tree_util.tree_map(lambda x: x[filled_sort], neighbours)
    filled = neighbours_filled[filled_sort]

    states = jax.tree_util.tree_map(lambda x: x[:size], neighbours)
    filled = filled[:size]

┏━━━┳━━━┳━━━┳━━━┓
┃ 5 ┃   ┃ A ┃ E ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 8 ┃ 7 ┃ 2 ┃ D ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 1 ┃ 6 ┃ 9 ┃ 3 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 4 ┃ F ┃ C ┃ B ┃
┗━━━┻━━━┻━━━┻━━━┛
1
┏━━━┳━━━┳━━━┳━━━┓
┃ 5 ┃ 7 ┃ A ┃ E ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 8 ┃   ┃ 2 ┃ D ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 1 ┃ 6 ┃ 9 ┃ 3 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 4 ┃ F ┃ C ┃ B ┃
┗━━━┻━━━┻━━━┻━━━┛
┏━━━┳━━━┳━━━┳━━━┓
┃   ┃ 5 ┃ A ┃ E ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 8 ┃ 7 ┃ 2 ┃ D ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 1 ┃ 6 ┃ 9 ┃ 3 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 4 ┃ F ┃ C ┃ B ┃
┗━━━┻━━━┻━━━┻━━━┛
┏━━━┳━━━┳━━━┳━━━┓
┃ 5 ┃ A ┃   ┃ E ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 8 ┃ 7 ┃ 2 ┃ D ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 1 ┃ 6 ┃ 9 ┃ 3 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 4 ┃ F ┃ C ┃ B ┃
┗━━━┻━━━┻━━━┻━━━┛
3
┏━━━┳━━━┳━━━┳━━━┓
┃ 5 ┃ A ┃ 2 ┃ E ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 8 ┃ 7 ┃   ┃ D ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 1 ┃ 6 ┃ 9 ┃ 3 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 4 ┃ F ┃ C ┃ B ┃
┗━━━┻━━━┻━━━┻━━━┛
┏━━━┳━━━┳━━━┳━━━┓
┃ 5 ┃ A ┃ E ┃   ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 8 ┃ 7 ┃ 2 ┃ D ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 1 ┃ 6 ┃ 9 ┃ 3 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 4 ┃ F ┃ C ┃ B ┃
┗━━━┻━━━┻━━━┻━━━┛
┏━━━┳━━━┳━━━┳━━━┓
┃ 8 ┃ 

In [8]:
print(filled)

[ True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True]
