## N-Queens Problem

*Given a NxN chessboard, can you place N queens such that they do not kill one another?*

We will solve this as a constraint satisfaction problem, using the **minimum conflicts approach**.

In [3]:
from collections import defaultdict
import itertools

import numpy as np

In [10]:
class NQueens:
    """Represents the N Queens problem as a constraint satisfaction problem.
    
    Data structure: dict (variable, value) format.
    Variable: the row index, e.g. 0-7 in case of 8-queens
    Value: the col index, again 0-7 (we change this)
    """
    
    def __init__(self, N=8, assignment='random'):
        """Initializes a NxN chessboard. Only stores positions of queens.
        
        Two methods to initialize the problem: greedy and random. Greedy generally
        helps find a solution quicker than random initialization. Greedy initialization
        selects the first value that is compatible with variables initialized so far.
        """
        self.N = N
        cols = list(range(N))
        np.random.shuffle(cols)
        self.assignment = {i: j for i, j in zip(range(N), cols)}
        # TODO: greedy initialization
        self.cache = defaultdict(list)
        for k, v in self.assignment.items():
            self.cache[k].append(v)
    
    def is_conflict(self, x1, x2, assignment):
        """Check if two variables (representing row numbers) conflict."""
        y1 = assignment[x1]
        y2 = assignment[x2]
        return x1 == x2 or y1 == y2 or x1 - x2 == y1 - y2
    
    def is_solution(self, assignment):
        """Check if the current assignment is a valid solution."""
        return len(self.assignment) == self.N and len(self.conflicted_vars(assignment)) == 0
        
    def n_conflicts(self, var, assignment):
        """Counts the number of conflicts for a particular variable (row index)."""
        # check all combinations
        return sum([self.is_conflict(var, el, assignment)
                    for el in assignment if el != var])
    
    def conflicted_vars(self, assignment):
        """Return the list of variables that are conflicted."""
        return [v for v in assignment if self.n_conflicts(v, assignment) > 0]
    
    def record(self, var, val):
        """Track the changes we have tried."""
        self.cache[var].append(val)
    
    def min_conflicts_val(self, var, assignment):
        """Return the value for a variable that minimizes the number of conflicts."""
        m = self.n_conflicts(var, assignment)
        val = assignment[var]
        vals_to_try = [x for x in range(self.N) if x not in self.cache[var]]
        if len(vals_to_try) == 0:
            vals_to_try = list(range(self.N))
        for j in vals_to_try:
            if j == val:
                continue
            assignment[var] = j
            tmp = self.n_conflicts(var, assignment)
            if tmp <= m:
                m = tmp
                val = j
        return val
    
    def min_conflicts(self, max_steps=1000):
        """Min conflicts approach to solving this problem."""
        current = self.assignment
        for n_iter in range(max_steps):
            if self.is_solution(current):
                print(f'Found solution in {n_iter} iterations')
                return current
            var = np.random.choice(self.conflicted_vars(current))
            val = self.min_conflicts_val(var, current)
            current[var] = val
            self.record(var, val)  # to avoid running into a deadlock like situation
        return f'Failed after {max_steps} iterations. Try with more iterations or a different initial state.'

The min conflicts approach is very effective in solving this type of problem because of the dense nature of the state space.

**More things to try:**
- Try larger values of N (implement in Java or Cython)
- Compare (both memory and execution time) with solving the same problem using search methods, e.g. A*
- Random vs. greedy initialization: does greedy initialization helps arriving at the solution quicker?

In [15]:
%%timeit

n_queens = NQueens(50)
tmp = n_queens.min_conflicts()

Found solution in 27 iterations
Found solution in 20 iterations
Found solution in 19 iterations
Found solution in 24 iterations
Found solution in 31 iterations
Found solution in 59 iterations
Found solution in 19 iterations
Found solution in 36 iterations
1.69 s ± 689 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
