In [1]:
from conftest import setup_project
setup_project()

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import jax
import jax.numpy as jnp
import chex
import numpy as np
#disable jax JIT
#jax.config.update("jax_disable_jit", True)

from tqdm.autonotebook import tqdm
from JAxtar.hash import hash_func_builder, HashTable
from JAxtar.bgpq import BGPQ, HashTableIdx_HeapValue, HeapValue
from functools import partial
from puzzle.slidepuzzle import SlidePuzzle
from heuristic.slidepuzzle_heuristic import SlidePuzzleHeuristic

  from tqdm.autonotebook import tqdm


In [3]:
size = int(2e7)
batch_size = int(8192)
cost_weight = 1.0 - 1e-3 # for tie breaking

In [4]:
puzzle = SlidePuzzle(4)

In [5]:
hash_func = hash_func_builder(puzzle.State)
table = HashTable.build(puzzle.State, 1, size)
size_table = table.capacity
n_table = table.n_table
heap = BGPQ.build(size, batch_size, HashTableIdx_HeapValue)
cost = jnp.full((size_table, n_table), jnp.inf)
not_closed = jnp.full((size_table, n_table), True, dtype=bool)
parant = jnp.full((size_table, n_table, 2), -1, dtype=int)

In [6]:
lookup = jax.jit(jax.vmap(partial(HashTable.lookup, hash_func), in_axes=(None, 0)))
parallel_insert = jax.jit(partial(HashTable.parallel_insert, hash_func))
heuristic = jax.jit(jax.vmap(SlidePuzzleHeuristic(puzzle).distance, in_axes=(0, None)))
solved_fn = jax.jit(jax.vmap(puzzle.is_solved, in_axes=(0, None)))
neighbours_fn = jax.jit(jax.vmap(puzzle.get_neighbours, in_axes=(0,0)))
delete_fn = jax.jit(BGPQ.delete_mins)
insert_fn = jax.jit(BGPQ.insert)

In [7]:
#states = jax.vmap(puzzle.get_initial_state, in_axes=0)(key=jax.random.split(jax.random.PRNGKey(32),1))
states = puzzle.State(board=jnp.array([0, 12, 9, 13, 15, 11, 10, 14, 3, 7, 2, 5, 4, 8, 6, 1], dtype=jnp.uint8))[jnp.newaxis, ...]
# 80 moves to solve
target = puzzle.get_target_state()
states, filled = HashTable.make_batched(puzzle.State, states, batch_size)
cost_val = jnp.full((batch_size,), jnp.inf).at[0].set(0)
table, inserted, idx, table_idx = parallel_insert(table, states, filled)
found = inserted
print(found[:10])
heur_val = heuristic(states, target)
print(heur_val[:10])
hash_idxs = HashTableIdx_HeapValue(index=idx, table_index=table_idx)[:, jnp.newaxis]
cost = cost.at[idx, table_idx].set(jnp.where(found, cost_val, cost[idx, table_idx]))
key = cost_val + heur_val
heap = BGPQ.insert(heap, key, hash_idxs)

[ True False False False False False False False False False]
[58.  0.  0.  0.  0.  0.  0.  0.  0.  0.]


In [8]:
def merge_sort_split(ak: chex.Array, av: HeapValue, bk: chex.Array, bv: HeapValue) -> tuple[chex.Array, HeapValue, chex.Array, HeapValue]:
    """
    Merge two sorted key tensors ak and bk as well as corresponding
    value tensors av and bv into a single sorted tensor.

    Args:
        ak: chex.Array - sorted key tensor
        av: HeapValue - sorted value tensor
        bk: chex.Array - sorted key tensor
        bv: HeapValue - sorted value tensor

    Returns:
        key1: chex.Array - merged and sorted
        val1: HeapValue - merged and sorted
        key2: chex.Array - merged and sorted
        val2: HeapValue - merged and sorted
    """
    n = ak.shape[-1] # size of group
    key = jnp.concatenate([ak, bk])
    val = jax.tree_util.tree_map(lambda a, b: jnp.concatenate([a, b]), av, bv)
    idx = jnp.argsort(key, stable=True)

    # Sort both key and value arrays using the same index
    sorted_key = key[idx]
    sorted_val = jax.tree_util.tree_map(lambda x: x[idx], val)
    return sorted_key[:n], sorted_val[:n], sorted_key[n:], sorted_val[n:]

min_key_buffer = jnp.full((batch_size,), jnp.inf)
min_val_buffer = HashTableIdx_HeapValue(index=jnp.full((batch_size,), -1, dtype=jnp.int32), table_index=jnp.full((batch_size,), -1, dtype=jnp.int32))
def pop_full(heap, min_key_buffer, min_val_buffer):
    heap, min_key, min_val = delete_fn(heap)
    min_idx, min_table_idx = min_val.index, min_val.table_index
    min_key = jnp.where(not_closed[min_idx, min_table_idx], min_key, jnp.inf)
    min_key, min_val, min_key_buffer, min_val_buffer = merge_sort_split(min_key, min_val, min_key_buffer, min_val_buffer)
    filled = jnp.isfinite(min_key)
    while heap.size > 0 and not filled.all():
        heap, min_key_buffer, min_val_buffer = delete_fn(heap)
        min_key_buffer = jnp.where(not_closed[min_val_buffer.index, min_val_buffer.table_index], min_key_buffer, jnp.inf)
        min_key, min_val, min_key_buffer, min_val_buffer = merge_sort_split(min_key, min_val, min_key_buffer, min_val_buffer)
        filled = jnp.isfinite(min_key)
    return heap, min_val, filled, min_key_buffer, min_val_buffer

