# Solve Killer Sudoku

Import z3 library

In [1]:
from z3 import *

Define global parameters

In [2]:
column_size = 9
row_size = 9

Create Z3 Solver

In [3]:
s = Solver()

Initiate cell matrix

In [4]:
cells = [[Int('cell%d%d' % (r, c)) for c in range(column_size)] for r in range(row_size)]

Add digit constraints: dom(val(cell in cells)) = [1, 9]

In [5]:
for r in range(row_size):
    for c in range(column_size):
        s.add(cells[r][c] >= 1)
        s.add(cells[r][c] <= 9)

Add row constraints

In [6]:
for r in range(row_size):
    s.add(Distinct(cells[r][0],
                   cells[r][1],
                   cells[r][2],
                   cells[r][3],
                   cells[r][4],
                   cells[r][5],
                   cells[r][6],
                   cells[r][7],
                   cells[r][8]))

Add column constraints

In [7]:
for c in range(row_size):
    s.add(Distinct(cells[0][c],
                   cells[1][c],
                   cells[2][c],
                   cells[3][c],
                   cells[4][c],
                   cells[5][c],
                   cells[6][c],
                   cells[7][c],
                   cells[8][c]))

Add 3*3 grid constrains

In [8]:
for r in range(0, row_size, 3):
    for c in range(0, column_size, 3):
        s.add(Distinct(cells[r + 0][c + 0],
                       cells[r + 0][c + 1],
                       cells[r + 0][c + 2],
                       cells[r + 1][c + 0],
                       cells[r + 1][c + 1],
                       cells[r + 1][c + 2],
                       cells[r + 2][c + 0],
                       cells[r + 2][c + 1],
                       cells[r + 2][c + 2]))

Add cage constrains from the file

In [9]:
f = open('cage_constraints', 'r')

for line in f:
    cage = []
    e = line.split(' ')
    for i in range(1, len(e)):
        cage.append(cells[int(e[i][0]) - 1][int(e[i][1]) - 1])
    s.add(Distinct(*cage))
    s.add(Sum(*cage) == e[0])

f.close()

Start solving

In [10]:
s.check()
m = s.model()

Output result matrix

In [11]:
for r in range(row_size):
    for c in range(column_size):
        print(str(m[cells[r][c]]) + ' ', end='')
    print()


2 1 5 6 4 7 3 9 8 
3 6 8 9 5 2 1 7 4 
7 9 4 3 8 1 6 5 2 
5 8 6 2 7 4 9 3 1 
1 4 2 5 9 3 8 6 7 
9 7 3 8 1 6 4 2 5 
8 2 1 7 3 9 5 4 6 
6 5 9 4 2 8 7 1 3 
4 3 7 1 6 5 2 8 9 
