### Imports

In [48]:
import os
from itertools import combinations
from typing import List, NoReturn

### Types

In [49]:
variable = int
list_of_variable = List[variable]
clauses = List[variable]
set_of_clauses = List[clauses]

### Logic functions

In [60]:
def get_CNF_n_true_among(values: list_of_variable, n: int) -> set_of_clauses:
    """
    Returns the CNF equivalent to two true variables from the list provided in parameter

    Parameters
    ----------
    values : int
        List of logic variables
    n : int
        Number of true variables

    """
    if len(values) < n:
        raise ValueError(f"The size of the list must be at least {n}.")

    close: set_of_clauses = []

    for j in range(len(values) + 1):
        if j != n:
            for cmb in combinations(values, j):
                close.append([-k if k in cmb else k for k in values])
    return close


def get_CNF_no_true_among_neighbor(var: variable, neighbor: list_of_variable) -> set_of_clauses:
    """
    Returns the CNF equivalent to no true variable among the neighbor of the variable provided in parameter

    Parameters
    ----------
    var : list_of_variable
        Initial variable
    neighbor : list_of_variable
        List of neighboor of the inital variable

    """
    ret: set_of_clauses = []

    for var_neighbor in neighbor:
        ret.append([-var, -var_neighbor])

    return ret

### IO functions

In [61]:
class LireGrille:
    def __init__(self, path: str):
        """
        Read the grid of Starbattle game file given in parameter
        """
        with open(path, "r") as f:  # open the file and read it
            self.data = f.read()

        self.grid: List[List[str]] = [line.split(" ") for line in self.data.strip().split("\n")]

        self.dim: int = len(self.grid[0])  # get the number of columns in the grid

    def get_line(self, line_number: int) -> list_of_variable:
        """
        Return all variables of the line given in parameter

        """
        valeur_deb = (line_number - 1) * self.dim + 1

        return list(range(valeur_deb, valeur_deb + self.dim))

    def get_column(self, col_number: int) -> list_of_variable:
        """
        Return all variables of the column given in parameter

        """
        return list(range(col_number, self.dim ** 2 + 1, self.dim))

    def get_zone(self, zone_number: int) -> list_of_variable:
        """
        Return all variables of the zone given in parameter
        """
        zone: list_of_variable = []
        k: int = 1
        for line in self.grid:
            for elt in line:
                if int(elt) == zone_number:
                    zone.append(k)
                k += 1
        return zone

    def get_neighbor(self, var: int) -> list_of_variable:
        """
        Return all variables of the neighbor of the variable given in parameter
        """

        l = [var + 1, var - 1, var - self.dim, var + self.dim, var - self.dim - 1,
             var - self.dim + 1, var + self.dim - 1, var + self.dim + 1]

        if var % self.dim == 1:  # if the variable is in the first column
            l.remove(var - 1)
            l.remove(var + self.dim - 1)
            l.remove(var - self.dim - 1)

        elif var % self.dim == 0:  # if the variable is in the last column
            l.remove(var + 1)
            l.remove(var + self.dim + 1)
            l.remove(var - self.dim + 1)

        if var <= self.dim:  # if the variable is in the first line
            if var % self.dim != 1:
                l.remove(var - self.dim - 1)
            if var % self.dim != 0:
                l.remove(var - self.dim + 1)
            l.remove(var - self.dim)

        elif var > self.dim * (self.dim - 1):  # if the variable is in the last line
            if var % self.dim != 1:
                l.remove(var + self.dim - 1)
            if var % self.dim != 0:
                l.remove(var + self.dim + 1)
            l.remove(var + self.dim)

        return l

    def get_dimension(self) -> int:
        """
        Return the number of lines/column/zone in the grid
        """
        return self.dim


### Parent functions

In [127]:
def get_battle_star_clauses(grid: LireGrille, n: int) -> set_of_clauses:
    """
    Returns the CNF equivalent to the Starbattle game

    Parameters
    ----------
    grid : Lire_grille
        Grid of the Starbattle game
    n : int
        Number of star by line/column/zone

    """
    ret: set_of_clauses = []

    for i in range(1, grid.get_dimension() + 1):
        ret.extend(get_CNF_n_true_among(grid.get_zone(i), n))
        ret.extend(get_CNF_n_true_among(grid.get_line(i), n))
        ret.extend(get_CNF_n_true_among(grid.get_column(i), n))

    for i in range(1, grid.get_dimension() ** 2 + 1):
        ret.extend(get_CNF_no_true_among_neighbor(i, grid.get_neighbor(i)))

    return ret


