# Sudoku as a Constraint Satisfaction Problem (CSP)

This notebook is an educational tool to explore how Sudoku can be framed and solved as a Constraint Satisfaction Problem in AI.

You will:
- Define Sudoku as a CSP: variables, domains, constraints.
- Implement AC-3 arc consistency to prune domains before search.
- Implement Backtracking search with MRV/degree and LCV heuristics.
- Add Forward-Checking and MAC (Maintaining Arc Consistency) during search.
- Compare performance across approaches.

Feel free to edit and experiment with the code and puzzles.

## 1) Representing Sudoku as a CSP

- **Variables**: each cell in the 9x9 grid, labeled `A1..I9` (rows A–I, cols 1–9).
- **Domains**: for each variable, the set `{1..9}` unless the puzzle fixes a value.
- **Constraints**: all-different within each row, each column, and each 3x3 subgrid.

We'll precompute **peers** of each variable (those that share a row, column, or box) and enforce binary `!=` constraints pairwise.

In [14]:
from time import perf_counter
from collections import deque, defaultdict
import math

# Helpers for Sudoku variable naming
ROWS = 'ABCDEFGHI'
COLS = '123456789'

def cross(A, B):
    return [a + b for a in A for b in B]

VARS = cross(ROWS, COLS)
ROW_UNITS = [cross(r, COLS) for r in ROWS]
COL_UNITS = [cross(ROWS, c) for c in COLS]
BOX_UNITS = [cross(rs, cs) for rs in ('ABC','DEF','GHI') for cs in ('123','456','789')]
UNITS = ROW_UNITS + COL_UNITS + BOX_UNITS
UNITS_OF = {v: [u for u in UNITS if v in u] for v in VARS}
PEERS = {v: set(sum(UNITS_OF[v], [])) - {v} for v in VARS}
ARCS = [(xi, xj) for xi in VARS for xj in PEERS[xi]]

def parse_grid(grid_str):
    """
    Parse a Sudoku string into initial domains.
    Accepts 81-char string with digits for given cells and 0 or . for empties.
    Returns dict var->set of possible values (as strings).
    """
    chars = [c for c in grid_str if c in '0.-123456789']
    if len(chars) != 81:
        raise ValueError('Grid string must have 81 characters of 0-9 or .')
    digits = '123456789'
    domains = {v: set(digits) for v in VARS}
    for v, c in zip(VARS, chars):
        if c in digits:
            domains[v] = {c}
    return domains

def display(domains):
    """Pretty print current assignment/domains as a grid.
    If a cell is singleton, print the value; otherwise length of domain.
    """
    width = 1 + max(len(''.join(sorted(domains[v]))) for v in VARS)
    line = '+'.join(['-'*(width*3)]*3)
    for r in ROWS:
        row = ''
        for c in COLS:
            v = r+c
            val = ''.join(sorted(domains[v]))
            row += val.center(width)
            if c in '36': row += '|'
        print(row)
        if r in 'CF': print(line)

def is_assignment_complete(domains):
    return all(len(domains[v]) == 1 for v in VARS)

def consistent(domains, var, val):
    """Check consistency of setting var=val with neighbors under != constraints.
    """
    for n in PEERS[var]:
        if val in domains[n] and len(domains[n]) == 1:
            return False
    return True


## 2) AC-3: Enforcing Arc Consistency

AC-3 iteratively prunes domain values that have no support in neighbor domains for each arc `(Xi, Xj)`.

- `revise(domains, Xi, Xj)`: remove `v` from `Xi`'s domain if `Xj`'s domain has no value different from `v`.
- Queue all arcs initially; when `Xi` changes, re-enqueue `(Xk, Xi)` for all other neighbors `Xk`.

We return a flag indicating success and some counters for pedagogy.

In [15]:
def revise(domains, Xi, Xj):
    removed = set()
    for x in set(domains[Xi]):
        # Find any y in Xj's domain that satisfies x != y
        if not any(y != x for y in domains[Xj]):
            domains[Xi].discard(x)
            removed.add(x)
    return removed

def ac3(domains):
    queue = deque(ARCS)
    checks = 0
    pruned = 0
    while queue:
        Xi, Xj = queue.popleft()
        checks += 1
        removed = revise(domains, Xi, Xj)
        if removed:
            pruned += len(removed)
            if len(domains[Xi]) == 0:
                return False, checks, pruned
            for Xk in PEERS[Xi] - {Xj}:
                queue.append((Xk, Xi))
    return True, checks, pruned


## 3) Search: Backtracking with Heuristics