In [9]:
print(states[0])
pbar = tqdm(total=size)
pbar.update(1)
while heap.size < size and not heap.size == 0 and table.size < size:
    pbar_str = f"heap_size: {heap.size:8d}, total_nodes: {table.size:8d}, "
    # get the minimum key
    heap, min_val, filled, min_key_buffer, min_val_buffer = pop_full(heap, min_key_buffer, min_val_buffer)
    if not filled.any():
        continue
    min_idx, min_table_idx = min_val.index, min_val.table_index
    parant_idx = jnp.stack((min_idx, min_table_idx), axis=-1).astype(jnp.int32)
    # get the state
    cost_val, not_closed_val = cost[min_idx, min_table_idx], not_closed[min_idx, min_table_idx]
    states = table.table[min_idx, min_table_idx]
    solved = solved_fn(states, target)
    if solved.any():
        print("solved")
        break
    not_closed = not_closed.at[min_idx, min_table_idx].min(~filled)
    closed_ratio = jnp.mean(~filled)
    pbar_str += f"cost: {jnp.mean(cost_val):.2f}, closed_ratio: {closed_ratio:.2f}"
    pbar.set_description_str(pbar_str)

    neighbours, ncost = neighbours_fn(states, filled)
    nextcosts = cost_val[:, jnp.newaxis] + ncost
    nextheur = jax.vmap(heuristic, in_axes=(0, None))(neighbours, target)
    nextkeys = cost_weight * nextcosts + nextheur
    filleds = jnp.isfinite(nextkeys)
    for i in range(nextkeys.shape[1]):
        nextkey = nextkeys[:, i]
        nextcost = nextcosts[:, i]
        nextstates = neighbours[:, i]
        filled = filleds[:, i]

        table, inserted, idx, table_idx = parallel_insert(table, nextstates, filled)
        vals = HashTableIdx_HeapValue(index=idx, table_index=table_idx)[:, jnp.newaxis]
        optimal = jnp.less(nextcost, cost[idx, table_idx])
        cost = cost.at[idx, table_idx].min(nextcost)
        parant = parant.at[idx, table_idx].set(jnp.where(optimal[:,jnp.newaxis], parant_idx, parant[idx, table_idx]))
        not_closed_update = not_closed[idx, table_idx] | optimal
        not_closed = not_closed.at[idx, table_idx].set(not_closed_update)
        nextkey = jnp.where(not_closed_update, nextkey, jnp.inf)
        added = int(jnp.sum(optimal))
        
        heap = insert_fn(heap, nextkey, vals, added_size=added)
        pbar.update(added)
pbar.close()
solved_st = states[solved][0]
n_cost = cost_val[solved][0]
print(solved_st)
print(n_cost)

┏━━━┳━━━┳━━━┳━━━┓
┃   ┃ C ┃ 9 ┃ D ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ F ┃ B ┃ A ┃ E ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 3 ┃ 7 ┃ 2 ┃ 5 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 4 ┃ 8 ┃ 6 ┃ 1 ┃
┗━━━┻━━━┻━━━┻━━━┛


  0%|          | 0/20000000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
path = []
idx = jnp.argmax(jnp.max(solved, axis=0))
parant_last = parant_idx[idx]
for i in range(100):
    if parant_last[0] == -1:
        break
    path.append(parant_last)
    parant_last = parant[parant_last[0], parant_last[1]]

for p in path[::-1]:
    state = table.table[p[0], p[1]]
    c = cost[p[0], p[1]]
    print(state)
    print(c)
print(solved_st)
print(n_cost)

┏━━━┳━━━┳━━━┳━━━┓
┃ 2 ┃ F ┃ 3 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃   ┃ 5 ┃ 7 ┃ B ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ C ┃ 9 ┃ 1 ┃ A ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 8 ┃ E ┃ D ┃ 6 ┃
┗━━━┻━━━┻━━━┻━━━┛
0.0
┏━━━┳━━━┳━━━┳━━━┓
┃ 2 ┃ F ┃ 3 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ C ┃ 5 ┃ 7 ┃ B ┃
┣━━━╋━━━╋━━━╋━━━┫
┃   ┃ 9 ┃ 1 ┃ A ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 8 ┃ E ┃ D ┃ 6 ┃
┗━━━┻━━━┻━━━┻━━━┛
1.0
┏━━━┳━━━┳━━━┳━━━┓
┃ 2 ┃ F ┃ 3 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ C ┃ 5 ┃ 7 ┃ B ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 8 ┃ 9 ┃ 1 ┃ A ┃
┣━━━╋━━━╋━━━╋━━━┫
┃   ┃ E ┃ D ┃ 6 ┃
┗━━━┻━━━┻━━━┻━━━┛
2.0
┏━━━┳━━━┳━━━┳━━━┓
┃ 2 ┃ F ┃ 3 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ C ┃ 5 ┃ 7 ┃ B ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 8 ┃ 9 ┃ 1 ┃ A ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ E ┃   ┃ D ┃ 6 ┃
┗━━━┻━━━┻━━━┻━━━┛
3.0
┏━━━┳━━━┳━━━┳━━━┓
┃ 2 ┃ F ┃ 3 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ C ┃ 5 ┃ 7 ┃ B ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 8 ┃ 9 ┃ 1 ┃ A ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ E ┃ D ┃   ┃ 6 ┃
┗━━━┻━━━┻━━━┻━━━┛
4.0
┏━━━┳━━━┳━━━┳━━━┓
┃ 2 ┃ F ┃ 3 ┃ 4 ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ C ┃ 5 ┃ 7 ┃ B ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ 8 ┃ 9 ┃ 1 ┃ A ┃
┣━━━╋━━━╋━━━╋━━━┫
┃ E ┃ D ┃ 6 ┃   ┃
┗━━━┻━━━┻━━━┻━━━┛
5.0
┏━━━