In [1]:
import time
import random

In [2]:
class Statistics:
    """
    Class used to store the various statistics
    measured while solving the SAT problem.
    """
    
    '''
    Constructor for the Statistics class
    '''    
    def __init__(self):
        self._input_file = ""
        self._result = ""
        self._output_file = ""
        
        self._num_vars = 0
        self._num_orig_clauses = 0
        self._num_clauses = 0    
        self._num_learned_clauses = 0 # Number of clauses learned by the solver
        self._num_decisions = 0
        self._num_implications = 0
        
        self._start_time = 0
        self._read_time = 0
        self._complete_time = 0
    
    def print_stats(self):
        print("=========================== STATISTICS ===============================")
        print("Solving formula from file: ",self._input_file)
        print("Vars:{}, Clauses:{} Stored Clauses:{}".format(str(self._num_vars),str(self._num_orig_clauses),str(self._num_clauses)))
        print("Input Reading Time: ",self._read_time - self._start_time)
        print("-------------------------------")
        print("Learned clauses: ",self._num_learned_clauses)
        print("Decisions made: ",self._num_decisions)
        print("Implications made: ",self._num_implications)
        print("Time taken: ",self._complete_time-self._start_time)
        print("-------------------------------")
        print("RESULT: ",self._result)
        print("======================================================================")  

In [3]:
class AssignedNode:
    """
    Class used to store the information
    about the variables being assigned
    """
    
    '''
    Constructor for the AssignedNode class
    '''
    def __init__(self,var,value,level,clause):        
        self.var = var       # Variable that is assigned 
        self.value = value   # Value assigned to the variable (True/False)
        self.level = level   # Level at which the variable is assigned
        
        # The index of the clause which implies the variable var
        # if var is assigned through Implication,
        # None if var is decided
        self.clause = clause 
        
        # Index at which a node is placed in the assignment stack
        # Initially it is -1 when node is created and has to be
        # updated when pushed in assignment_stack.
        self.index = -1
        
    def __str__(self):
        return str(self.var)+" "+str(self.value)+" "+str(self.level)+" "+str(self.clause)

In [4]:
class SAT:
    """
    Class to store the data structures that are maintained while
    solving the SAT problem and the methods to solve the SAT
    """
    
    '''
    Constructor for the SAT class
    '''
    def __init__(self,to_log):
        self._num_clauses = 0 # Number of clauses 
        self._num_vars = 0    # Number of variables 
        self._level = 0       # Decision level (level at which the solver is in backtracking tree)
        self._clauses = []
        self._clauses_watched_by_l = {}
        self._literals_watching_c = {}
        
        self._unassigned_vars = set()
        self._variable_to_assignment_nodes = {}
        self._assignment_stack = []
        
        self._is_log = to_log
        
        # Stats object
        self.stats = Statistics()

In [5]:
def is_negative_literal(self,literal):
    return literal > self._num_vars

SAT._is_negative_literal = is_negative_literal    

In [6]:
def get_var_from_literal(self,literal):
    if self._is_negative_literal(literal):
        return literal - self._num_vars
    
    return literal

SAT._get_var_from_literal = get_var_from_literal

In [7]:
def add_clause(self,clause):
    clause = clause[:-1]

    #Remove duplicates
    clause = list(set(clause))
    
    if len(clause)==1:
        lit = clause[0]
        value_to_set = True
        if lit[0]=='-':
            value_to_set = False
            var = int(lit[1:])
        else:
            var = int(lit)
            
        if var not in self._variable_to_assignment_nodes:
            self.stats._num_implications += 1
            node = AssignedNode(var,value_to_set,0,None)
            self._variable_to_assignment_nodes[var] = node
            self._assignment_stack.append(node)
            node.index = len(self._assignment_stack)-1
            self._unassigned_vars.remove(var)
            if self._is_log:
                print("Implied(unary): ",node)
        else:
            node = self._variable_to_assignment_nodes[var]
            if node.value != value_to_set:
                #### HANDLE
                print("UNSAT")
        return
        
    
    clause_with_literals = []
    
    for lit in clause:
        if lit[0]=='-':
            var = int(lit[1:])
            clause_with_literals.append(var+self._num_vars)
        else:
            var = int(lit)
            clause_with_literals.append(var)
    
    clause_id = self._num_clauses
    self._clauses.append(clause_with_literals)
    self._num_clauses += 1
    
    # setup watchers
    watch_literal1 = clause_with_literals[0]
    watch_literal2 = clause_with_literals[1]
    
    self._literals_watching_c[clause_id] = [watch_literal1,watch_literal2]
    
    self._clauses_watched_by_l.setdefault(watch_literal1,[]).append(clause_id)
    self._clauses_watched_by_l.setdefault(watch_literal2,[]).append(clause_id)
    

