# Solving Sudoku using Constraint Satisfaction

# Problem Definition

+ Sudoku is a 9x9 grid of numbers.
+ Rows are labelled as A through I and columns are labelled as 1 through 9.
+ Intersection of a row and columns is called a square that is filled with a number from the set {1,2,3,4,5,6,7,8,9}.
+ A collection of 9 squares in a column, row, or box is called a unit. So each square has 3 units.
+ Squares that share a unit are called peers. By this definition, a square can have 20 peers.

The puzzle that is sudoku is solved if the squares in each unit are filled with a permutation of digits 1 through 9.


## Creation step

In [5]:
def cross(A, B):
    "Takes cross product(here concatenation) of elements in A and elements in B"
    return [a + b for a in A for b in B]

digits = '123456789'
rows = 'ABCDEFGHI'
cols = digits
squares = cross(rows, cols)
unitList = (
    [cross(rows, c) for c in cols] + #generates 9 column units
    [cross(r, cols) for r in rows] + #generates 9 row units
    [cross(rs, cs) for rs in ('ABC', 'DEF', 'GHI') for cs in ('123', '456', '789')] #generates 9 box units
)
units = dict((s, [u for u in unitList if s in u]) for s in squares)
peers = dict((s, set(sum(units[s], [])) - set([s])) for s in squares)
#sum([[1,2,3],[4,5,6]], []) == [1,2,3,4,5,6]

In [6]:
print("Number of elements in squares list = ",len(squares))
print("Number of elements in unitList = ",len(unitList))
print("Number of elements in units dictionary = ",len(units))
print("Number of elements in peers dictionary = ",len(peers))

Number of elements in squares list =  81
Number of elements in unitList =  27
Number of elements in units dictionary =  81
Number of elements in peers dictionary =  81


In [7]:
#unitList of D2 - located in squares[28]
[unit for unit in unitList if 'D2' in unit]

[['A2', 'B2', 'C2', 'D2', 'E2', 'F2', 'G2', 'H2', 'I2'],
 ['D1', 'D2', 'D3', 'D4', 'D5', 'D6', 'D7', 'D8', 'D9'],
 ['D1', 'D2', 'D3', 'E1', 'E2', 'E3', 'F1', 'F2', 'F3']]

In [8]:
#units of D2
units['D2']

[['A2', 'B2', 'C2', 'D2', 'E2', 'F2', 'G2', 'H2', 'I2'],
 ['D1', 'D2', 'D3', 'D4', 'D5', 'D6', 'D7', 'D8', 'D9'],
 ['D1', 'D2', 'D3', 'E1', 'E2', 'E3', 'F1', 'F2', 'F3']]

In [9]:
#peers of D2
print(peers['D2'])

{'D5', 'G2', 'D3', 'D1', 'D7', 'D4', 'H2', 'I2', 'F1', 'F3', 'F2', 'D6', 'E1', 'E2', 'C2', 'D8', 'E3', 'B2', 'A2', 'D9'}


## Grid Representation

In [10]:
#Grid Will be represented as follows:

example_grid = "4.....8.5.3..........7......2.....6.....8.4......1.......6.3.7.5..2.....1.4......"

In [11]:
#extracting square values from the grid

def grid_values(grid):
    "Convers grid to a dictionary of {square: character} with '0' or '.' for not filled"
    chars = [c for c in grid if c in digits or c in '0.']
    assert len(chars) == 81
    return dict(zip(squares, chars))

In [12]:
print(grid_values(example_grid))

{'A1': '4', 'A2': '.', 'A3': '.', 'A4': '.', 'A5': '.', 'A6': '.', 'A7': '8', 'A8': '.', 'A9': '5', 'B1': '.', 'B2': '3', 'B3': '.', 'B4': '.', 'B5': '.', 'B6': '.', 'B7': '.', 'B8': '.', 'B9': '.', 'C1': '.', 'C2': '.', 'C3': '.', 'C4': '7', 'C5': '.', 'C6': '.', 'C7': '.', 'C8': '.', 'C9': '.', 'D1': '.', 'D2': '2', 'D3': '.', 'D4': '.', 'D5': '.', 'D6': '.', 'D7': '.', 'D8': '6', 'D9': '.', 'E1': '.', 'E2': '.', 'E3': '.', 'E4': '.', 'E5': '8', 'E6': '.', 'E7': '4', 'E8': '.', 'E9': '.', 'F1': '.', 'F2': '.', 'F3': '.', 'F4': '.', 'F5': '1', 'F6': '.', 'F7': '.', 'F8': '.', 'F9': '.', 'G1': '.', 'G2': '.', 'G3': '.', 'G4': '6', 'G5': '.', 'G6': '3', 'G7': '.', 'G8': '7', 'G9': '.', 'H1': '5', 'H2': '.', 'H3': '.', 'H4': '2', 'H5': '.', 'H6': '.', 'H7': '.', 'H8': '.', 'H9': '.', 'I1': '1', 'I2': '.', 'I3': '4', 'I4': '.', 'I5': '.', 'I6': '.', 'I7': '.', 'I8': '.', 'I9': '.'}