def write_DIMAC_clause(path: str, clauses_list: set_of_clauses, nb_of_variable: int) -> NoReturn:
    """
    Write the clause in DIMAC format in the file given in parameter

    Parameters
    ----------
    path : str
        Path of the file where the clause will be written
    clauses_list : set_of_clauses
        Clauses to write in DIMAC format
    nb_of_variable : int
        Number of variable of the grid
    """

    with open(path, 'w') as f:
        f.write(f"p cnf {nb_of_variable} {len(clauses_list)}\n")
        for clause in clauses_list:
            for var in clause:
                f.write(f"{var} ")
            f.write("0\n")


def convert_NSAT_to_3SAT(path: str) -> NoReturn:
    """
    Convert the NSAT file given in parameter to a 3SAT file

    Parameters
    ----------
    path : str
        Path of the NSAT file
    """
    with open(path, 'r') as f:
        lines = f.readlines()

    maxvalue: int = int(lines[0].split()[2])
    nb_of_clauses: int = int(lines[0].split()[3])
    cnf: set_of_clauses = [list(map(int, lines[i].split()[:-1])) for i in range(1, nb_of_clauses + 1)]

    new_cnf: set_of_clauses = []
    for clause in cnf:
        if len(clause) == 1:
            maxvalue += 2
            new_cnf.extend([[x1 := clause[0], y1 := maxvalue, y2 := maxvalue - 1],
                            [x1, y1, -y2],
                            [x1, -y1, y2],
                            [x1, -y1, -y2]])
        elif len(clause) == 2:
            maxvalue += 1
            new_cnf.extend([[x1 := clause[0], x2 := clause[1], y1 := maxvalue],
                            [x1, x2, -y1]])
        else:
            while len(clause) > 3:
                new_clause: clause = []
                for i in range(0, len(clause), 2):
                    if i + 1 < len(clause):
                        maxvalue += 1
                        new_cnf.append([clause[i], clause[i + 1], -maxvalue])
                        new_clause.append(maxvalue)
                    else:
                        new_clause.append(clause[i])
                clause = new_clause
            new_cnf.append(clause)

    os.remove(path)
    write_DIMAC_clause(path, new_cnf, maxvalue)


def run_sat_solver(path: str, n: int, convert_to_3sat: bool = False) -> list_of_variable:
    """
    Run the SAT solver on the clauses of the Starbattle game and return the result

    Parameters
    ----------
    path : str
        Path of the file where the clause is written
    n : int
        Number of star by line/column/zone
    convert_to_3sat: bool
        If True, convert the clauses to 3SAT clauses

    Returns
    -------
    list_of_variable
        List of the true variable of the grid (if the grid is solved)
    """

    grid: LireGrille = LireGrille(path)
    starbattle_clauses: set_of_clauses = get_battle_star_clauses(grid, n)
    write_DIMAC_clause(path.split(".")[-2] + ".dimac", starbattle_clauses, grid.get_dimension() ** 2)
    if convert_to_3sat:
        convert_NSAT_to_3SAT(path.split(".")[-2] + ".dimac")

    if os.name == 'nt':
        ret_state = os.system(f"wsl -- minisat {path.split('.')[-2] + '.dimac'} {path.split('.')[-2] + '.solved'}")
    else:
        ret_state = os.system(f"minisat {path.split('.')[-2] + '.dimac'} {path.split('.')[-2] + '.solved'}")

    if ret_state == 20:
        raise ValueError("The SAT solver returned a UNSAT problem")
    if ret_state != 10:
        raise ValueError("The SAT solver returned an error")

    with open(path.split(".")[-2] + ".solved", "r") as f:
        lines = f.readlines()

    data = [var for item in lines[1].split() if abs(var := int(item)) <= grid.get_dimension() ** 2 and var > 0]
    return data

In [131]:
%%time
run_sat_solver("Exemples/test1.txt", 2, False)

CPU times: total: 484 ms
Wall time: 5.59 s


[6, 8, 10, 12, 23, 25, 30, 36, 37, 41, 52, 54, 56, 58, 69, 71, 74, 76]