# Sudoku Logical Solver Tutorial

Basically, the main idea is, use the following 2-step method to fill all the empty cells in a sudoku:
step 1: in each row/column/square, calculate the candidate list of the current state
step 2: simplify the candidate list by doing the following:
        check if there is this condition satisfied: if in n empty cells, only n total candidates are possible, then delete other candidates in those cells, meanwhile delete the n such candidates in other empty cells of this row/column/square. Then check for each digit from 1 to 9, if it only appears in the candidate list of one single empty cell for each row/column/square, if so, then fill in it. if there's also empty cell that only has one possible candidates,then fill it in.
Loop step 1 and 2 through all the rows and columns and squares until all cells are filled


In [1]:
import numpy as np
import pandas as pd
from itertools import combinations
from collections import defaultdict

In [2]:
# Example Sudoku board (medium difficulty)
sudoku_medium = [
    [None, 2, None, 6, None, 8, None, None, None],
    [5, 8, None, None, None, 9, 7, None, None],
    [None, None, None, None, 4, None, None, None, None],
    [3, 7, None, None, None, None, 5, None, None],
    [6, None, None, None, None, None, None, None, 4],
    [None, None, 8, None, None, None, None, 1, 3],
    [None, None, None, None, 2, None, None, None, None],
    [None, None, 9, 8, None, None, None, 3, 6],
    [None, None, None, 3, None, 6, None, 9, None],
]

In [3]:
sudoku_from_image = [
    [9, 1, None, None, None, 7, None, None, None],
    [None, 7, None, 1, None, 3, None, None, 8],
    [6, None, None, None, None, None, 4, None, None],
    [None, None, 2, None, None, None, None, 8, None],
    [None, None, None, None, 5, None, 7, 3, 4],
    [None, None, None, None, None, None, None, 1, None],
    [3, 4, 7, 2, None, None, 8, None, None],
    [None, None, None, None, None, 9, None, 6, None],
    [None, None, None, 8, None, None, None, None, 7],
]

In [4]:
# print the result as form of sudoko
def print_sudoku(board):
    """Prints the Sudoku board in a readable format."""
    for row in board:
        print(" ".join(str(num) if num is not None else "." for num in row))

