# Sudoku solver

In [58]:
rows = 'ABCDEFGHI'
cols = '123456789'

In [59]:
def cross(a, b):
    return [s + t for s in a for t in b]

## Setup boxes, units and peers

Let's start naming the important elements created by these rows and columns that are relevant to solving a Sudoku:

- The individual squares at the intersection of rows and columns will be called __boxes__. These boxes will have labels 'A1', 'A2', ..., 'I9'.
- The complete rows, columns, and 3x3 squares, will be called __units__. Thus, each unit is a set of 9 boxes, and there are 27 units in total.
- For a particular box (such as 'A1'), its __peers__ will be all other boxes that belong to a common unit (namely, those that belong to the same row, column, or 3x3 square).

In [60]:
boxes = cross(rows, cols)

In [61]:
row_units = [cross(r, cols) for r in rows]

column_units = [cross(rows, c) for c in cols]

square_units = [cross(rs, cs) for rs in ('ABC','DEF','GHI') 
                              for cs in ('123','456','789')]

unitlist = row_units + column_units + square_units

In [62]:
#units_dict = dict((b, [u for u in unit_list if b in u]) for b in boxes)

In [63]:
def create_units_dict():

    units_dict = dict(zip(boxes, [[] for _ in boxes]))
    
    for unit in unitlist:
        for box in unit:
            units_dict[box].append(unit)

    return units_dict

units_dict = create_units_dict()

In [64]:
#peers_dict = dict((s, set(sum(units_dict[s], [])) - set([s])) for s in boxes)

In [65]:
def create_peers_dict():

    peers_dict = {}
    
    for box in boxes:
        blist = [b for u in units_dict[box] for b in u]
        box_peers = set(blist) - set([box])
        peers_dict[box] = box_peers
        
    return peers_dict

peers_dict = create_peers_dict()

In [66]:
def display(values):

    width = 1 + max(len(values[s]) for s in boxes)

    line = '+'.join(['-' * (width * 3)] * 3)

    for r in rows:
        print(''.join(values[r + c].center(width) 
                      + ('|' if c in '36' else '') for c in cols))
        if r in 'CF': 
            print(line)

In [67]:
puzzle1 = '..3.2.6..9..3.5..1..18.64....81.29..7.......8..67.82....26.95..8..2.3..9..5.1.3..'
puzzle2 = '4.....8.5.3..........7......2.....6.....8.4......1.......6.3.7.5..2.....1.4......'

In [68]:
def grid_values(grid):
    
    assert len(grid) == len(boxes)
    
    return dict(zip(boxes, [v for v in grid]))

In [69]:
def init_state(values):
    
    values_string = '123456789'

    result = values.copy()
    for box in boxes:
        if result[box] == ".":
            result[box] = values_string

    return result

In [70]:
def eliminate(values):

    solved_boxes = [box for box in values.keys() if len(values[box]) == 1]

    result = values.copy()
    
    for box in solved_boxes:
        for peer in peers_dict[box]:
            result[peer] = result[peer].replace(result[box], '')
            
    return result

In [71]:
def only_choice(values):
    
    values_string = '123456789'

    result = values.copy()
    for unit in unitlist:

        value_blist_dict = {v:[] for v in values_string}

        for box in unit:
            for v in values[box]:
                value_blist_dict[v].append(box)
        
        for v in value_blist_dict:
            blist = value_blist_dict[v]
            if len(blist) == 1:
                result[blist[0]] = v

    return result

In [72]:
def reduce_puzzle(values):

    stalled = False
    while not stalled:

        solved_count_before = len([box for box in values.keys() 
                                   if len(values[box]) == 1])

        values = eliminate(values)
        values = only_choice(values)

        solved_count_after = len([box for box in values.keys() 
                                  if len(values[box]) == 1])

        stalled = (solved_count_before == solved_count_after)

        if len([box for box in values.keys() if len(values[box]) == 0]):
            return False

    return values

In [73]:
def search(values):
    
    values = reduce_puzzle(values)
    if values is False:
        return False
    if all(len(values[b]) == 1 for b in boxes): 
        return values
    
    _, b = min((len(values[b]), b) for b in boxes if len(values[b]) > 1)
    
    for v in values[b]:

        test_values = values.copy()
        test_values[b] = v

        test_result = search(test_values)
        if test_result:
            return test_result

In [74]:
def solve(values):

    values = init_state(values)
    values = search(values)

    return values

In [75]:
values = grid_values(puzzle1)

print("Initial puzzle state...")
display(values)

values = solve(values)

print()
print("Final puzzle state...")
display(values)

Initial puzzle state...
. . 3 |. 2 . |6 . . 
9 . . |3 . 5 |. . 1 
. . 1 |8 . 6 |4 . . 
------+------+------
. . 8 |1 . 2 |9 . . 
7 . . |. . . |. . 8 
. . 6 |7 . 8 |2 . . 
------+------+------
. . 2 |6 . 9 |5 . . 
8 . . |2 . 3 |. . 9 
. . 5 |. 1 . |3 . . 

Final puzzle state...
4 8 3 |9 2 1 |6 5 7 
9 6 7 |3 4 5 |8 2 1 
2 5 1 |8 7 6 |4 9 3 
------+------+------
5 4 8 |1 3 2 |9 7 6 
7 2 9 |5 6 4 |1 3 8 
1 3 6 |7 9 8 |2 4 5 
------+------+------
3 7 2 |6 8 9 |5 1 4 
8 1 4 |2 5 3 |7 6 9 
6 9 5 |4 1 7 |3 8 2 


In [76]:
values = grid_values(puzzle2)

print("Initial puzzle state...")
display(values)

values = solve(values)

print()
print("Final puzzle state...")
display(values)

Initial puzzle state...
4 . . |. . . |8 . 5 
. 3 . |. . . |. . . 
. . . |7 . . |. . . 
------+------+------
. 2 . |. . . |. 6 . 
. . . |. 8 . |4 . . 
. . . |. 1 . |. . . 
------+------+------
. . . |6 . 3 |. 7 . 
5 . . |2 . . |. . . 
1 . 4 |. . . |. . . 

Final puzzle state...
4 1 7 |3 6 9 |8 2 5 
6 3 2 |1 5 8 |9 4 7 
9 5 8 |7 2 4 |3 1 6 
------+------+------
8 2 5 |4 3 7 |1 6 9 
7 9 1 |5 8 6 |4 3 2 
3 4 6 |9 1 2 |7 5 8 
------+------+------
2 8 9 |6 4 3 |5 7 1 
5 7 3 |2 9 1 |6 8 4 
1 6 4 |8 7 5 |2 9 3 
