In [9]:
# Testing code for Region B and Region AB (JAX‐compatible)
from coord_index import *
import jax.numpy as jnp
import jax
jax.config.update('jax_enable_x64', True)

# ——— import your original functions below ———
# from your_module import (
#     is_coord_valid_B, coord_to_index_B, index_to_coord_B,
#     is_coord_valid_AB, get_global_index_from_coord, get_coord_from_global_index
# )


def test_region_AB(n):
    """Test every (i,j) in the 5n×(3n−1) box for Region AB."""
    height = 5 * n
    width = 3 * n
    # main‐region count = 2n*(3n-1) + n*(2n) + 2n*(3n-1) = 14n^2 - 4n
    num_plaq_main = 14 * n * n - 7 * n + 2 + 2 * (2 * n - 1)
    # middle edges count = 3*(n−2)
    num_middle = 3 * (n - 2) + 2

    mismatches = []
    print(f"--- Region AB (n={n}) ---")
    for i in range(height):
        for j in range(width):
            gidx = get_global_index_from_coord(i, j, n, num_plaq_main, num_middle)
            if int(gidx) < 0:
                continue
            i2, j2 = get_coord_from_global_index(n, gidx, num_plaq_main, num_middle)
            correct = (int(i2) == i) and (int(j2) == j)
            print(f"(AB)({i:2d},{j:2d}) → gidx={int(gidx):4d} → ({int(i2):2d},{int(j2):2d})"
                  + (" ✅" if correct else "  ❌"))
            if not correct:
                mismatches.append(((i, j), (int(gidx), int(i2), int(j2))))
    if mismatches:
        print("AB mismatches:", mismatches)
    else:
        print("All Region AB points round‐trip correctly.\n")


if __name__ == "__main__":
    # Example: test for n=4 (or change to any n)
    test_region_AB(4)


--- Region AB (n=4) ---
(AB)( 0, 0) → gidx=   0 → ( 0, 0) ✅
(AB)( 0, 1) → gidx=   1 → ( 0, 1) ✅
(AB)( 0, 2) → gidx=   2 → ( 0, 2) ✅
(AB)( 0, 3) → gidx=   3 → ( 0, 3) ✅
(AB)( 0, 4) → gidx=   4 → ( 0, 4) ✅
(AB)( 0, 5) → gidx=   5 → ( 0, 5) ✅
(AB)( 0, 6) → gidx=   6 → ( 0, 6) ✅
(AB)( 0, 7) → gidx=   7 → ( 0, 7) ✅
(AB)( 0, 8) → gidx=   8 → ( 0, 8) ✅
(AB)( 0, 9) → gidx=   9 → ( 0, 9) ✅
(AB)( 0,10) → gidx=  10 → ( 0,10) ✅
(AB)( 0,11) → gidx=  11 → ( 0,11) ✅
(AB)( 1, 0) → gidx=  12 → ( 1, 0) ✅
(AB)( 1, 1) → gidx=  13 → ( 1, 1) ✅
(AB)( 1, 2) → gidx=  14 → ( 1, 2) ✅
(AB)( 1, 3) → gidx=  15 → ( 1, 3) ✅
(AB)( 1, 4) → gidx=  16 → ( 1, 4) ✅
(AB)( 1, 5) → gidx=  17 → ( 1, 5) ✅
(AB)( 1, 6) → gidx=  18 → ( 1, 6) ✅
(AB)( 1, 7) → gidx=  19 → ( 1, 7) ✅
(AB)( 1, 8) → gidx=  20 → ( 1, 8) ✅
(AB)( 1, 9) → gidx=  21 → ( 1, 9) ✅
(AB)( 1,10) → gidx=  22 → ( 1,10) ✅
(AB)( 1,11) → gidx=  23 → ( 1,11) ✅
(AB)( 2, 0) → gidx=  24 → ( 2, 0) ✅
(AB)( 2, 1) → gidx=  25 → ( 2, 1) ✅
(AB)( 2, 2) → gidx=  26 → ( 2, 2) ✅
(AB)

In [10]:
# Test and print every valid point
def test_AB_P0(n):
    print(f"--- Testing AB_P0 region for n={n} ---")
    height = 5 * n
    width  = 3 * n
    for i in range(height):
        for j in range(width):
            if not is_coord_valid_AB_P0(i, j, n):
                # skip invalid coordinates
                continue
            idx = coord_to_index_AB_P0(i, j, n)
            i2, j2 = index_to_coord_AB_P0(idx, n)
            correct = (i2 == i) and (j2 == j)
            symbol = "✅" if correct else "❌"
            print(f"({i:2d},{j:2d}) -> idx {idx:4d} -> ({i2:2d},{j2:2d}) {symbol}")

# Example run
test_AB_P0(4)