In [5]:
def if_valid(table, row, col, num):
    # check if number if valid within the row and column
    for i in range(9):
      if table[row][i]==num or table[i][col]==num:
         return False
    # check if valid in the 3*3 square
    start_row, start_col = 3 * (row // 3), 3 * (col// 3)
    for i in range(3):
        for j in range(3):
            if table[start_row + i][start_col + j] == num:
                return False
    return True

To apply techniques later, we need to take out three different candidates_dicts for each position, each corresponds to the row/col/square of that position.

In [6]:
# Find possible digits for each position
def get_candidates(table):
    candidates_dict={}
    for row in range(9):
        for col in range(9):
            if table[row][col] is None:
                candidates = set()
                for num in range(1,10):
                    if if_valid(table,row,col,num):
                        candidates.add(num)
                candidates_dict[(row,col)]=candidates
    
    return candidates_dict

In [7]:
# take out the row/col/square candidate dictionary in one function
def select_candidates(pos, candidates_dict, axis):
    """
    attribute "axis" denotes which candidates dict in this iteration we are focusing on
            axis == "r" : take out the row candidates
            axis == "c" : take out the column candidates
            axis == "s" : take out the square candidates
    """
    row,col=pos
    selected_candidates={}

    if axis == "r":
        selected_candidates = {k: v for k, v in candidates_dict.items() if k[0] == row}

    elif axis == "c":
        selected_candidates = {k: v for k, v in candidates_dict.items() if k[1] == col}

    elif axis == "s":
        start_row = 3 * (row // 3)
        start_col = 3 * (col // 3)
        selected_candidates = {
            k: v for k, v in candidates_dict.items()
            if start_row <= k[0] < start_row + 3 and start_col <= k[1] < start_col + 3
        }

    return selected_candidates


In [8]:
# naked subset elimination
def naked_subset_elimination(candidates_dict):
    
    keys = list(candidates_dict.keys()) 
    N = len(candidates_dict)
    
    for n in range(1, N):
        for combo in combinations(keys, n):
            union_set = set()

            for key in combo:
                union_set.update(candidates_dict[key])
            
            if len(union_set) == n:
                rest = [k for k in keys if k not in combo]
                for rest_key in rest:
                    candidates_dict[rest_key] = candidates_dict[rest_key] - union_set
 
    return candidates_dict

In [9]:
# hidden subset elimination
def hidden_subset_elimination(candidates_dict):
    num_positions = defaultdict(list)

    # decide which position a num appears
    for pos, candidates in candidates_dict.items():
        for num in candidates:
            num_positions[num].append(pos)

    # to check if there are hidden single/pair/triple/quad
    for n in range(1,5):
        for nums in combinations(num_positions.keys(), n):
            positions = set()
            for num in nums:
                positions.update(num_positions[num])
            if len(positions) == n:
                for pos in positions:
                    candidates_dict[pos] = candidates_dict[pos].intersection(set(nums))
    
    return candidates_dict

In [10]:
# add pointing (notice that this technique need to check square candidates first, then
# to deal with the rows or columns)
def pointing_elimination(pos, candidates_dict):
    square_candidates = select_candidates(pos, candidates_dict, axis="s")

    for num in range(1, 10):
        positions_with_num = [p for p, v in square_candidates.items() if num in v]

        if len(positions_with_num) <= 1:
            continue

        rows = {r for r, _ in positions_with_num}
        cols = {c for _, c in positions_with_num}

        if len(rows) == 1:
            target_row = next(iter(rows))
            for col in range(9):
                if (target_row, col) not in square_candidates and (target_row, col) in candidates_dict:
                    candidates_dict[(target_row, col)] -= {num}

        elif len(cols) == 1:
            target_col = next(iter(cols))
            for row in range(9):
                if (row, target_col) not in square_candidates and (row, target_col) in candidates_dict:
                    candidates_dict[(row, target_col)] -= {num}

    return candidates_dict

In [11]:
# add x wing elimination
def xwing_elimination(candidates_dict):
    
    # scan the rows for xwing structures
    for num in range(1,10):
        
        # start with an empty list 
        row_to_cols = defaultdict(list)

        for (row, col), candidates in candidates_dict.items():
            if num in candidates:
                row_to_cols[row].append(col)

        # try to find if there is a x wing exists for the given number
        rows = list(row_to_cols.keys())

        for i in range(len(rows)):

            for j in range(i + 1, len(rows)):
                r1, r2 = rows[i], rows[j]
                c1s = row_to_cols[r1]
                c2s = row_to_cols[r2]

                if len(c1s) == 2 and c1s == c2s:
                    
                    col1, col2 = c1s

                    for row in range(9):
                        if row != r1 and row != r2:
                            for col in [col1, col2]:
                                if (row, col) in candidates_dict:
                                    candidates_dict[(row, col)] -= {num}

    # scan the columns for xwing structures
    for num in range(1, 10):
        col_to_rows = defaultdict(list)

        for (row, col), candidates in candidates_dict.items():
            if num in candidates:
                col_to_rows[col].append(row)

        cols = list(col_to_rows.keys())
        for i in range(len(cols)):
            for j in range(i + 1, len(cols)):
                c1, c2 = cols[i], cols[j]
                r1s = col_to_rows[c1]
                r2s = col_to_rows[c2]

                if len(r1s) == 2 and r1s == r2s:
                    row1, row2 = r1s
                    for col in range(9):
                        if col != c1 and col != c2:
                            for row in [row1, row2]:
                                if (row, col) in candidates_dict:
                                    candidates_dict[(row, col)] -= {num}
    
    
    return candidates_dict

In [12]:
# add claiming (如果一个数字 d 在某一行（或列）中，所有的候选格都集中在同一个宫格（3×3）中，
# 那么我们就可以在这个宫格中排除掉其他格子中的该数字 d)
def claiming_elimination(candidates_dict):

    for num in range(1,10):
        for row in range(9):
            for row in range(9):
                positions = [(row, col) for col in range(9) if (row, col) in candidates_dict and num in candidates_dict[(row, col)]]
            
            if len(positions) >= 2:
                blocks = {(r // 3, c // 3) for r, c in positions}
                if len(blocks) == 1:
                    block_row, block_col = next(iter(blocks))
                    for r in range(3 * block_row, 3 * block_row + 3):
                        for c in range(3 * block_col, 3 * block_col + 3):
                            if r != row and (r, c) in candidates_dict:
                                candidates_dict[(r, c)] -= {num}

        for col in range(9):
            positions = [(row, col) for row in range(9) if (row, col) in candidates_dict and num in candidates_dict[(row, col)]]
            if len(positions) >= 2:
                blocks = {(r // 3, c // 3) for r, c in positions}
                if len(blocks) == 1:
                    block_row, block_col = next(iter(blocks))
                    for r in range(3 * block_row, 3 * block_row + 3):
                        for c in range(3 * block_col, 3 * block_col + 3):
                            if c != col and (r, c) in candidates_dict:
                                candidates_dict[(r, c)] -= {num}
    
    return candidates_dict

In [13]:
# finally, the easiest one: for cells with a single sole candidate, just fill it in
def fill_sole_candidate(table,candidates_dict):
    for (row, col), candidates in list(candidates_dict.items()):
        if len(candidates) == 1:
            value = next(iter(candidates))
            if table[row][col] is None:
                table[row][col] = value
                del candidates_dict[(row, col)]

    return table

In [14]:
def solve_sudoku(table):
    changed = True

    while changed:
        changed = False

        table_before = [row.copy() for row in table]
        candidates_dict = get_candidates(table)
        if not candidates_dict:
            break

        for axis in ["r", "c", "s"]:
            for i in range(9):
                # determine starting point
                if axis == "r":
                    pos = (i, 0)
                elif axis == "c":
                    pos = (0, i)
                else:
                    pos = (3 * (i // 3), 3 * (i % 3))

                # take out sub area candidates 
                sub_candidates = select_candidates(pos, candidates_dict, axis)

                if axis == "s":
                    candidates_dict = pointing_elimination(pos, candidates_dict)

                sub_candidates = naked_subset_elimination(sub_candidates)
                sub_candidates = hidden_subset_elimination(sub_candidates)

                # update main candidate
                for k in sub_candidates:
                    candidates_dict[k] = sub_candidates[k]

        # add xwing elimination
        candidates_dict = xwing_elimination(candidates_dict)

        # add claiming
        candidates_dict=claiming_elimination(candidates_dict)

        # try to fill in the only candidate
        table = fill_sole_candidate(table, candidates_dict)

        if table != table_before:
            changed = True

    return table

In [15]:
table=solve_sudoku(sudoku_from_image)
print_sudoku(table)

9 1 8 . . 7 . . .
. 7 . 1 . 3 . . 8
6 . . . . 8 4 7 1
. . 2 . . . . 8 .
. . . . 5 2 7 3 4
. . . . 8 . . 1 .
3 4 7 2 . . 8 . .
. . . . . 9 . 6 .
. . . 8 3 5 . 4 7
