In [None]:
try:
    import pysat
except:
    !pip install python-sat
    import pysat


In [None]:
from pysat.solvers import Glucose3
g = Glucose3()
g.add_clause([-1, 2])
g.add_clause([-2, 3])
print(g.solve())
print(g.get_model())
g.delete()


In [None]:
with Glucose3() as g:
    g.add_clause([-1, 2])
    g.add_clause([-2, 3])
    print(g.solve())
    print(g.get_model())


In [None]:
from pysat.solvers import Glucose4, Minisat22
with Minisat22() as g:
    g.add_clause([-1, 2])
    g.add_clause([-2, 3])
    print(g.solve())
    print(g.get_model())


In [None]:
import numpy as np

class SudokuViaSAT(object):

    def __init__(self, sudoku_string):
        """
        @param sudoku_string: an 81-long digit string: 0 represents an unknown
            digit, and 1..9 represent the respective digit.
        """
        assert len(sudoku_string) > 80
        self.board = np.zeros((9, 9), dtype=np.uint8)
        for i in range(9):
            for j in range(9):
                self.board[i, j] = int(sudoku_string[i * 9 + j])
        self.sat = None # This will be the SAT instance.
        # Perform here any other initialization you think you need.
        # YOUR CODE HERE

    def show(self):
        """Prints out the board."""
        print("+---+---+---+")
        for i in range(9):
            r = '|'
            for j in range(9):
                r += "." if self.board[i, j] == 0 else str(self.board[i, j])
                if (j + 1) % 3 == 0:
                    r += "|"
            print(r)
            if (i + 1) % 3 == 0:
                print("+---+---+---+")


In [None]:
problem = "000000061350000000400050000020000800000601000000700000000080200600400000007000010"
sd = SudokuViaSAT(problem)
sd.show()


In [None]:
def encode_variable(d, i, j):
    """This function creates the variable (the integer) representing the
    fact that digit d appears in position i, j.
    Of course, to obtain the complement variable, you can just negate
    (take the negative) of the returned integer.
    Note that it must be: 1 <= d <= 9, 0 <= i <= 8, 0 <= j <= 8."""
    assert 1 <= d <= 9
    assert 0 <= i < 9
    assert 0 <= j < 9
    # The int() below seems useless, but it is not.  If d is a numpy.uint8,
    # as an element of the board is, this int() ensures that the generated
    # literal is a normal Python integer, as the SAT solvers expect.
    return int(d * 100 + i * 10 + j)


In [None]:
# Let's define a testing helper.

def check_equal(x, y, msg=None):
    if x != y:
        if msg is None:
            print("Error:")
        else:
            print("Error in", msg, ":")
        print("    Your answer was:", x)
        print("    Correct answer: ", y)
    assert x == y, "%r and %r are different" % (x, y)


In [None]:
check_equal(encode_variable(3, 6, 7), 367)