--- Testing AB_P0 region for n=4 ---
( 0, 0) -> idx    0 -> ( 0, 0) ✅
( 0, 1) -> idx    1 -> ( 0, 1) ✅
( 0, 2) -> idx    2 -> ( 0, 2) ✅
( 0, 3) -> idx    3 -> ( 0, 3) ✅
( 0, 4) -> idx    4 -> ( 0, 4) ✅
( 0, 5) -> idx    5 -> ( 0, 5) ✅
( 0, 6) -> idx    6 -> ( 0, 6) ✅
( 0, 7) -> idx    7 -> ( 0, 7) ✅
( 0, 8) -> idx    8 -> ( 0, 8) ✅
( 0, 9) -> idx    9 -> ( 0, 9) ✅
( 0,10) -> idx   10 -> ( 0,10) ✅
( 0,11) -> idx   11 -> ( 0,11) ✅
( 1, 0) -> idx   12 -> ( 1, 0) ✅
( 1, 1) -> idx   13 -> ( 1, 1) ✅
( 1, 2) -> idx   14 -> ( 1, 2) ✅
( 1, 3) -> idx   15 -> ( 1, 3) ✅
( 1, 4) -> idx   16 -> ( 1, 4) ✅
( 1, 5) -> idx   17 -> ( 1, 5) ✅
( 1, 6) -> idx   18 -> ( 1, 6) ✅
( 1, 7) -> idx   19 -> ( 1, 7) ✅
( 1, 8) -> idx   20 -> ( 1, 8) ✅
( 1, 9) -> idx   21 -> ( 1, 9) ✅
( 1,10) -> idx   22 -> ( 1,10) ✅
( 1,11) -> idx   23 -> ( 1,11) ✅
( 2, 0) -> idx   24 -> ( 2, 0) ✅
( 2, 1) -> idx   25 -> ( 2, 1) ✅
( 2, 2) -> idx   26 -> ( 2, 2) ✅
( 2, 3) -> idx   27 -> ( 2, 3) ✅
( 2, 4) -> idx   28 -> ( 2, 4) ✅
( 2, 5

In [11]:
def test_region_B(n):
    """Test every (i,j) in the 3n×(2n−1) box for Region B."""
    height = 3 * n
    width = 2 * n
    mismatches = []
    print(f"--- Region B (n={n}) ---")
    for i in range(height):
        for j in range(width):
            if bool(is_coord_valid_B(i, j, n)):
                idx = coord_to_index_B(i, j, n)
                i2, j2 = index_to_coord_B(idx, n)
                correct = (int(i2) == i) and (int(j2) == j)
                print(f"(B) ({i:2d},{j:2d}) → idx={int(idx):3d} → ({int(i2):2d},{int(j2):2d})"
                      + (" ✅" if correct else "  ❌"))
                if not correct:
                    mismatches.append(((i, j), (int(idx), int(i2), int(j2))))
    if mismatches:
        print("B mismatches:", mismatches)
    else:
        print("All Region B points round‐trip correctly.\n")
test_region_B(4)

--- Region B (n=4) ---
(B) ( 0, 0) → idx=  0 → ( 0, 0) ✅
(B) ( 0, 1) → idx=  1 → ( 0, 1) ✅
(B) ( 0, 2) → idx=  2 → ( 0, 2) ✅
(B) ( 0, 3) → idx=  3 → ( 0, 3) ✅
(B) ( 0, 4) → idx=  4 → ( 0, 4) ✅
(B) ( 0, 5) → idx=  5 → ( 0, 5) ✅
(B) ( 0, 6) → idx=  6 → ( 0, 6) ✅
(B) ( 0, 7) → idx=  7 → ( 0, 7) ✅
(B) ( 1, 0) → idx=  8 → ( 1, 0) ✅
(B) ( 1, 1) → idx=  9 → ( 1, 1) ✅
(B) ( 1, 2) → idx= 10 → ( 1, 2) ✅
(B) ( 1, 3) → idx= 11 → ( 1, 3) ✅
(B) ( 1, 4) → idx= 12 → ( 1, 4) ✅
(B) ( 1, 5) → idx= 13 → ( 1, 5) ✅
(B) ( 1, 6) → idx= 14 → ( 1, 6) ✅
(B) ( 1, 7) → idx= 15 → ( 1, 7) ✅
(B) ( 2, 0) → idx= 16 → ( 2, 0) ✅
(B) ( 2, 1) → idx= 17 → ( 2, 1) ✅
(B) ( 2, 2) → idx= 18 → ( 2, 2) ✅
(B) ( 2, 3) → idx= 19 → ( 2, 3) ✅
(B) ( 2, 4) → idx= 20 → ( 2, 4) ✅
(B) ( 2, 5) → idx= 21 → ( 2, 5) ✅
(B) ( 2, 6) → idx= 22 → ( 2, 6) ✅
(B) ( 2, 7) → idx= 23 → ( 2, 7) ✅
(B) ( 3, 4) → idx= 24 → ( 3, 4) ✅
(B) ( 3, 5) → idx= 25 → ( 3, 5) ✅
(B) ( 3, 6) → idx= 26 → ( 3, 6) ✅
(B) ( 3, 7) → idx= 27 → ( 3, 7) ✅
(B) ( 4, 5) → idx= 28 → (