SAT._add_clause = add_clause

In [8]:
def read_dimacs_cnf_file(self,cnf_filename):    
    cnf_file = open(cnf_filename,"r")
    for line in cnf_file.readlines():
        line = line.rstrip()
        line = line.split()
        
        first_word = line[0]
        
        # Comments
        if first_word == "c":
            continue
        elif first_word == "p":
            self._num_vars = int(line[2])
            self._unassigned_vars = {i for i in range(1,self._num_vars+1)}
            self.stats._num_orig_clauses = int(line[3])
        else:
            self._add_clause(line)        
    cnf_file.close()

SAT._read_dimacs_cnf_file = read_dimacs_cnf_file

In [9]:
def decide(self): 
    
    num_unassigned_vars = len(self._unassigned_vars)
    
    if num_unassigned_vars == 0:
        return -1
    
    var = self._unassigned_vars.pop()
    value_to_set = True
    
#     var = random.choice(tuple(self._unassigned_vars))
#     value_to_set = random.choice([True,False])
#     self._unassigned_vars.remove(var)
    
    self._level += 1
    new_node = AssignedNode(var,value_to_set,self._level,None)
    self._variable_to_assignment_nodes[var] = new_node
    self._assignment_stack.append(new_node)
    new_node.index = len(self._assignment_stack)-1
    self.stats._num_decisions += 1
    if self._is_log:
        print("Choosen decision: ",end="")
        print(new_node)
    return var
    return -1

SAT._decide = decide

In [10]:
def boolean_constraint_propogation(self,is_first_time):
    last_assignment_pointer = len(self._assignment_stack)-1
    
    if is_first_time:
        last_assignment_pointer = 0
        
    while last_assignment_pointer < len(self._assignment_stack):
        last_assigned_node = self._assignment_stack[last_assignment_pointer]
        
        if last_assigned_node.value == True:
            literal_that_is_falsed = last_assigned_node.var + self._num_vars
        else:
            literal_that_is_falsed = last_assigned_node.var
        
        # Now change watch literals for all clauses watched by literal_that_is_falsed
        itr = 0
        clauses_watched_by_falsed_literal = self._clauses_watched_by_l.setdefault(literal_that_is_falsed,[]).copy()
        while itr < len(clauses_watched_by_falsed_literal):
            clause_id = clauses_watched_by_falsed_literal[itr]
            watch_list_of_clause = self._literals_watching_c[clause_id]
            
            other_watch_literal = watch_list_of_clause[0]
            if other_watch_literal == literal_that_is_falsed:
                other_watch_literal = watch_list_of_clause[1]
            
            other_watch_var = self._get_var_from_literal(other_watch_literal)
            is_negative_other = self._is_negative_literal(other_watch_literal)
            
            if other_watch_var in self._variable_to_assignment_nodes:
                value_assgned = self._variable_to_assignment_nodes[other_watch_var].value
                if (is_negative_other and value_assgned == False) or (not is_negative_other and value_assgned == True):
                    itr += 1
                    continue
            
            new_literal_to_watch = -1
            clause = self._clauses[clause_id]
            for lit in clause:
                if lit not in watch_list_of_clause:
                    var_of_lit = self._get_var_from_literal(lit)
                    if var_of_lit not in self._variable_to_assignment_nodes:
                        new_literal_to_watch = lit
                        break
                    else:
                        node = self._variable_to_assignment_nodes[var_of_lit]
                        is_negative = self._is_negative_literal(lit)
                        if (is_negative and node.value == False) or (not is_negative and node.value == True):
                            new_literal_to_watch = lit
                            break
            
            if new_literal_to_watch != -1:
                # We have got a new watcher
                
                self._literals_watching_c[clause_id].remove(literal_that_is_falsed)
                self._literals_watching_c[clause_id].append(new_literal_to_watch)
                
                self._clauses_watched_by_l.setdefault(literal_that_is_falsed,[]).remove(clause_id)
                self._clauses_watched_by_l.setdefault(new_literal_to_watch,[]).append(clause_id)
                
            else:
                # We got no other watcher, so other_watch_literal has to be true
                if other_watch_var not in self._variable_to_assignment_nodes:
                    value_to_set = not is_negative_other
                    assign_var = AssignedNode(other_watch_var,value_to_set,self._level,clause_id)
                    self._variable_to_assignment_nodes[other_watch_var] = assign_var
                    self._assignment_stack.append(assign_var)
                    assign_var.index = len(self._assignment_stack)-1
                    self._unassigned_vars.remove(assign_var.var)
                    self.stats._num_implications += 1
                    if self._is_log:
                        print("Implied decision:", end="")
                        print(assign_var)
                else:
                    # Conflict is detected, create a conflict node and push it to assignment stack
                    # conflic node needed to store which clause caused the conflict and the
                    # level at which the conflict occured
                    conflict_node = AssignedNode(None,None,self._level,clause_id)
                    self._assignment_stack.append(conflict_node)
                    conflict_node.index = len(self._assignment_stack)-1
                    if self._is_log:
                        print("CONFLICT")
                    return "CONFLICT"
            
            itr += 1
        last_assignment_pointer += 1
    
    return "NO_CONFLICT"

