In [211]:
import os
import time
import random
import threading

In [212]:
# function to read clauses from cnf file "fname".

def read_clause (fname):
    f = open(fname,"r")
    read_count = 0
    clauses = []
    readflag = 1                 # =1, indicate new clause. The definition of a clause may extend beyond a single line of text as per standard format 
    for x in f :                 # for each line
        if x[0] != 'c' :         #if not comment
            if x[0] == 'p' :     #if 'p' i.e problem def. line
                x = x.split()
                num_var = int(x[2])      #read number of variables and clauses
                num_claus = int(x[3])
            else :                     #if clause
                if readflag == 1 :     #new clause starts
                    newclause = []
                x = x.split()
                for j in x :
                    i = int(j)
                    readflag = 0
                    if i == 0 :          #if 0,i.e clause end, push this clause to list clauses and start a new clause
                        clauses.append(newclause)
                        readflag = 1
                        read_count += 1
                        break
                    if -i in newclause:
                        readflag = 1
                        read_count += 1
                        break
                    if i not in newclause:
                        newclause.append(i)       #while Integer not zero append to current clause
                        
                    #Following conditions are for checking Inconsistent or Invalid input cnf file
                    if i > num_var :
                        print("Illegal variable number "+str(i))
                        return 0,num_var,num_claus
                    if read_count >= num_claus :
                        print("Number of clauses exceeds definition")
                        return 0,num_var,num_claus
    return 1,num_var,num_claus,clauses

In [213]:
# Following functions are for hueristic VSIDS

# Initiatilaztion  : Generated counter for number of times a literal appears
def VSIDS_init(clauses,num_var):
    counter = {}
    for x in range(-num_var,num_var+1):
        counter[x]=0
    for clause in clauses:
        for literal in clause:
            counter[literal] += 1
    return counter

# conflict : Incerements counter of literalts in conflict clause to increase there chances of getting selected
def VSIDS_conflict(counter,conflictClause):
    for literal in conflictClause:
        counter[literal]+=1
    return counter

# decay : Counter is reduced by 5% for all literals at each conflict
def VSIDS_decay(counter,num_var):
    for i in range(-num_var,num_var+1):
        counter[i]=counter[i]*95/100
    return counter

# decide : Picks a Variable NOT yet in M based on max counter value
def VSIDS_decide(counter,M,num_var):
    max=0
    var=0
    for x in range(-num_var,num_var+1):
        if counter[x]>=max and x not in M and -x not in M:
                max=counter[x]
                var=x
    return var

In [214]:
# NOTE : BCP and Unit Propogation in ONLY used in beginning to get rid of unit clauses and Their implications

def bcp(clauses, literal):                    #Boolean Constant Propagation on Literal
    new_claus_set = [x[:] for x in clauses]   #Using SLicing Technique: Fastest available in Python
    for x in reversed(new_claus_set):
        if literal in x:                      #if clause satified ,
            new_claus_set.remove(x)                    #Remove that clause
        if -literal in x:                     #if -literal present , remaining should satisfy . Hence,
            x.remove(-literal)                         #Remove -literal from that clause
            if not x:                         #if this makes a clause Empty , UNSAT
                return -1
    return new_claus_set



def unit_propagation(clauses):               # Propogate Unit Clauses and add implications to M
    assignment = []
    flag=1
    while flag!=0:                           #till Implications are found
        flag=0
        for x in clauses:                    #for each clause
            if len(x) == 1 :                 # if UNIT clause , propagate and add to assignment
                unit=x[0]
                clauses = bcp(clauses, unit) 
                assignment += [unit]
                flag=1
            if clauses == -1:                #if UNSAT after propogate, return -1
                return -1, []
            if not clauses:                   
                return clauses, assignment
    return clauses, assignment

In [215]:
def create_watchList(clauses,M,num_var):          # Create the 2-literal watch data structure
    literal_watch = {}                    # Will contain the main Literal-> Clause number mapping
    clauses_literal_watched = []          # The reverse,i.e. Clause-> Literal mapping
    for i in range (-num_var,num_var+1):
        literal_watch[i] = []
    for i in range (0,len(clauses)):      #for each clause pick two literals
        newc = []
        first = 0
        for j in range(0,len(clauses[i])):
            if clauses[i][j] not in M and first==0:
                A = clauses[i][j]
                first=1
                continue
            if clauses[i][j] not in M and first==1:
                B = clauses[i][j]
                break
        newc.append(A)
        newc.append(B)
        clauses_literal_watched.append(newc)       #add both to watched of that clause 
        literal_watch[A].append(i)                 #add clause to watch of those literals
        literal_watch[B].append(i)
    return literal_watch,clauses_literal_watched