In [13]:
def eliminate(values, s, d):
    """
    eliminate d from values[s], and propagate when values <= 2
    Return values, except retrun False if contradiction is detected
    """
    if d not in values[s]:
        return values #Already eliminated
    values[s] = values[s].replace(d, '')
    
    #if square only has one value, then eliminate that value from its peers
    if len(values[s]) == 0:
        return False
    elif len(values[s]) == 1:
        d2 = values[s]
        if not all(eliminate(values, s2, d2) for s2 in peers[s]):
            return False
    
    #if a unit has only one place for a value d, then put it there.
    for u in units[s]:
        dplaces = [s for s in u if d in values[s]]
        if len(dplaces) == 0:
            return False
        elif len(dplaces) == 1:
            if not assign(values, dplaces[0], d):
                return False
    return values
    
def assign(values, s, d):
    """
    Assign values[s] with d, that is eliminating from values[s] all 
    except d and propagate it to its peers. 
    Return values, except retrun False if contradiction is detected
    """
    
    other_values = values[s].replace(d, '')
    if all(eliminate(values, s, d2) for d2 in other_values):
        return values
    else:
        return False

In [14]:
#converting values in squares to possible values

def parse_grid(grid):
    """
    Convert grid to a dict of possible values, {square: digits}, or
    return False if a contradiction is detected.
    """
    
    #Create a Dictionary where we assign each square all possible digits
    values = dict((s,digits) for s in squares)

    #Assign proper values to those squares in Values Dict that already had value in grid.
    for s,d in grid_values(grid).items():
        #only doing this for non-empty elements of the grid
        if d in digits and not assign(values, s, d):
            return False #could not assign d to square s
    return values

In [15]:
def display(values):
    "Display these values as a 2-D grid."
    width = 1+max(len(values[s]) for s in squares)
    line = '+'.join(['-'*(width*3)]*3)
    for r in rows:
        print(''.join(values[r+c].center(width)+('|' if c in '36' else '') for c in cols))
        if r in 'CF': 
            print(line)
    print()

In [16]:
display(parse_grid(example_grid))

   4      1679   12679  |  139     2369    269   |   8      1239     5    
 26789     3    1256789 | 14589   24569   245689 | 12679    1249   124679 
  2689   15689   125689 |   7     234569  245689 | 12369   12349   123469 
------------------------+------------------------+------------------------
  3789     2     15789  |  3459   34579    4579  | 13579     6     13789  
  3679   15679   15679  |  359      8     25679  |   4     12359   12379  
 36789     4     56789  |  359      1     25679  | 23579   23589   23789  
------------------------+------------------------+------------------------
  289      89     289   |   6      459      3    |  1259     7     12489  
   5      6789     3    |   2      479      1    |   69     489     4689  
   1      6789     4    |  589     579     5789  | 23569   23589   23689  



In [17]:
#search and solve

def solve(grid):
    return search(parse_grid(grid))

def search(values):
    """
    Using depth first search.
    """
    #check if a contradiction was found
    if values is False:
        return False
    
    #check if already solved
    if all(len(values[s]) == 1 for s in squares):
        return values
    
    #choose the unfilled square with fewest remaining possible values
    n,s = min((len(values[s]), s) for s in squares if len(values[s]) > 1)
    
    #recursively call search after assigning s to d(do this for each d in values[s])
    return some(search(assign(values.copy(), s, d)) for d in values[s])

def some(seq):
    """
    Return some element of seq that is True.
    Here we are returning a possible solution
    (that doesn't have any contradictions)
    """
    for e in seq:
        if e:
            return e
    return False

In [19]:
display(solve(example_grid))

4 1 7 |3 6 9 |8 2 5 
6 3 2 |1 5 8 |9 4 7 
9 5 8 |7 2 4 |3 1 6 
------+------+------
8 2 5 |4 3 7 |1 6 9 
7 9 1 |5 8 6 |4 3 2 
3 4 6 |9 1 2 |7 5 8 
------+------+------
2 8 9 |6 4 3 |5 7 1 
5 7 3 |2 9 1 |6 8 4 
1 6 4 |8 7 5 |2 9 3 

