# Sudoku STC Solver

In [1]:
import sys
from pathlib import Path

cwd = Path.cwd()
project_root = cwd if (cwd / "stc_sudoku").is_dir() else cwd.parent
sys.path.insert(0, str(project_root))

In [2]:
import numpy as np
from stc_sudoku.funcs import (
    run_stc_solver,
    run_parallel_tempering_stc,
    init_stc_state,
    generate_samples_smc,
    sinkhorn_projection,
    calculate_energy_dynamic,
    update_tensor,
    targeted_cluster_swap,
)

In [3]:
from sudoku import Sudoku

In [4]:
def display_board(board_array, size):
    box_size = int(np.sqrt(size))
    board_list = board_array.tolist()
    puzzle = Sudoku(box_size, box_size, board=board_list)
    puzzle.show()

## 4x4 Board

In [5]:
puzzle_4x4 = np.array([
    [0, 3, 4, 0],
    [4, 0, 0, 2],
    [1, 0, 0, 3],
    [0, 2, 1, 0]
])

print("Initial 4x4 Puzzle:")
display_board(puzzle_4x4, size=4)

Initial 4x4 Puzzle:
Puzzle has exactly one solution
+-----+-----+
|   3 | 4   |
| 4   |   2 |
+-----+-----+
| 1   |   3 |
|   2 | 1   |
+-----+-----+



In [6]:
solution_4x4 = run_stc_solver(size=4, given_puzzle=puzzle_4x4, iterations=100, batch_size=500)

print("\nFinal 4x4 Output:")
display_board(solution_4x4, size=4)


Starting 4x4 Dynamic GLS STC Solver...
Step   0 | Actual Collisions:   2 | Weighted E:   2.0 | Patience: 0
Step   1 | Actual Collisions:   0 | Weighted E:   0.0 | Patience: 0
>>> Perfect Solution Found! <<<

Final 4x4 Output:
Puzzle has exactly one solution
+-----+-----+
| 2 3 | 4 1 |
| 4 1 | 3 2 |
+-----+-----+
| 1 4 | 2 3 |
| 3 2 | 1 4 |
+-----+-----+



## 9x9 Board

In [7]:
# Classic 9x9 puzzle (unique solution)
puzzle_9x9 = np.array([
    [5, 3, 0, 0, 7, 0, 0, 0, 0],
    [6, 0, 0, 1, 9, 5, 0, 0, 0],
    [0, 9, 8, 0, 0, 0, 0, 6, 0],
    [8, 0, 0, 0, 6, 0, 0, 0, 3],
    [4, 0, 0, 8, 0, 3, 0, 0, 1],
    [7, 0, 0, 0, 2, 0, 0, 0, 6],
    [0, 6, 0, 0, 0, 0, 2, 8, 0],
    [0, 0, 0, 4, 1, 9, 0, 0, 5],
    [0, 0, 0, 0, 8, 0, 0, 7, 9],
], dtype=int)

print("Initial 9x9 Puzzle:")
display_board(puzzle_9x9, size=9)

Initial 9x9 Puzzle:
Puzzle has exactly one solution
+-------+-------+-------+
| 5 3   |   7   |       |
| 6     | 1 9 5 |       |
|   9 8 |       |   6   |
+-------+-------+-------+
| 8     |   6   |     3 |
| 4     | 8   3 |     1 |
| 7     |   2   |     6 |
+-------+-------+-------+
|   6   |       | 2 8   |
|       | 4 1 9 |     5 |
|       |   8   |   7 9 |
+-------+-------+-------+



In [8]:
solution_9x9 = run_stc_solver(size=9, given_puzzle=puzzle_9x9, iterations=2000, batch_size=4000)

print("\nFinal 9x9 Output:")
display_board(solution_9x9, size=9)


Starting 9x9 Dynamic GLS STC Solver...
Step   0 | Actual Collisions:  30 | Weighted E:  30.0 | Patience: 0
Step   7 | Actual Collisions:   0 | Weighted E:   0.0 | Patience: 0
>>> Perfect Solution Found! <<<