In [216]:
# Function to propogate using 2-literal watch

def two_watch_propogate(clauses,literal_watch,clauses_literal_watched,M,variable): 
    prop_list = [variable]             # add current change to list of updates
    while len(prop_list) != 0 :        # while updates remain to propogate
        variable = prop_list.pop()     #pick one variable
        for affected_claus_num in reversed(literal_watch[-variable]) : #for all clauses in its watch list
            affected_claus = clauses[affected_claus_num][:]
            A = clauses_literal_watched[affected_claus_num][0]
            B = clauses_literal_watched[affected_claus_num][1]
            A_prev=A
            B_prev=B
            status,M,A,B,unit = check_status(affected_claus,M,A,B)     # check status of each clause
            if status == "Unit" :
                prop_list.append(unit)
                M.append(unit)                                         # if unit, add to updates
            elif status == "Unsatisfied":                              #if unsat, return conflict clause
                return affected_claus,literal_watch
                                                                       #else the clause is still unresolve, remove from current add to proper watch
            literal_watch [A_prev].remove(affected_claus_num)
            literal_watch [B_prev].remove(affected_claus_num)
            clauses_literal_watched[affected_claus_num][0] = A
            clauses_literal_watched[affected_claus_num][1] = B
            literal_watch [A].append(affected_claus_num)
            literal_watch [B].append(affected_claus_num)
            
    return -1,literal_watch

In [217]:
def check_status(clause,M,A,B):
    unit = 0
    if A in M or B in M:                   # if one watch satisfied, nothing to do 
        return "Satisied",M,A,B,unit
    sym=[]                                  # symbols not defined yet
    for literal in clause:                  # find symbols not defined
        if -literal not in M:
            sym.append(literal)
        if literal in M :
            if -A not in M :
                return "Satisied",M,A,literal,unit
            return "Satisied",M,literal,B,unit
    if len(sym) == 1:                              # if one such symbol -> Unit Clause
        return "Unit",M,A,B,sym[0]
    if len(sym) == 0:                              # if no such symbol -> Unsatisfied (conflict) clause
        return "Unsatisfied",M,A,B,unit
    else :
        return "Unresolved",M,sym[0],sym[1],unit   # else return two new unsatisfied variables to use for Literal_watch

In [218]:
def RandomRestart(M,back,decide_pos,probability,Restart_count):  
    if random.random() < probability :          # If Generated random probability less than current : RESTART
        M = back[:]
        decide_pos = []
        probability *= 0.5                      # Decay next Restart probability by 50%
        Restart_count += 1
        if probability < 0.001 :
            probability = 0.2
        if Restart_count > len(M) + 10:         #avoid restarts if already restarted many times
            probability=0
    return probability,Restart_count

In [219]:
def verify(M,clauses) :                  # Verify the Solution in M for SAT
    for c in clauses :                   # for each clause
        flag = 0
        for lit in c:
            if lit in M:                 # atleast one literal should be true
                flag = 1
                break
        if flag == 0:
            return False
    return True

In [220]:
def Analyze_Conflict(M, conflict,decide_pos):  #for simplicity : ALL DECISIONs made till now are a Learned Clause 
    learn = []
    for x in decide_pos:
        learn.append(-M[x])
    return learn

In [221]:
def all_vars_assigned(num_var ,M_len):        # Returns True if all variables already assigne , False otherwise
    if M_len >= num_var:
        return True
    return False

In [222]:
def assign(variable,M,decide_pos):             # Adds the decision literal to M and correponding update to decision level
    decide_pos.append(len(M))
    print(variable)
    M.append(variable)

In [223]:
def add_learned_clause_to(clauses,literal_watch,clauses_literal_watched,Learned_c,M):
    if len(Learned_c) == 0:
        return -1
    if len(Learned_c) == 1:             # if unit clause is learnt : add it as a decision 
        M.append(Learned_c[0])
        return 1,Learned_c[0]
    clauses.append(Learned_c)           # for others, add two literals A,B to literal watch data structure
    A = Learned_c[0]
    B = Learned_c[1]
    i = len(clauses)-1
    newc = []
    newc.append(A)
    newc.append(B)
    clauses_literal_watched.append(newc)
    literal_watch[A].append(i)
    literal_watch[B].append(i)
    return 0