SAT._boolean_constraint_propogation = boolean_constraint_propogation

In [11]:
def binary_resolute(self,clause1,clause2,var):    
    full_clause = clause1 + clause2
    full_clause.remove(var)
    full_clause.remove(var+self._num_vars)
    return list(set(full_clause))

SAT._binary_resolute = binary_resolute

In [12]:
def is_valid_clause(self,clause,level):
    counter = 0
    maxi = -1
    cand = -1
    for lit in clause:
        var = self._get_var_from_literal(lit)
        node = self._variable_to_assignment_nodes[var]
        if node.level == level:
            counter += 1
            if node.index > maxi:
                maxi = node.index
                cand = node
    return counter == 1,cand

SAT._is_valid_clause = is_valid_clause

In [13]:
def get_backtrack_level(self,conflict_clause,conflict_level):
    maximum_level_before_conflict_level = -1
    literal_at_conflict_level = -1
    
    for lit in conflict_clause:
        var = self._get_var_from_literal(lit)
        assigned_node = self._variable_to_assignment_nodes[var]
        
        if assigned_node.level == conflict_level:
            literal_at_conflict_level = lit
        else:
            if assigned_node.level > maximum_level_before_conflict_level:
                maximum_level_before_conflict_level = assigned_node.level
    
    return maximum_level_before_conflict_level, literal_at_conflict_level

SAT._get_backtrack_level = get_backtrack_level

In [14]:
def analyze_conflict(self):
    # assert that last node is a conflict node

    assigment_stack_pointer = len(self._assignment_stack)-1
    conflict_node = self._assignment_stack[assigment_stack_pointer]
    conflict_level = conflict_node.level
    conflict_clause = self._clauses[conflict_node.clause]
    
    # pop the conflict node
    self._assignment_stack.pop()
    
   
    if self._is_log:
        print("Analyzing Conflict in the node: ",end="")
        print(conflict_node)
    
    if conflict_level == 0:
        return -1,None
    