Final 9x9 Output:
Puzzle has exactly one solution
+-------+-------+-------+
| 5 3 4 | 6 7 8 | 9 1 2 |
| 6 7 2 | 1 9 5 | 3 4 8 |
| 1 9 8 | 3 4 2 | 5 6 7 |
+-------+-------+-------+
| 8 5 9 | 7 6 1 | 4 2 3 |
| 4 2 6 | 8 5 3 | 7 9 1 |
| 7 1 3 | 9 2 4 | 8 5 6 |
+-------+-------+-------+
| 9 6 1 | 5 3 7 | 2 8 4 |
| 2 8 7 | 4 1 9 | 6 3 5 |
| 3 4 5 | 2 8 6 | 1 7 9 |
+-------+-------+-------+



In [None]:
def solve_9x9(puzzle, method="stc", **kwargs):
    """Modular solver: method in ('stc', 'parallel_tempering'). Uses cluster swap when close."""
    size = 9
    if method == "stc":
        return run_stc_solver(size=size, given_puzzle=puzzle, **kwargs)
    if method == "parallel_tempering":
        return run_parallel_tempering_stc(size=size, given_puzzle=puzzle, **kwargs)
    raise ValueError("method must be 'stc' or 'parallel_tempering'")

solution_9x9_alt = solve_9x9(puzzle_9x9, method="parallel_tempering", iterations=2000, batch_per_replica=4000)
print("\nFinal 9x9 Output:")
display_board(solution_9x9_alt, size=9)

In [9]:
puzzle2_9x9 = np.array([
    [0, 0, 2, 1, 0, 0, 0, 3, 0],
    [3, 0, 0, 0, 0, 0, 0, 0, 8],
    [0, 9, 1, 0, 0, 0, 0, 0, 0],
    [0, 8, 0, 9, 0, 0, 0, 2, 5],
    [0, 6, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 8, 0, 0, 7],
    [9, 0, 0, 0, 0, 0, 7, 0, 2],
    [0, 0, 0, 0, 1, 4, 0, 0, 0],
    [4, 0, 0, 3, 7, 0, 6, 8, 0],
], dtype=int)
print("Initial 9x9 Puzzle:")
display_board(puzzle2_9x9, size=9)

'''solution2_9x9 = run_parallel_tempering_stc(size=9, given_puzzle=puzzle2_9x9, iterations=2000, batch_per_replica=4000)

print("\nFinal 9x9 Output:")
display_board(solution2_9x9, size=9)'''

Initial 9x9 Puzzle:
Puzzle has exactly one solution
+-------+-------+-------+
|     2 | 1     |   3   |
| 3     |       |     8 |
|   9 1 |       |       |
+-------+-------+-------+
|   8   | 9     |   2 5 |
|   6   |       |       |
|       |     8 |     7 |
+-------+-------+-------+
| 9     |       | 7   2 |
|       |   1 4 |       |
| 4     | 3 7   | 6 8   |
+-------+-------+-------+



'solution2_9x9 = run_parallel_tempering_stc(size=9, given_puzzle=puzzle2_9x9, iterations=2000, batch_per_replica=4000)\n\nprint("\nFinal 9x9 Output:")\ndisplay_board(solution2_9x9, size=9)'

In [None]:
# Solve puzzle2_9x9 with GLS STC (includes cluster swap when 1–4 collisions remain)
solution2_9x9 = run_stc_solver(size=9, given_puzzle=puzzle2_9x9, iterations=2000, batch_size=4000)
print("\nFinal puzzle2_9x9 Output:")
display_board(solution2_9x9, size=9)