In [224]:
def Backjump(M, dec_level, decide_pos,Imp_count):         #BackJump to decision level by deleting decisions from M and decision positions
    Imp_count = Imp_count + len(M) - len(decide_pos)
    if not decide_pos:
        return -1,-1,Imp_count
    dec_level = decide_pos.pop()
    literal = M[dec_level]
    del M[dec_level:]
    return 0,-literal,Imp_count

In [225]:
def progressBar(current, total, barLength = 20) :        # Print progress bar. Just to givee feel of work being done
    percent = float(current) * 100 / total
    arrow   = '-' * int(percent/100 * barLength - 1) + '>'
    spaces  = ' ' * (barLength - len(arrow))
    print('Progress (num_var:may backtrack): [%s%s] %d ' % (arrow, spaces, current), end='\r')

In [226]:
def MAIN(condition, var):         
    try:                                                             # to run file from directory test
        os.chdir("test")
    except:
        do_nothing_just_avoid_exception = 1
    #fname = input("Enter file name (without extension) :")           # Take input file name. Repeat till valid input
    #print()
    fname = "a.cnf"
    while (not os.path.isfile(fname)) or (not os.path.exists(fname)):
        fname = input("Error. File "+fname+" not found. Try again:")
        print()
        fname += ".cnf"

    startread = time.process_time()
    a,num_var,num_claus,clauses = read_clause(fname)                  # Read from input file
    endread = time.process_time()
    
    if a == 1 :                                                       # Status of reading input
        print("Successfully read the file")
        print("Vars : "+str(num_var)+" Clauses : "+str(num_claus))
        print("Read time :"+str(endread-startread)+" sec")
    else:
        print("Please resolve the errors")
        return
    
    print (" Solving ...")
    startSolve = time.process_time()
    solution = CDCL_solve(clauses,num_var,condition, var)                            # Solve CNF by CDCL
    EndSolve = time.process_time()
    print()
    print("Result:")
    print()
    print("Statistics :")
    print("=============================================")
    print("# Restarts : " + str(solution[1]))
    print("# Learned Clauses : " + str(solution[2]))
    print("# Decisions : " + str(solution[3]))
    print("# Implications : " + str(solution[4]))
    print("# Solve time : "+str(EndSolve-startSolve)+" sec")             # Print results
    print("=============================================")
    if solution[0] != -1:
        print("Assignment verified")
        assn = solution[0][:]
        assn.sort(key=abs)
        os.chdir('..')
        os.chdir("solutions")
        with open(fname, "w") as outfile:
            outfile.write("\n".join(str(lit) for lit in assn))
        os.chdir('..')
        print("Solution in /solutions/"+str(fname))
        print("SAT")
    else :
        print("UNSAT")
        return

El codi anterior és original d'[aquest repositori](https://github.com/Kapilhk/SatPie). Les següents modificacions s'han implementat per generar un solucionador tal que, cada vegada que s'ha de prendre una decisió, l'execució es pari, un agent extern pugui prendre la decisió corresponent, i l'execució pugui continuar.

In [17]:
"""
Solucionador SAT Solver CDCL, resol les clàsules donades a clauses, amb num_var variables.
break_at_decision == True --> El solucionador es deté quan ha de prendre una decisió, o quan troba que és SAT o UNSAT
break_at_decision == False --> El solucionador executa tots els passos de forma habitual, i es deté quan troba que és SAT o UNSAT

Si el model és SAT, retorna 1, M, Restart_count,Learned_count,Decide_count,Imp_count, None
Si el model és UNSAT, retorna -1, None, Restart_count,Learned_count,Decide_count,Imp_count, None
Si el model ha de prendre una decisió, retorna 0, M, Restart_count,Learned_count,Decide_count,Imp_count, backup_model
"""