In [None]:
def decode_variable(p):
    """Given an integer constructed as by _create_prediate above,
    returns the tuple (d, i, j), where d is the digit, and i, j are
    the cells where the digit is.  Returns None if the integer is out of
    range.
    Note that it must be: 1 <= d <= 9, 0 <= i <= 8, 0 <= j <= 8.
    If this does not hold, return None.
    Also return None if p is not in the range from 100, to 988 (the
    highest and lowest values that p can assume).
    Hint: modulo arithmetic via %, // is useful here!"""
    # YOUR CODE HERE
    if 100 <= p <= 988:
      return (int(p // 100), int((p%100)//10), int(p % 10))
    else:
      return None


In [None]:
for d in range(1, 10):
    for i in range(9):
        for j in range(9):
            r = decode_variable(encode_variable(d, i, j))
            check_equal((d, i, j), r)



In [None]:
def every_cell_contains_at_least_one_digit():
    """Returns a list of clauses, stating that every cell must contain
    at least one digit."""
    # YOUR CODE HERE
    total = []
    for i in range(9):
      for j in range(9):
        lis = []
        for d in range(1,10):
          number = encode_variable(d, i, j)
          lis.append(number)
        total.append(lis)
    return total
    


In [None]:
def prepare(g):
    for c in every_cell_contains_at_least_one_digit():
        g.add_clause(c)

with Glucose3() as g:
    prepare(g)
    # This can be solved.
    check_equal(g.solve(), True)
    for d in range(1, 10):
        # These clauses state that no digit appears at 4, 5.
        # You can change the coordinates if you like.
        g.add_clause([-encode_variable(d, 4, 5)])
    check_equal(g.solve(), False)



In [None]:
def every_cell_contains_at_most_one_digit():
    """Returns a list of clauses, stating that every cell contains
    at most one digit."""
    # YOUR CODE HERE
    total = []
    for i in range(9):
      for j in range(9):
        for d in range(1,10):
          for e in range(1, 10):
            lis = []
            number = encode_variable(d, i, j)
            number2 = encode_variable(e, i, j)
            if -number2 != -number:
              lis.append(-number)
              lis.append(-number2)
              total.append(lis)
    return total


In [None]:
def prepare(g):
    for c in every_cell_contains_at_most_one_digit():
        g.add_clause(c)

with Glucose3() as g:
    prepare(g)
    check_equal(g.solve(), True)
    # This states that both 3 and 4 appear at position 6, 7.
    g.add_clause([encode_variable(3, 6, 7)])
    g.add_clause([encode_variable(4, 6, 7)])
    check_equal(g.solve(), False)



In [None]:
def no_identical_digits_in_same_row():
    """Returns a list of clauses, stating that if a digit appears
    in a cell, the same digit cannot appear elsewhere in the
    same row, column, or 3x3 square."""
    # YOUR CODE HERE
    total = []
    for d in range(1, 10):
      for i in range(9):
        for j in range(9):
          for e in range(9):
            lis = []
            number = encode_variable(d, i, j)
            number2 = encode_variable(d, i, e)
            if -number2 != -number:
              lis.append(-number)
              lis.append(-number2)
              total.append(lis)
    return total


In [None]:
def prepare(g):
    for c in no_identical_digits_in_same_row():
        g.add_clause(c)

with Glucose3() as g:
    prepare(g)
    check_equal(g.solve(), True)
    # This states that 3 appears twice in row 5.
    g.add_clause([encode_variable(3, 5, 7)])
    g.add_clause([encode_variable(3, 5, 8)])
    check_equal(g.solve(), False)

# But columns are not forbidden.
with Glucose3() as g:
    prepare(g)
    g.add_clause([encode_variable(3, 5, 7)])
    g.add_clause([encode_variable(3, 2, 7)])
    check_equal(g.solve(), True)



In [None]:
def no_identical_digits_in_same_column():
    """Returns a list of clauses, stating that if a digit appears
    in a cell, the same digit cannot appear elsewhere in the
    same row, column, or 3x3 square."""
    # YOUR CODE HERE
    total = []
    for d in range(1, 10):
      for i in range(9):
        for j in range(9):
          for e in range(9):
            lis = []
            number = encode_variable(d, i, j)
            number2 = encode_variable(d, e, j)
            if -number2 != -number:
              lis.append(-number)
              lis.append(-number2)
              total.append(lis)
    return total


In [None]:
def prepare(g):
    for c in no_identical_digits_in_same_column():
        g.add_clause(c)

with Glucose3() as g:
    prepare(g)
    check_equal(g.solve(), True)
    # This states that 3 appears twice in column 7.
    g.add_clause([encode_variable(3, 5, 7)])
    g.add_clause([encode_variable(3, 2, 7)])
    check_equal(g.solve(), False)

# But rows are not forbidden.
with Glucose3() as g:
    prepare(g)
    g.add_clause([encode_variable(3, 5, 7)])
    g.add_clause([encode_variable(3, 5, 8)])
    check_equal(g.solve(), True)



In [None]:
def no_identical_digits_in_same_block():
    """Returns a list of clauses, stating that if a digit appears
    in a cell, the same digit cannot appear elsewhere in the
    same row, column, or 3x3 square."""
    # YOUR CODE HERE
    total = []
    for i in range(9):
      for j in range(9):
        j1 = j // 3
        i1 = i // 3
        for l in range(i1 * 3, i1 * 3 + 3):
          for m in range(j1 * 3, j1 * 3 + 3):
            for d in range(1, 10):
              if (l, m) != (i, j):
                total.append([-encode_variable(d, j, i), -encode_variable(d, m, l)])
    return total


In [None]:
def prepare(g):
    for c in no_identical_digits_in_same_block():
        g.add_clause(c)

with Glucose3() as g:
    prepare(g)
    check_equal(g.solve(), True)
    # This states that 3 appears twice in top block
    g.add_clause([encode_variable(3, 1, 1)])
    g.add_clause([encode_variable(3, 1, 2)])
    check_equal(g.solve(), False)

# One more test.
with Glucose3() as g:
    prepare(g)
    g.add_clause([encode_variable(3, 1, 1)])
    g.add_clause([encode_variable(3, 2, 1)])
    check_equal(g.solve(), False)

# But different blocks are not forbidden.
with Glucose3() as g:
    prepare(g)
    g.add_clause([encode_variable(3, 1, 1)])
    g.add_clause([encode_variable(3, 5, 8)])
    check_equal(g.solve(), True)



In [None]:
def sudoku_rules():
    clauses = []
    clauses.extend(every_cell_contains_at_least_one_digit())
    clauses.extend(every_cell_contains_at_most_one_digit())
    clauses.extend(no_identical_digits_in_same_row())
    clauses.extend(no_identical_digits_in_same_column())
    clauses.extend(no_identical_digits_in_same_block())
    return clauses


In [None]:
with Glucose3() as g:
    for c in sudoku_rules():
        g.add_clause(c)
    check_equal(g.solve(), True)



In [None]:
def _board_to_SAT(self):
    """Translates the currently known state of the board into a list of SAT
    clauses.  Each clause has only one literal, and expresses the fact that a
    given digit is in a given position.  The method returns the list of clauses
    corresponding to all the initially known Sudoku digits."""
    # YOUR CODE HERE
    clauses = []
    rowcount = 0
    for i in range(9):
      for j in range(9):
        if self.board[i][j] != 0:
          clauses.append([encode_variable(self.board[i][j], i, j)])
    return clauses

SudokuViaSAT._board_to_SAT = _board_to_SAT


In [None]:

problem = "000000061350000000400050000020000800000601000000700000000080200600400000007000010"
sd = SudokuViaSAT(problem)
sd.show()
clauses = sd._board_to_SAT()
check_equal(len(clauses), 17)
check_equal([310] in clauses, True)
check_equal([511] in clauses, True)
check_equal([224] in clauses, False)

# This should print the elements of the board.
for c in sd._board_to_SAT():
    print(c)



In [None]:
def _to_SAT(self):
    return list(sudoku_rules()) + list(self._board_to_SAT())

SudokuViaSAT._to_SAT = _to_SAT


In [None]:
problem = "000000061350000000400050000020000800000601000000700000000080200600400000007000010"
sd = SudokuViaSAT(problem)
with Glucose3() as g:
    for c in sd._to_SAT():
        g.add_clause(c)
    check_equal(g.solve(), True)

problem = "000000061350000000404050000020000800000601000000700000000080200600400000007000010"
sd = SudokuViaSAT(problem)
with Glucose3() as g:
    for c in sd._to_SAT():
        for j in c:
            g.add_clause(c)
    check_equal(g.solve(), False)



In [None]:
problem = "000000061350000000400050000020000800000601000000700000000080200600400000007000010"
sd = SudokuViaSAT(problem)
with Glucose3() as g:
    for c in sd._to_SAT():
        g.add_clause(c)
    check_equal(g.solve(), True)
    # Let's get a truth assignment.
    ps = g.get_model()
    print(ps)


In [None]:
for l in ps:
    if l > 0 and decode_variable(l) is not None:
        print(decode_variable(l))


In [None]:
def solve(self, Solver):
    """Solves the Sudoku instance using the given SAT solver
    (e.g., Glucose3, Minisat22).
    @param Solver: a solver, such as Glucose3, Minisat22.
    @returns: False, if the Sudoku problem is not solvable, and True, if it is.
        In the latter case, the solve method also completes self.board,
        using the solution of SAT to complete the board."""
    # YOUR CODE HERE
    with Solver() as g:
        for c in self._to_SAT():
            g.add_clause(c)
        if g.solve() is False:
          return False
        ps = g.get_model()
        fill = []
        for l in ps:
           if l > 0 and decode_variable(l) is not None:
             fill.append(decode_variable(l))
        for tup in fill:
          self.board[tup[1], tup[2]] = tup[0]
    return True


SudokuViaSAT.solve = solve


In [None]:
problem = "000000061350000000400050000020000800000601000000700000000080200600400000007000010"
sd = SudokuViaSAT(problem)
sd.show()
sd.solve(Glucose3)
sd.show()


In [None]:
def verify_solution(Solver, sudoku_string):
    sd = SudokuViaSAT(sudoku_string)
    sd.solve(Solver)
    # Check that we leave alone the original assignment.
    for i in range(9):
        for j in range(9):
            d = int(sudoku_string[i * 9 + j])
            if d > 0:
                assert sd.board[i, j] == d
    # Check that there is a digit in every cell.
    for i in range(9):
        for j in range(9):
            assert 1 <= sd.board[i, j] <= 9
    # Check the exclusion rules of Sudoku.
    for i in range(9):
        for j in range(9):
            # No repetition in row.
            for ii in range(i + 1, 9):
                assert sd.board[i, j] != sd.board[ii, j]
            # No repetition in column.
            for jj in range(j + 1, 9):
                assert sd.board[i, j] != sd.board[i, jj]
            # No repetition in block
            ci, cj = i // 3, j // 3
            for bi in range(2):
                ii = ci * 3 + bi
                for bj in range(2):
                    jj = cj * 3 + bj
                    if i != ii or j != jj:
                        assert sd.board[i, j] != sd.board[ii, jj]


In [None]:
verify_solution(Glucose3, problem)


In [None]:
import requests

r = requests.get("https://raw.githubusercontent.com/shadaj/sudoku/master/sudoku17.txt")
puzzles = r.text.split()


In [None]:
for problem in puzzles[:100]:
    verify_solution(Glucose3, problem)



In [None]:
import time

def compute_performance(problem_idx):
    for Solver in [Glucose3, Glucose4, Minisat22]:
        t = time.time()
        sd = SudokuViaSAT(puzzles[problem_idx])
        sd.solve(Solver)
        dt = time.time() - t
        print(Solver.__name__, ":", dt)


In [None]:
compute_performance(6361)


In [None]:
compute_performance(1258)
compute_performance(5302)
compute_performance(10980)
compute_performance(25632)
compute_performance(48287)


In [None]:
def compute_performance(Solver, num_problems):
    max_time = 0
    for idx, problem in enumerate(puzzles[:num_problems]):
        t = time.time()
        sd = SudokuViaSAT(problem)
        sd.solve(Solver)
        dt = time.time() - t
        if dt > max_time:
            print(idx, ":", dt)
            max_time = dt


In [None]:
compute_performance(Glucose4, 1000)