We’ll implement three variants:
- Plain Backtracking (BT)
- BT + Forward-Checking (FC)
- BT + MAC (run AC-3 on affected arcs after each assignment)

Heuristics:
- MRV: choose the unassigned variable with the smallest domain.
- Degree heuristic: break ties by picking the variable with most unassigned neighbors.
- LCV: try values that eliminate the fewest neighbor values first.

In [16]:
def select_unassigned_variable(domains):
    unassigned = [v for v in VARS if len(domains[v]) > 1]
    # MRV
    m = min(len(domains[v]) for v in unassigned)
    candidates = [v for v in unassigned if len(domains[v]) == m]
    if len(candidates) == 1:
        return candidates[0]
    # Degree heuristic
    def degree(v):
        return sum(1 for n in PEERS[v] if len(domains[n]) > 1)
    return max(candidates, key=degree)

def order_values_lcv(domains, var):
    def conflicts(val):
        # How many neighbor values would be eliminated by assigning var=val?
        count = 0
        for n in PEERS[var]:
            if val in domains[n]:
                count += 1
        return count
    return sorted(domains[var], key=conflicts)

def deep_copy_domains(domains):
    return {v: set(vals) for v, vals in domains.items()}

class Metrics:
    def __init__(self, ac3_checks=0, ac3_pruned=0):
        self.assignments = 0
        self.backtracks = 0
        self.inferences = 0
        self.ac3_checks = ac3_checks
        self.ac3_pruned = ac3_pruned

def forward_check(domains, var, val):
    inferences = []
    for n in PEERS[var]:
        if val in domains[n]:
            if len(domains[n]) == 1:
                return None  # failure
            domains[n].remove(val)
            inferences.append((n, val))
    return inferences

def undo_inferences(domains, inferences):
    if not inferences: return
    for v, val in inferences:
        domains[v].add(val)

def backtrack(domains, use_fc=False, use_mac=False, metrics=None):
    if is_assignment_complete(domains):
        return domains
    var = select_unassigned_variable(domains)
    for val in order_values_lcv(domains, var):
        if consistent(domains, var, val):
            # Try assign
            snapshot = deep_copy_domains(domains)
            domains[var] = {val}
            if metrics: metrics.assignments += 1

            inferences = None
            if use_fc:
                inferences = forward_check(domains, var, val)
                if inferences is None:
                    domains.update(snapshot)
                    continue
                if metrics: metrics.inferences += len(inferences)

            if use_mac:
                # Run AC-3 restricted to arcs involving neighbors of var
                queue = deque([(n, var) for n in PEERS[var]])
                checks = pruned = 0
                while queue:
                    Xi, Xj = queue.popleft()
                    checks += 1
                    removed = revise(domains, Xi, Xj)
                    if removed:
                        pruned += len(removed)
                        if len(domains[Xi]) == 0:
                            # Restore and try next value
                            domains.update(snapshot)
                            break
                        for Xk in PEERS[Xi] - {Xj}:
                            queue.append((Xk, Xi))
                else:
                    if metrics: metrics.ac3_checks += checks; metrics.ac3_pruned += pruned
                    result = backtrack(domains, use_fc, use_mac, metrics)
                    if result: return result
                    # Restore on backtrack
                    domains.update(snapshot)
                    if metrics: metrics.backtracks += 1
                    continue
                # If MAC failed, move to next value
                if metrics: metrics.backtracks += 1
                continue

            # Plain BT or FC without MAC
            result = backtrack(domains, use_fc, use_mac, metrics)
            if result: return result

            # Undo if failed
            if use_fc and inferences is not None:
                undo_inferences(domains, inferences)
            domains.update(snapshot)
            if metrics: metrics.backtracks += 1
    return None


## 4) Convenience: Run and Compare Solvers

Use `run_solver` to apply a variant to a puzzle and measure time and effort.

In [17]:
def run_solver(grid_str, method='AC3+BT'):
    base = parse_grid(grid_str)
    # Pre-AC3 pass
    d0 = deep_copy_domains(base)
    ok, checks, pruned = ac3(d0)
    ac3_time = None

    if not ok:
        return {'status': 'contradiction after AC3', 'ac3_checks': checks, 'ac3_pruned': pruned}

    metrics = Metrics(checks, pruned)
    start = perf_counter()
    if method == 'BT':
        sol = backtrack(d0, use_fc=False, use_mac=False, metrics=metrics)
    elif method == 'BT+FC':
        sol = backtrack(d0, use_fc=True, use_mac=False, metrics=metrics)
    elif method == 'BT+MAC':
        sol = backtrack(d0, use_fc=True, use_mac=True, metrics=metrics)
    else:
        raise ValueError('Unknown method')
    elapsed = perf_counter() - start
    return {
        'solution': sol,
        'time_sec': elapsed,
        'assignments': metrics.assignments,
        'backtracks': metrics.backtracks,
        'inferences': metrics.inferences,
        'ac3_checks': metrics.ac3_checks,
        'ac3_pruned': metrics.ac3_pruned,
    }