def CDCL_solve(clauses, num_var, break_at_decision = False):

    def backup_model():
        return {
            'clauses': clauses,
            'num_var': num_var,
            'decide_pos': decide_pos,
            'M': M,
            'back': back,
            'counter': counter,
            'literal_watch': literal_watch,
            'clauses_literal_watched': clauses_literal_watched,
            'probability': probability,
            'Restart_count': Restart_count,
            'Learned_count': Learned_count,
            'Decide_count': Decide_count,
            'Imp_count': Imp_count,
            'variable': variable,
            'conflict': conflict,
            'Learned_c': Learned_c,
            'dec_level': dec_level,
            'jump_status': jump_status,
            'var': var
        }

    
    decide_pos = []                             # for Maintaing Decision Level
    M = []                                      # Current Assignments and implications
    clauses,M = unit_propagation(clauses)                        # Initial Unit Propogation : if conflict - UNSAT
    if clauses == -1 :
        return -1,0,0,0,0                                        # UNSAT
    back=M[:]                                                    # Keep Initialization Backup for RESTART
    counter = VSIDS_init(clauses,num_var)                        # Initialize Heuristic counter
    
    # Initialize TWO LITERAL WATCH data Structure :
    literal_watch,clauses_literal_watched = create_watchList(clauses,M,num_var)

    probability=0.9                                             # Random Restart Probability ,  Decays with restarts
    Restart_count = Learned_count = Decide_count = Imp_count = 0
    
    
    variable = None
    conflict = None
    Learned_c = None
    dec_level = None
    jump_status = None
    var = None
    

    while not all_vars_assigned(num_var , len(M)) :             # While variables remain to assign
        
        if break_at_decision:
            return 0, M, Restart_count,Learned_count,Decide_count,Imp_count, backup_model()
        
        variable = VSIDS_decide(counter,M,num_var)                      # Decide : Pick a variable
        Decide_count += 1
        
        assign(variable,M,decide_pos)
        conflict,literal_watch = two_watch_propogate(clauses,literal_watch,clauses_literal_watched,M,variable)         # Deduce by Unit Propogation
        
        
        while conflict != -1 :
            VSIDS_conflict(counter,conflict)                    # Incerements counter of literalts in conflict
            counter=VSIDS_decay(counter,num_var)                # Decay counters by 5%

            Learned_c = Analyze_Conflict(M, conflict,decide_pos)      #Diagnose Conflict

            dec_level = add_learned_clause_to(clauses,literal_watch,clauses_literal_watched,Learned_c,M) #add Learned clause to all data structures
            Learned_count += 1
            jump_status,var,Imp_count = Backjump(M, dec_level, decide_pos,Imp_count)      #BackJump to decision level

            if jump_status == -1:                                     # UNSAT
                return -1, None, Restart_count,Learned_count,Decide_count,Imp_count, None
            M.append(var)                                             # Append negation of last literal after backjump
            
            probability,Restart_count = RandomRestart(M,back,decide_pos,probability,Restart_count)        #Random Restart
            conflict,literal_watch = two_watch_propogate(clauses,literal_watch,clauses_literal_watched,M,var)

            
    #Reaches here if all variables assigned. 
    return 1, M, Restart_count,Learned_count,Decide_count,Imp_count, None

In [37]:
import numpy as np
import gym
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim

In [18]:
"""
Solucionador SAT Solver CDCL, cridar després de CDCL_solve amb break_at_decision == True (en cas que el solver s'hagi aturat per prendre una decisió), o després de CDCL_continue
Continua la resolució on ho havia deixat CDCL_solve, amb la decisió decision presa.

El solucionador es deté quan ha de prendre una decisió, o quan troba que és SAT o UNSAT

Si el model és SAT, retorna 1, M, Restart_count,Learned_count,Decide_count,Imp_count, None
Si el model és UNSAT, retorna -1, None, Restart_count,Learned_count,Decide_count,Imp_count, None
Si el model ha de prendre una decisió, retorna 0, M, Restart_count,Learned_count,Decide_count,Imp_count, backup_model
"""

