In [1]:
import sys
import os

# 프로젝트 루트 디렉토리 경로를 추가
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import jax
import jax.numpy as jnp
import chex
import numpy as np

from tqdm.autonotebook import tqdm
from JAxtar.hash import hash_func_builder, HashTable
from JAxtar.bgpq import BGPQ, heapcalue_dataclass
from functools import partial
from puzzle.slidepuzzle import SlidePuzzle
from heuristic.slidepuzzle_heuristic import SlidePuzzleHeuristic

  from tqdm.autonotebook import tqdm


In [3]:
puzzle = SlidePuzzle(4)
target = puzzle.State(board=jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0], dtype=jnp.uint8))
heuristic = SlidePuzzleHeuristic(puzzle)

In [4]:
@heapcalue_dataclass
class AstarTableHeapValue:
    """
    This class is a dataclass that represents a hash table heap value.
    It has two fields:
    1. index: jnp.uint32 / hashtable index
    2. table_index: jnp.uint8 / cuckoo table index
    """
    index: chex.Array
    table_index: chex.Array
    cost: chex.Array
    heuristic: chex.Array

    @staticmethod
    def default(_ = None) -> "AstarTableHeapValue":
        return AstarTableHeapValue(index=jnp.zeros(1, dtype=jnp.uint32), table_index=jnp.zeros(1, dtype=jnp.uint8), cost=jnp.zeros(1, dtype=jnp.float32), heuristic=jnp.zeros(1, dtype=jnp.float32))

In [5]:
def get_shuffled_puzzle():
    state = puzzle.State(board=jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0], dtype=jnp.uint8))
    for _ in range(1000):
        state, cost = puzzle.get_neighbours(state)
        filled = jnp.isfinite(cost)
        #random use np.random.choice
        state = state[filled][np.random.choice(jnp.sum(filled))]
    return state

In [6]:
size = int(1e7)
group_size = int(10000)
#check batch generation
#states = jax.vmap(puzzle.get_initial_state, in_axes=0)(key=jax.random.split(jax.random.PRNGKey(123),1))
states = get_shuffled_puzzle()
#states = puzzle.State(board=jnp.array([10, 2, 0, 8, 15, 12, 13, 14, 6, 7, 3, 9, 11, 1, 5, 4], dtype=jnp.uint8))
print(states.shape)
print(states)
hash_func = hash_func_builder(puzzle.State)
table = HashTable.make_lookup_table(puzzle.State, 1, size)
heap = BGPQ.make_heap(size, group_size, AstarTableHeapValue)

lookup = jax.jit(partial(HashTable.lookup, hash_func))
parallel_insert = jax.jit(partial(HashTable.parallel_insert, hash_func))
heuristic = jax.jit(jax.vmap(SlidePuzzleHeuristic(puzzle).distance, in_axes=(0, None)))

shape(board=(16,))
┏━━━┳━━━┳━━━┳━━━┓
┃ F ┃ 8 ┃ 5 ┃ E ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 2 ┃ D ┃ C ┃ 3 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 6 ┃ 4 ┃ 7 ┃ B ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 9 ┃   ┃ A ┃ 1 ┃
┗━━━┻━━━┻━━━┻━━━┛


In [7]:
states, filled = HashTable.make_batched(puzzle.State, states[jnp.newaxis, :], group_size)
cost = jnp.full((group_size,), jnp.inf).at[0].set(0)
table, inserted = parallel_insert(table, states, filled)
idx, table_idx, found = jax.vmap(lookup, in_axes=(None, 0))(table, states)
found = found & filled
hash_idxs = AstarTableHeapValue(index=idx[:, jnp.newaxis], table_index=table_idx[:, jnp.newaxis], cost=cost[:, jnp.newaxis], heuristic=heuristic(states, target)[:, jnp.newaxis])
heap = BGPQ.insert(heap, cost, hash_idxs)

In [8]:
def visualize_state(state, key):
    filled = jnp.isfinite(key)
    for i in range(group_size):
        if filled[i]:
            print(state[i])
            print(f" cost: {key[i]}")

In [9]:
print(states[0])
total_nodes = 1
pbar = tqdm(total=size)
pbar.update(1)
while heap.size < size and not heap.size == 0 and total_nodes < size:
    pbar_str = f"heap_size: {heap.size:8d}, total_nodes: {total_nodes:8d}, "
    heap, min_key, min_val = BGPQ.delete_mins(heap)
    cost = min_val.cost
    pbar_str += f"cost: {jnp.mean(cost):.2f}"
    heur = min_val.heuristic
    filled = jnp.isfinite(min_key)
    states = table.table[min_val.index.squeeze(), min_val.table_index.squeeze()]
    pbar.set_description_str(pbar_str)

    neighbours, ncost = jax.vmap(puzzle.get_neighbours, in_axes=(0,0))(states, filled)
    _, _, nfound = jax.vmap(jax.vmap(lookup, in_axes=(None, 0)), in_axes=(None, 0))(table, neighbours)
    ncost_sum = ncost + cost + jnp.where(nfound, jnp.inf, 0)
    heuristic_cost = jax.vmap(heuristic, in_axes=(0, None))(neighbours, target)
    if jax.vmap(jax.vmap(puzzle.is_solved, in_axes=(0,None)), in_axes=(0,None))(neighbours, target).any():
        print("solved")
        break
    for i in range(4):
        n_state = neighbours[:,i]
        n_cost = ncost_sum[:,i]
        n_heuristic = heuristic_cost[:,i]
        filled = jnp.isfinite(n_cost)
        table, inserted = parallel_insert(table, n_state, filled)
        idx, table_idx, found = jax.vmap(lookup, in_axes=(None, 0))(table, n_state)
        vals = AstarTableHeapValue(index=idx[:,jnp.newaxis], table_index=table_idx[:,jnp.newaxis], cost=n_cost[:,jnp.newaxis], heuristic=n_heuristic[:,jnp.newaxis])
        key = n_cost + n_heuristic + jnp.where(inserted, 0, jnp.inf)
        added = int(jnp.sum(inserted))
        total_nodes += added
        pbar.update(added)
        heap = BGPQ.insert(heap, key, vals)
solved_idx = jax.vmap(jax.vmap(puzzle.is_solved, in_axes=(0,None)), in_axes=(0,None))(neighbours, target)
solved_st = neighbours[solved_idx][0]
n_cost = ncost_sum[solved_idx][0]
print(solved_st)
print(n_cost)

┏━━━┳━━━┳━━━┳━━━┓
┃ F ┃ 8 ┃ 5 ┃ E ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 2 ┃ D ┃ C ┃ 3 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 6 ┃ 4 ┃ 7 ┃ B ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 9 ┃   ┃ A ┃ 1 ┃
┗━━━┻━━━┻━━━┻━━━┛


  0%|          | 0/10000000 [00:00<?, ?it/s]

solved
┏━━━┳━━━┳━━━┳━━━┓
┃ 1 ┃ 2 ┃ 3 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 5 ┃ 6 ┃ 7 ┃ 8 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 9 ┃ A ┃ B ┃ C ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ D ┃ E ┃ F ┃   ┃
┗━━━┻━━━┻━━━┻━━━┛
76.0