def print_result(title, result):
    print(f'=== {title} ===')
    if result.get('solution'):
        display(result['solution'])
        print('time_sec     :', f"{result['time_sec']:.6f}")
        print('assignments  :', result['assignments'])
        print('backtracks   :', result['backtracks'])
        print('inferences   :', result['inferences'])
        print('ac3_checks   :', result['ac3_checks'])
        print('ac3_pruned   :', result['ac3_pruned'])
    else:
        print(result)


## 5) Try It: Input a Sudoku

- Provide an 81-character string with digits `1-9` and `0` or `.` for empty cells.
- Example below is a moderately hard puzzle.

You can run each solver and compare results and performance.

In [19]:
# You can replace this with your own puzzle string (81 chars)
puzzle = (
    '..3.2.6..'
    '9..3.5..1'
    '..18.64..'
    '..81.29..'
    '7.......8'
    '..67.82..'
    '..26.95..'
    '8..2.3..9'
    '..5.1.3..'
)
print('Initial puzzle:')
display(parse_grid(puzzle))

print('Solving with Backtracking (BT)...')
res_bt = run_solver(puzzle, method='BT')
print_result('Backtracking (BT)', res_bt)

print('Solving with Backtracking + Forward-Checking (BT+FC)...')
res_fc = run_solver(puzzle, method='BT+FC')
print_result('Backtracking + Forward-Checking (BT+FC)', res_fc)

print('Solving with Backtracking + MAC (BT+MAC)...')
res_mac = run_solver(puzzle, method='BT+MAC')
print_result('Backtracking + MAC (BT+MAC)', res_mac)


Initial puzzle:
123456789 123456789     3     |123456789     2     123456789 |    6     123456789 123456789 
    9     123456789 123456789 |    3     123456789     5     |123456789 123456789     1     
123456789 123456789     1     |    8     123456789     6     |    4     123456789 123456789 
------------------------------+------------------------------+------------------------------
123456789 123456789     8     |    1     123456789     2     |    9     123456789 123456789 
    7     123456789 123456789 |123456789 123456789 123456789 |123456789 123456789     8     
123456789 123456789     6     |    7     123456789     8     |    2     123456789 123456789 
------------------------------+------------------------------+------------------------------
123456789 123456789     2     |    6     123456789     9     |    5     123456789 123456789 
    8     123456789 123456789 |    2     123456789     3     |123456789 123456789     9     
123456789 123456789     5     |123456789     1     123

In [20]:
# You can replace this with your own puzzle string (81 chars)
puzzle = (
    '...1..7.2'
    '.3.95....'
    '..1..2..3'
    '59....3.1'
    '.2.....7.'
    '7.3....98'
    '8..2..1..'
    '....85.6.'
    '6.5..9...'
)
print('Initial puzzle:')
display(parse_grid(puzzle))
 
print('Solving with Backtracking (BT)...')
res_bt = run_solver(puzzle, method='BT')
print_result('Backtracking (BT)', res_bt)

print('Solving with Backtracking + Forward-Checking (BT+FC)...')
res_fc = run_solver(puzzle, method='BT+FC')
print_result('Backtracking + Forward-Checking (BT+FC)', res_fc)

print('Solving with Backtracking + MAC (BT+MAC)...')
res_mac = run_solver(puzzle, method='BT+MAC')
print_result('Backtracking + MAC (BT+MAC)', res_mac)


Initial puzzle:
123456789 123456789 123456789 |    1     123456789 123456789 |    7     123456789     2     
123456789     3     123456789 |    9         5     123456789 |123456789 123456789 123456789 
123456789 123456789     1     |123456789 123456789     2     |123456789 123456789     3     
------------------------------+------------------------------+------------------------------
    5         9     123456789 |123456789 123456789 123456789 |    3     123456789     1     
123456789     2     123456789 |123456789 123456789 123456789 |123456789     7     123456789 
    7     123456789     3     |123456789 123456789 123456789 |123456789     9         8     
------------------------------+------------------------------+------------------------------
    8     123456789 123456789 |    2     123456789 123456789 |    1     123456789 123456789 
123456789 123456789 123456789 |123456789     8         5     |123456789     6     123456789 
    6     123456789     5     |123456789 123456789    