## N-Queens Problem

*Given a NxN chessboard, can you place N queens such that they do not kill one another?*

We will solve this as a constraint satisfaction problem, using the **minimum conflicts approach**. The idea is to repeatedly try to minimize the conflicts among the variables by picking a single random conflicted variable and changing its value.

In [1]:
from collections import defaultdict

import numpy as np

In [9]:
class NQueens(object):
    """Represents the N Queens problem as a constraint satisfaction problem (CSP).
    
    A CSP is typically defined by:
    - variables
    - domains, i.e. what values the variables take
    - constraints, i.e. restrict the combinations of variables/values
    
    N-queen formulation: since we know that the solution will always have one queen per
    row (or column), we choose the row index as our variable and the column index as the
    value that a variable can take. Thus, a variable has N-1 degrees of freedom.
    """

    def __init__(self, N=8, initialization='random'):
        """Initializes a NxN chessboard.
        
        Data structures.
        - dict of the form {variable: value}
        - cache to store the list of values tried (for a given variable)
        
        Two types of initialization.
        - Complete: greedy or random
        - Partial: not all variables are assigned a value
        
        Args:
            N: int, the size of the chessboard
            initialization: "greedy" or "random"
        """
        self.N = N
        # TODO: greedy initialization
        self.assignment = {i: np.random.randint(0, N) for i in range(N)}
        self.cache = defaultdict(list)
        for k, v in self.assignment.items():
            self.cache[k].append(v)

    # Generic functions

    def is_complete(self, assignment):
        """Check if the current assignment is complete, i.e. all variables have been assigned."""
        return len(assignment) == self.N

    def is_conflict(self, x1, x2, y1, y2):
        """Check if two variables are in conflict.
        
        Args:
            x1, y1: variable, value pair
            x2, y2: variable, value pair
            
        Returns:
            True if they are in conflict, False otherwise
        """
        in_same_row = (x1 == x2)
        in_same_col = (y1 == y2)
        in_same_diagonal = (abs(x1 - x2) == abs(y1 - y2))
        return in_same_row or in_same_col or in_same_diagonal

    def is_consistent(self, var, val, assignment):
        """Checks if a variable, value pair is consistent with the assignment."""
        for k, v in assignment.items():
            if k != var and self.is_conflict(k, var, v, val):
                return False
        return True

    def is_solution(self, assignment):
        """Check if the current assignment is a valid solution."""
        return self.is_complete(assignment) and len(
            self.conflicted_vars(assignment)) == 0

    # Min conflicts solver

    def n_conflicts(self, var, val, assignment):
        """Counts the number of conflicts for a particular variable."""
        return sum([
            self.is_conflict(var, k, val, v)
            for k, v in assignment.items()
            if k != var
        ])

    def conflicted_vars(self, assignment):
        """Return the list of variables that are conflicted."""
        return [
            k for k, v in assignment.items()
            if not self.is_consistent(k, v, assignment)
        ]

    def min_conflicts_val(self, var, assignment):
        """Return the value for a variable that minimizes the number of conflicts."""
        metric = self.N
        value = -1
        values_tried = self.cache.get(var, [])
        if len(values_tried) == self.N:
            values_tried = []
        self.cache[var] = values_tried
        for val in range(self.N):
            if val not in values_tried:
                num = self.n_conflicts(var, val, assignment)
                if num <= metric:
                    metric = num
                    value = val
        return value

    def min_conflicts_solver(self, max_steps=1000):
        """Prune the search space using the minimum conflicts heuristic."""
        current = self.assignment
        for n_iter in range(max_steps):
            if self.is_solution(current):
                print(f'Found solution in {n_iter} iterations')
                return current
            var = np.random.choice(self.conflicted_vars(current))
            val = self.min_conflicts_val(var, current)
            current[var] = val
            # record this trial, avoids going into a deadlock like situation across iterations
            self.cache[var].append(val)
        return f'Failed after {max_steps} iterations. Try with more iterations or a different initial state.'

    # Backtracking search solver

    def remaining_legal_values(self, var, assignment):
        """Returns the list of values for the variable that are consistent with the
        assignment."""
        return [
            val for val in range(self.N)
            if self.is_consistent(var, val, assignment)
        ]

    def select_unassigned_variable(self, assignment):
        """Selects an unassigned variable using the minimum-remaining-values (MRV) heuristic.
        
        Returns:
            variable, legal_values as a tuple
        """
        # variable, legal_values
        unassigned = [(var, self.remaining_legal_values(var, assignment))
                      for var in range(self.N)
                      if var not in assignment]
        # return the variable with mrv
        var, val = unassigned[0]
        for k, v in unassigned[1:]:
            if len(v) <= len(val):
                var, val = k, v
        return var, val

    def n_choices_ruled_out(self, var, val, neighbor, assignment):
        """How many choices does a particular variable:value pair rule out for its neighbors?
        
        Args:
            var, val: the variable:value pair
            neighbor: neighbor of the variable
            assignment: the partial assignment
            
        Returns:
            int, the number of choices this variable:value pair rules out
        """
        # remaining legal values for its neighbor
        legal_values = self.remaining_legal_values(neighbor, assignment)
        # compute how many of those are in conflict with var:val
        return sum([
            1 for v in legal_values if self.is_conflict(var, neighbor, val, v)
        ])

    def order_domain_values(self, var, values, assignment):
        """Order the values according the to least-constraining-value (LCV) heuristic.
        
        If a value v1 rules out 5 choices and v2 2 choices (of their neighbors), v1 is
        more constraining than v2. So, the order would be v2 followed by v1.
        """
        neighbors = [var - 1, var + 1]
        # one of them may already be in the assignment
        if (var - 1) in assignment:
            neighbors.remove(var - 1)
        elif (var + 1) in assignment:
            neighbors.remove(var + 1)
        # minimize the number of choices ruled out for neighbors
        metric = 2 * self.N
        result = []  # to contain values in lcv order
        for val in values:
            n_ruled_out = sum([
                self.n_choices_ruled_out(var, val, neighbor, assignment)
                for neighbor in neighbors
            ])
            result.append((val, n_ruled_out))
        return [x[0] for x in sorted(result, key=lambda x: x[1])]

    def inference(self, var, value):
        """Inference on constraints such as arc-consistency or path-consistency."""
        return {}

    def backtrack(self, assignment):
        """Recursive backtracking function to search for valid solutions.
        
        Args:
            assignment: a partial or complete assignment
            
        Returns:
            the solution, i.e. assignment or "failure"
        """
        if self.is_complete(assignment):
            return assignment
        var, values = self.select_unassigned_variable(assignment)
        for val in self.order_domain_values(var, values, assignment):
            if self.is_consistent(var, val, assignment):
                assignment[var] = val
                inferences = self.inference(var, val)  # partial assignment
                if inferences != 'failure':
                    assignment = {
                        **assignment,
                        **inferences
                    }  # best way to merge dicts?
                    result = self.backtrack(assignment)
                    if result != 'failure':
                        return result
            # remove inferences and var from assignment
            assignment.pop(var, None)
            for k in inferences:
                assignment.pop(k, None)
        return 'failure'

In [10]:
# min conflicts
nqueens = NQueens(8)
print('Minimum conflicts approach:', nqueens.min_conflicts_solver())

# backtracking
print('Backtracking:', nqueens.backtrack({}))

Found solution in 427 iterations
Minimum conflicts approach: {0: 1, 1: 5, 2: 0, 3: 6, 4: 3, 5: 7, 6: 2, 7: 4}
Backtracking: {6: 2, 7: 7, 5: 0, 2: 4, 4: 5, 3: 1, 1: 6, 0: 3}


The min conflicts approach is very effective in solving this type of problem because of the dense nature of the state space.

**More things to try:**
- Try larger values of N (implement in Java or Cython)
- Compare (both memory and execution time) with solving the same problem using search methods and backtracking.
- Random vs. greedy initialization: does greedy initialization help arriving at the solution quicker?