#     node = conflict_node
    while True:
        is_nice,node_to_use = self._is_valid_clause(conflict_clause,conflict_level)
        if is_nice:
            break
        prev_assigned_node = node_to_use
        if self._is_log:
            print("Clause: ",conflict_clause)
            print("Node_to_use ",node_to_use)
        clause = self._clauses[prev_assigned_node.clause]
        var = prev_assigned_node.var
        conflict_clause = self._binary_resolute(conflict_clause,clause,var)
    
    if len(conflict_clause) > 1:
        # Ad learned clause
        self.stats._num_learned_clauses += 1
        clause_id = self._num_clauses
        self._num_clauses += 1
        self._clauses.append(conflict_clause)
        self._clauses_watched_by_l.setdefault(conflict_clause[0],[]).append(clause_id)
        self._clauses_watched_by_l.setdefault(conflict_clause[1],[]).append(clause_id)
        self._literals_watching_c[clause_id] = [conflict_clause[0],conflict_clause[1]]
        
        backtrack_level, conflict_level_literal = self._get_backtrack_level(conflict_clause,conflict_level)
        
        conflict_level_var = self._get_var_from_literal(conflict_level_literal)
        is_negative_conflict_lit = self._is_negative_literal(conflict_level_literal)
        
        value_to_set = True
        if is_negative_conflict_lit:
            value_to_set = False
        
        ###
        maxi = -1
        cand = -1
        
        for lit in conflict_clause:
            if lit != conflict_level_literal:
                var_lit = self._get_var_from_literal(lit)
                assigned_node = self._variable_to_assignment_nodes[var_lit]
                
                if assigned_node.index > maxi:
                    maxi = assigned_node.index
                    cand = assigned_node
        
        ###
        
        node = AssignedNode(conflict_level_var,value_to_set,backtrack_level,clause_id)
        return backtrack_level,node
    else:
        literal = conflict_clause[0]
        var = self._get_var_from_literal(literal)
        is_negative_literal = self._is_negative_literal(literal)
        assigned_node = self._variable_to_assignment_nodes[var]
        if assigned_node.level >= 1:
            backtrack_level = 0
            value_to_set = True
            if is_negative_literal:
                value_to_set = False
            node = AssignedNode(var,value_to_set,backtrack_level,None)
            return backtrack_level,node
        else:
            return -1,None
    
        
    
SAT._analyze_conflict = analyze_conflict

In [15]:
def backtrack(self,backtrack_level,node_to_add):
    # set level
    self._level = backtrack_level
    
    # Remove all nodes at level gr8r than btrack level
    itr = len(self._assignment_stack)-1
    while True:
        if itr<0:
            break
        if self._assignment_stack[itr].level <= backtrack_level:                
            break
        else:
            self._unassigned_vars.add(self._assignment_stack[itr].var)
            del self._variable_to_assignment_nodes[self._assignment_stack[itr].var]
            self._assignment_stack.pop(itr)
            itr -= 1
    
    # Add the node to the end
    self._variable_to_assignment_nodes[node_to_add.var] = node_to_add
    self._assignment_stack.append(node_to_add)
    node_to_add.index = len(self._assignment_stack)-1
    self._unassigned_vars.remove(node_to_add.var)
    self.stats._num_implications += 1
    
SAT._backtrack = backtrack

In [16]:
def solve(self,cnf_filename):
    '''
    Solve the SAT problem instance present in
    the passed filename.
    
    Parameters:
        cnf_filename: Name of the file having the SAT formula to be solved
    '''
    self.stats._input_file = cnf_filename
    self.stats._start_time = time.time()
    
    self._read_dimacs_cnf_file(cnf_filename)
    
    self.stats._read_time = time.time()
    self.stats._num_vars = self._num_vars
    self.stats._num_clauses = self._num_clauses
    
    first_time = True
    while True:     
        while self._boolean_constraint_propogation(first_time) == "CONFLICT":
            first_time = False
            backtrack_level, node_to_add = self._analyze_conflict()
            if backtrack_level == -1:
                print("UNSAT")
                self.stats._result = "UNSAT"
                self.stats._complete_time = time.time()
                return
            self._backtrack(backtrack_level,node_to_add)
        first_time = False
        var_decided = self._decide()
        if var_decided == -1:
            print("SAT")
            self.stats._result = "SAT"
            self.stats._complete_time = time.time()
            return
            

SAT.solve = solve

In [17]:
sat=SAT(False)
sat.solve("bmc-5.cnf")

SAT


In [18]:
#### Inspections
sat.stats.print_stats()

Solving formula from file:  bmc-5.cnf
Vars:9396, Clauses:41207 Stored Clauses:41134
Input Reading Time:  0.181243896484375
-------------------------------
Learned clauses:  1308
Decisions made:  17285
Implications made:  303068
Time taken:  2.8075644969940186
-------------------------------
RESULT:  SAT