def CDCL_continue(backup, decision):

    def backup_model():
        return {
            'clauses': clauses,
            'num_var': num_var,
            'decide_pos': decide_pos,
            'M': M,
            'back': back,
            'counter': counter,
            'literal_watch': literal_watch,
            'clauses_literal_watched': clauses_literal_watched,
            'probability': probability,
            'Restart_count': Restart_count,
            'Learned_count': Learned_count,
            'Decide_count': Decide_count,
            'Imp_count': Imp_count,
            'variable': variable,
            'conflict': conflict,
            'Learned_c': Learned_c,
            'dec_level': dec_level,
            'jump_status': jump_status,
            'var': var
        }
   
    clauses = backup['clauses']
    num_var = backup['num_var']
    decide_pos = backup['decide_pos']
    M = backup['M']
    back = backup['back']
    counter = backup['counter']
    literal_watch = backup['literal_watch']
    clauses_literal_watched = backup['clauses_literal_watched']
    probability = backup['probability']
    Restart_count = backup['Restart_count']
    Learned_count = backup['Learned_count']
    Decide_count = backup['Decide_count']
    Imp_count = backup['Imp_count']
    variable = backup['variable']
    conflict = backup['conflict']
    Learned_c = backup['Learned_c']
    dec_level = backup['dec_level']
    jump_status = backup['jump_status']
    var = backup['var']

    first_decision_already_made = False
    
    while not all_vars_assigned(num_var , len(M)) :             # While variables remain to assign
        
        if first_decision_already_made:
            return 0, M, Restart_count,Learned_count,Decide_count,Imp_count, backup_model()    #Decision, break !!
        
        first_decision_already_made = True
        variable = decision
        Decide_count += 1
        
        assign(variable,M,decide_pos)
        conflict,literal_watch = two_watch_propogate(clauses,literal_watch,clauses_literal_watched,M,variable)         # Deduce by Unit Propogation
        
        
        while conflict != -1 :
            VSIDS_conflict(counter,conflict)                    # Incerements counter of literalts in conflict
            counter=VSIDS_decay(counter,num_var)                # Decay counters by 5%

            Learned_c = Analyze_Conflict(M, conflict,decide_pos)      #Diagnose Conflict

            dec_level = add_learned_clause_to(clauses,literal_watch,clauses_literal_watched,Learned_c,M) #add Learned clause to all data structures
            Learned_count += 1
            jump_status,var,Imp_count = Backjump(M, dec_level, decide_pos,Imp_count)      #BackJump to decision level

            if jump_status == -1:                                     # UNSAT
                return -1, None, Restart_count,Learned_count,Decide_count,Imp_count, None
            M.append(var)                                             # Append negation of last literal after backjump
            
            probability,Restart_count = RandomRestart(M,back,decide_pos,probability,Restart_count)        #Random Restart
            conflict,literal_watch = two_watch_propogate(clauses,literal_watch,clauses_literal_watched,M,var)

            
    #Reaches here if all variables assigned. 
    return 1, M, Restart_count,Learned_count,Decide_count,Imp_count, None

## **Reinforcement Learning**

In [306]:
class SatSolverEnv():
    
    def __init__(self):
        self.clauses = []
        self.current_state = []
        self.M = []
        self.backup = []
        self.num_var = 0

    def reset(self):
        
        self.clauses = []
        
        file_names = ["a.cnf", "sample.cnf", "unsat.cnf", "unsat1.cnf"]
        fname = "test/" + random.choice(file_names)

        a,num_var,num_claus,clauses = read_clause(fname)
        
        
        
        solution = CDCL_solve(clauses,num_var, True)                            # Solve CNF by CDCL
        is_sat, M, Restart_count,Learned_count,Decide_count,Imp_count,backup = solution
        
        self.backup = backup
        self.clauses = clauses
        self.num_var = num_var
        
        state = self.clauses2array(clauses, M)
        self.current_state = state
        self.M = M
        
        return state
    
    def step(self, action):
        
        
        variable = int(action) / 2
        if action % 2 == 0:
            sign = -1
        else:
            sign = 1
        
        decision = int(variable * sign) + 1
        
        if decision in self.M or abs(decision) > num_var-1:
            reward = -500
            next_state = self.current_state
            is_terminal = False
            return next_state, reward, is_terminal
            
        is_sat, M, Restart_count,Learned_count,Decide_count,Imp_count,backup = CDCL_continue(self.backup, decision)
        
        self.backup = backup
        self.M = M
        
        if is_sat != 0:
            reward = 0
            is_terminal = True
            next_state = np.zeros((16,16))
        else:
            reward = -1
            next_state = self.clauses2array(self.clauses, self.M)
            is_terminal = False
            
        self.current_state = next_state
            
        return next_state, reward, is_terminal
        
        
    def clauses2array(self, clauses, M):

        array_clauses = np.zeros((16,16))

        current_clause = 0
        for clause in clauses:
            for variable in clause:
                if variable < 0:
                    variable = -1*variable
                    array_clauses[current_clause][variable-1] = -1
                else:
                    array_clauses[current_clause][variable-1] = 1
            current_clause += 1
            
        if len(M) <= 0:
            return array_clauses

        for i in M:
            for row_index in range(len(array_clauses)):
                    i_to_read = abs(i)-1
                    if array_clauses[row_index][i_to_read] == 1 and i > 0:
                        array_clauses[row_index] = 0
                    elif array_clauses[row_index][i_to_read] == 1 and i < 0:
                        array_clauses[row_index][i_to_read] = 0
                    elif array_clauses[row_index][i_to_read] == -1 and i < 0:
                        array_clauses[row_index] = 0
                    elif array_clauses[row_index][i_to_read] == -1 and i > 0:
                        array_clauses[row_index][i_to_read] = 0


        return array_clauses

