<a href="https://colab.research.google.com/github/srini229/EE5333_tutorials/blob/master/misc/SAT_examples.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install python-sat

In [None]:
from pysat.solvers import Cadical195
clauses = [[1, 2]]

solver = Cadical195(bootstrap_with=clauses)
while solver.solve():
  print(solver.get_model())
  solver.append_formula([[-j for j in solver.get_model()]])
solver.delete()

In [None]:
def atmost_one(l, solver):
  for i in range(len(l)):
    solver.append_formula([[-l[i], -l[j]] for j in range(i+1,len(l))])

def atleast_one(l, solver):
  solver.append_formula([l])

def exactly_one(l, solver):
  atmost_one(l, solver)
  atleast_one(l, solver)

In [79]:
# row/col indices run from 0,...,n-1
# Each row/col has exactly one queen
# vars for each position in row-major order
# e.g.
# |  1  |  2  |  3  |  4  |
# |  5  |  6  |  7  |  8  |
# |  9  | 10  | 11  | 12  |
# | 13  | 14  | 15  | 16  |

# |  1  |  2  |  3  |  4  |  5  |
# |  6  |  7  |  8  |  9  | 10  |
# | 11  | 12  | 13  | 14  | 15  |
# | 16  | 17  | 18  | 19  | 20  |
# | 21  | 22  | 23  | 24  | 25  |

def print_sol(model, n):
  sol = [((k-1)//n, (k-1)%n) for k in model if k > 0]
  q = [[' ' for i in range(n)] for j in range(n)]
  for s in sol:
    q[s[0]][s[1]] = 'Q'
  for r in q:
    print('{0}{1}{0}'.format('|', '|'.join(r)))
  print()

def nqueens(n):
  solver = Cadical195()
  # index of var at (i,j) = 4 * i + j + 1
  def index(i, j):
    return n * i + j + 1 if i >= 0 and j >= 0 and i < n and j < n else None
  # each row has one queen
  for i in range(n):
    exactly_one([index(i, j) for j in range(n)], solver)
  # each col has one queen
  for j in range(n):
    exactly_one([index(i, j) for i in range(n)], solver)
  # each diagonal has exactly one queen
  for i in range(-n+2, n - 1):
    atmost_one([index(i + j, j) for j in range(n) if None != index(i + j, j)], solver)
  for j in range(1, 2 * n - 2):
    atmost_one([index(j - i, i) for i in range(j,-1,-1) if None != index(j - i, i)], solver)

  for m in solver.enum_models():
    print_sol(m, n)
  solver.delete()

In [80]:
nqueens(4)

| |Q| | |
| | | |Q|
|Q| | | |
| | |Q| |

| | |Q| |
|Q| | | |
| | | |Q|
| |Q| | |



In [81]:
n = 9
from math import sqrt
ns = round(sqrt(n))

def print_sol_(s):
  sol = [[' ' for i in range(n)] for j in range(n)]
  for i in s:
    sol[i[0] - 1][i[1] - 1] = str(i[2])
  for r in sol:
    print('{0}{1}{0}'.format('|', '|'.join(r)))
  print()

def print_sol(model):
  s = []
  for i in model:
    if i > 0:
      i = i - 1
      s.append((i//81 + 1, 1 + ((i//9) % 9), i % 9 + 1))
  print_sol_(s)

def sudoku(hints=[]):
  print_sol_(hints)
  # index of var at (i,j) = 81 * i + 9 * j + v + 1
  def index(i, j, v):
    return 81 * i + 9 * j + v + 1
  solver = Cadical195()

# fill in the constraints

  for h in hints:
    solver.add_clause([index(h[0] - 1, h[1] - 1, h[2] - 1)])

  if solver.solve():
    print_sol(solver.get_model())


In [82]:
sudoku([(1,1,9), (9,9,8)])

|9| | | | | | | | |
| | | | | | | | | |
| | | | | | | | | |
| | | | | | | | | |
| | | | | | | | | |
| | | | | | | | | |
| | | | | | | | | |
| | | | | | | | | |
| | | | | | | | |8|

|9| | | | | | | | |
| | | | | | | | | |
| | | | | | | | | |
| | | | | | | | | |
| | | | | | | | | |
| | | | | | | | | |
| | | | | | | | | |
| | | | | | | | | |
| | | | | | | | |8|