Starting 9x9 Dynamic GLS STC Solver...
Step   0 | Actual Collisions:  36 | Weighted E:  36.0 | Patience: 0
Step  40 | Stuck at 7 collisions (patience 26). Warping...
Step  50 | Actual Collisions:   7 | Weighted E:  21.0 | Patience: 10
Step  66 | Stuck at 7 collisions (patience 26). Warping...
Step  92 | Stuck at 7 collisions (patience 26). Warping...
Step 100 | Actual Collisions:   7 | Weighted E:  49.0 | Patience: 8
Step 118 | Stuck at 7 collisions (patience 26). Warping...
Step 144 | Stuck at 7 collisions (patience 26). Warping...
Step 150 | Actual Collisions:   7 | Weighted E:  77.0 | Patience: 6
Step 170 | Stuck at 7 collisions (patience 26). Warping...
Step 196 | Stuck at 7 collisions (patience 26). Warping...
Step 200 | Actual Collisions:   7 | Weighted E: 105.0 | Patience: 4
Step 222 | Stuck at 7 collisions (patience 26). Warping...
Step 248 | Stuck at 7 collisions (patience 26). Warping...
Step 250 | Actual Collisions:   7 | Weighted E: 133.0 | Patience: 2
Step 274 | Stuck at 

## SMC + Sinkhorn on puzzle2_9x9

Example: use **sinkhorn_projection** to enforce row/col/box constraints in continuous space, then **generate_samples_smc** to draw collision-free samples with forward masking. Update logits from elite samples and repeat until a complete valid grid is found.

In [None]:
# Solve puzzle2_9x9 with SMC + Sinkhorn (with cluster swap when 1–4 collisions remain)
size = 9
logits, fixed_mask = init_stc_state(size, puzzle2_9x9)
ones = np.ones(size)
batch = 2000
for step in range(500):
    logits = sinkhorn_projection(logits, iterations=10, given_puzzle=puzzle2_9x9, fixed_mask=fixed_mask)
    samples, smc_energies = generate_samples_smc(logits, batch, given_puzzle=puzzle2_9x9, fixed_mask=fixed_mask)
    energies, _ = calculate_energy_dynamic(samples, ones, ones, ones)
    best_idx = np.argmin(energies)
    min_e = energies[best_idx]
    if step % 25 == 0:
        n_complete = np.sum(smc_energies == 0)
        print(f"Step {step:3d} | Min collisions: {min_e:.0f} | SMC completions: {n_complete}/{batch}", flush=True)
    if min_e == 0:
        print(">>> Perfect solution found (SMC+Sinkhorn) <<<", flush=True)
        solution_smc = samples[best_idx]
        break
    if min_e > 0 and min_e <= 4:
        best_board = samples[best_idx].copy()
        fixed_board = targeted_cluster_swap(best_board, fixed_mask, verbose=True)
        e_fixed, _ = calculate_energy_dynamic(fixed_board.reshape(1, size, size), ones, ones, ones)
        if e_fixed[0] == 0:
            print(">>> Perfect solution found via cluster swap (SMC+Sinkhorn) <<<", flush=True)
            solution_smc = fixed_board
            break
    logits = update_tensor(logits, samples, energies, learning_rate=0.3, fixed_mask=fixed_mask, given_puzzle=puzzle2_9x9)
else:
    solution_smc = samples[np.argmin(energies)]
    print(">>> Max steps reached, returning best <<<", flush=True)

print("\nFinal 9x9 (SMC+Sinkhorn):")
display_board(solution_smc, size=9)

Step   0 | Min collisions: 11 | SMC completions: 8/2000
Step  25 | Min collisions: 2 | SMC completions: 1996/2000
Step  50 | Min collisions: 2 | SMC completions: 1999/2000
Step  75 | Min collisions: 2 | SMC completions: 2000/2000
Step 100 | Min collisions: 2 | SMC completions: 1999/2000
Step 125 | Min collisions: 2 | SMC completions: 1999/2000
Step 150 | Min collisions: 2 | SMC completions: 1999/2000
Step 175 | Min collisions: 2 | SMC completions: 1999/2000
Step 200 | Min collisions: 2 | SMC completions: 2000/2000
Step 225 | Min collisions: 2 | SMC completions: 2000/2000
Step 250 | Min collisions: 2 | SMC completions: 2000/2000
Step 275 | Min collisions: 2 | SMC completions: 1998/2000
Step 300 | Min collisions: 2 | SMC completions: 2000/2000
Step 325 | Min collisions: 2 | SMC completions: 1997/2000
Step 350 | Min collisions: 2 | SMC completions: 1999/2000
Step 375 | Min collisions: 2 | SMC completions: 1998/2000