In [248]:
class AgentBrain(nn.Module):
    def __init__(self, lr_input):

        super(AgentBrain, self).__init__()

        self.conv1 = nn.Conv2d(1, 16, 3) # 14x14x16
        self.relu1 = nn.ReLU(inplace = True)
        self.conv2 = nn.Conv2d(16, 32, 3) # 12x12x32
        self.relu2 = nn.ReLU(inplace = True)

        self.fc1 = nn.Linear(12*12*32, 128)
        self.relu5 = nn.ReLU(inplace = True)
        self.fc2 = nn.Linear(128, 2*16)

        self.device = torch.device('cuda')
        self.optimizer = optim.Adam(self.parameters(), lr = lr_input)
        self.loss = nn.MSELoss()
        self.to(self.device)
    
    def forward(self, x):
        
        x = torch.from_numpy(x).float()
        x = x.view((1, 1, 16, 16)).to(self.device)
    
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        
        x = torch.flatten(x, 1) 

        x = self.fc1(x)
        x = self.relu5(x)
        x = self.fc2(x)

        return x

In [316]:
class Agent():
    def __init__(self, gamma, epsilon):
        self.gamma = gamma
        self.epsilon = epsilon
        
        self.mem_size = 10000
        self.mem_counter = 0
        
        self.action_space = list(range(32))
        
        self.lr = 1e-4
        
        self.batch_size = 128
        
        self.Q_eval = AgentBrain(self.lr)
        
        self.state_memory = [0] * self.mem_size
        self.new_state_memory = [0] * self.mem_size

        self.action_memory = [0] * self.mem_size
        self.reward_memory = [0] * self.mem_size
        self.terminal_memory = [0] * self.mem_size
        
        
    def store_transition(self, state, action, reward, next_state, is_done):
        
        i = self.mem_counter % self.mem_size
        
        self.state_memory[i] = state
        self.new_state_memory[i] = next_state
        self.reward_memory[i] = reward
        self.action_memory[i] = action
        self.terminal_memory[i] = is_done
        
        self.mem_counter += 1
        
    def choose_action(self, obs):
        
        random_number = np.random.random()
        
        if random_number > self.epsilon:
            
            action = torch.argmax(self.Q_eval.forward(obs)).item()
            
        else:
            action = np.random.choice(self.action_space)
            self.epsilon *= 0.99
            
            
        return action
    
    
    def learn(self):
        
        if self.mem_counter < self.batch_size:
            return
        
        
        self.Q_eval.optimizer.zero_grad()
        
        max_mem = min(self.mem_counter, self.mem_size)
        
        
        batch_exemples = np.random.choice(max_mem, self.batch_size,replace = False)
        batch_index = np.arange(self.batch_size, dtype = np.int32)
        
        disp = self.Q_eval.device
        
        state_batch = self.state_memory[batch_index]
        next_state_batch = self.new_state_memory[batch_index]
        reward_batch = self.reward_memory[batch_index]
        terminal_batch = self.terminal_memory[batch_index]
        action_batch = self.action_memory[batch_index]
        
        q_eval = self.Q_eval.forward(state_batch)[batch_index, action_batch]
        q_next = self.Q_eval.forward(next_state_batch)
        q_next[terminal_batch] = 0.0
        
        q_target = reward_batch + self.gamma + torch.max(q_next, dim = 1)[0]
        
        loss = self.Q_eval.loss(q_target, q_eval).to(disp)
        loss.backward()
        self.Q_eval.optimizer.step()

In [318]:
agent = Agent(0.99, 1.0)
env = SatSolverEnv()

for i in range(10):
    score = 0
    done = False
    obs = env.reset()
    
    while not done:
        action = agent.choose_action(obs)
        obs_next, reward, terminated = env.step(action)
        
        score += reward
        
        done = terminated
        
        agent.store_transition(obs, action, reward, obs_next, terminated)
        
        agent.learn()
        
        obs = obs_next
        
    print("Score: " , score)

-1
-2
-1
0
Score:  -10503
1
Score:  0
1
-2
-1
Score:  -10502
0
2
-1
-2
///bi 127


TypeError: only integer scalar arrays can be converted to a scalar index