In [251]:
from collections import deque
import copy
import pdb

In [574]:
class Node(object):
    adj_dict = {}
    player1_color_score ={}
    player2_color_score ={}
    
    def __init__(self, level, states_list, colors_list):
        self.player1_states = []
        self.player2_states = []
        self.level = level
        self.state_colored_dict = {}
        self.state_none_colored_dict = {}
        for state in states_list:
            self.state_none_colored_dict[state] = set(colors_list)
            
    def expand(self, player):
        """Generate a queue of next nodes, all are checked by the ac3 
        to make sure the validity."""
        neibor_state_list = []
        for state in self.state_colored_dict:
            for adj_state in Node.adj_dict[state]:
                if adj_state not in self.state_colored_dict:
                    neibor_state_list.append(adj_state)
                
        next_nodes_list = []
        for state in sorted(neibor_state_list):
            for color in sorted(self.state_none_colored_dict[state]):
                next_node = copy.deepcopy(self)
                next_node.set_color(player, (state, color))
                if next_node.arc_consistency3():
                    next_node.level += 1
                    next_nodes_list.append(next_node)
        #pdb.set_trace()
        return next_nodes_list
    
    def arc_consistency3(self):
        """Perform arc-consistency check on the node."""
        
        # construct the queue of arcs
        queue = deque([])
        for start in Node.adj_dict.keys():
            for end in Node.adj_dict[start]:
                # only check those with color
                if start in self.state_colored_dict and end in self.state_colored_dict:
                    queue.append((start, end))
                
        while queue:
            # get the edge start -> end
            start, end = queue.popleft()
            
            # if the end has not been set a color, and it has more than 1 color left for choosing
            # then it must be legit
            # if end in self.state_none_colored_dict and len(self.state_none_colored_dict[end]) > 1:
            #    continue
            
            # if the start node has been set a color
            # if start not in self.state_none_colored_dict:
                
                # get the end color
            #    if end in self.state_none_colored_dict:
            #        end_color = self.state_none_colored_dict[end].pop()
            #        self.state_none_colored_dict[end].add(end_color)
            #    else:
            end_color = self.state_colored_dict[end]
                    
                # if the end has the same color it is not consistency
            if self.state_colored_dict[start] == end_color:
                return False
            # continue
                
            # the start node has not been set and we need to reduce the possible choice
            """
            revised = False
            color_reomve = []
            for start_color in self.state_none_colored_dict[start]:
                if end in self.state_none_colored_dict:
                    end_color = self.state_none_colored_dict[end].pop()
                    self.state_none_colored_dict[end].add(end_color)
                else:
                    end_color = self.state_colored_dict[end]
                if start_color == end_color:
                    #self.state_none_colored_dict[start].remove(start_color)
                    color_reomve.append(start_color)
                    revised = True
                        
            if revised:
                
                # remove unconsistency color
                for color in color_reomve:
                    self.state_none_colored_dict[start].remove(color)
                    
                # check whether the node is empty
                if not self.state_none_colored_dict[start]:
                    return False
                
                
                for parent in Node.adj_dict[start]:
                    if parent not in self.state_none_colored_dict:
                        if len(self.state_none_colored_dict[start]) == 1:
                            start_left_color = self.state_none_colored_dict[start].pop()
                            if start_left_color == self.state_colored_dict[parent]:
                                #pdb.set_trace()
                                return False
                            else:
                                self.state_none_colored_dict[start].add(start_left_color)
                                continue
                        else:
                            continue
                    queue.append((parent, start))
            """
        return True
    
    def set_color(self, player, action):
        """Update the node by the action, (state, color). If the node 
        is valid after ac3, then we return true, else return false"""
        state, color = action
        del self.state_none_colored_dict[state]
        self.state_colored_dict[state] = color
        if player == 1:
            self.player1_states.append(state)
        else:
            self.player2_states.append(state)
        
    def get_eval(self):
        score = sum([Node.player1_color_score[self.state_colored_dict[t]] for t in self.player1_states])
        score -= sum([Node.player2_color_score[self.state_colored_dict[t]] for t in self.player2_states])
        return score

In [575]:
# minmax with alpha-beta pruning implementation
log = []

def alpha_beta_search(node, max_level):
    """Search the path to ge the max value"""
    return max_value(node, -float('inf'), float('inf'), max_level)


def max_value(node, alpha, beta, max_level):
    """Select the next step to maximize the value, 
    return the path and value"""
    #pdb.set_trace()
    next_nodes_list = node.expand(1)
    #pdb.set_trace()
    
    value = -float('inf')
    if node.level == max_level or not next_nodes_list:
        value = node.get_eval()
        
    path = None
    log.append(node.player2_states[-1] 
               + ', ' 
               + node.state_colored_dict[node.player2_states[-1]]
               + ', '
               + str(node.level)
               + ', '
               + str(value)
               + ', '
               + str(alpha)
               + ', '
               + str(beta)
              )

    if node.level == max_level or not next_nodes_list:
        return node.get_eval(), []
    
    for next_node in next_nodes_list:
        next_value, next_path = min_value(next_node, alpha, beta, max_level)
        
        if next_value > value:
            value = next_value
            path = next_path
            path.insert(0, (next_node.player1_states[-1], next_node.state_colored_dict[next_node.player1_states[-1]]))
            
        
        
        if value >= beta:
            if not path:
                pdb.set_trace()
            log.append(node.player2_states[-1] 
               + ', ' 
               + node.state_colored_dict[node.player2_states[-1]]
               + ', '
               + str(node.level)
               + ', '
               + str(value)
               + ', '
               + str(alpha)
               + ', '
               + str(beta)
              )
            return value, path
        alpha = max(alpha, value)
        log.append(node.player2_states[-1] 
               + ', ' 
               + node.state_colored_dict[node.player2_states[-1]]
               + ', '
               + str(node.level)
               + ', '
               + str(value)
               + ', '
               + str(alpha)
               + ', '
               + str(beta)
              )
    return value, path


def min_value(node, alpha, beta, max_level):
    """Select the next step to minimize the value,
    return the path and value"""
    #pdb.set_trace()
    next_nodes_list = node.expand(2)
    
    value = +float('inf')
    if node.level == max_level or not next_nodes_list:
        value = node.get_eval()
    
    path = None
    
    log.append(node.player1_states[-1] 
               + ', ' 
               + node.state_colored_dict[node.player1_states[-1]]
               + ', '
               + str(node.level)
               + ', '
               + str(value)
               + ', '
               + str(alpha)
               + ', '
               + str(beta)
              )
            
    if node.level == max_level or not next_nodes_list:
        return node.get_eval(), []
    
    for next_node in next_nodes_list:
        next_value, next_path = max_value(next_node, alpha, beta, max_level)
        
        if next_value < value:
            value = min(value, next_value)
            path = next_path
            path.insert(0, (next_node.player2_states[-1], next_node.state_colored_dict[next_node.player2_states[-1]]))
            
        if value <= alpha:
            log.append(node.player1_states[-1] 
               + ', ' 
               + node.state_colored_dict[node.player1_states[-1]]
               + ', '
               + str(node.level)
               + ', '
               + str(value)
               + ', '
               + str(alpha)
               + ', '
               + str(beta)
              )
            return value, path
            #if not path:
                #pdb.set_trace()
                
        beta = min(beta, value)
        log.append(node.player1_states[-1] 
               + ', ' 
               + node.state_colored_dict[node.player1_states[-1]]
               + ', '
               + str(node.level)
               + ', '
               + str(value)
               + ', '
               + str(alpha)
               + ', '
               + str(beta)
              )
    #if not path:
        #pdb.set_trace()
    return value, path

In [576]:
node = Node(0, ['SA', 'NT', 'NSW', 'WA', 'Q', 'V'], ['R','G','B'])

In [577]:
node.set_color(1, ('WA','R'))
node.set_color(2, ('SA','G'))

In [578]:
Node.player1_color_score = {'R': 10, 'G':5, 'B':0}
Node.player2_color_score = {'R': 0, 'G':2, 'B':8}
Node.adj_dict = {'SA':['WA', 'NT', 'Q', 'NSW', 'V'], 
                 'NT':['WA', 'SA', 'Q'], 
                 'NSW':['Q', 'V', 'SA'], 
                 'WA': ['SA', 'NT'], 
                 'Q': ['NT', 'SA', 'NSW'], 
                 'V': ['SA', 'NSW'] }

In [579]:
value, path = alpha_beta_search(node, 3)

In [580]:
path

[('NSW', 'B'), ('NT', 'B'), ('Q', 'R')]

In [581]:
for line in log:
    print line
print path[0][0] + ', ' + path[0][1] + ', ' + str(value)

SA, G, 0, -inf, -inf, inf
NSW, B, 1, inf, -inf, inf
NT, B, 2, -inf, -inf, inf
Q, R, 3, 10, -inf, inf
NT, B, 2, 10, 10, inf
Q, R, 3, 10, 10, inf
NT, B, 2, 10, 10, inf
Q, R, 3, 10, 10, inf
NT, B, 2, 10, 10, inf
V, R, 3, 10, 10, inf
NT, B, 2, 10, 10, inf
V, R, 3, 10, 10, inf
NT, B, 2, 10, 10, inf
NSW, B, 1, 10, -inf, 10
NT, B, 2, -inf, -inf, 10
Q, R, 3, 10, -inf, 10
NT, B, 2, 10, -inf, 10
NSW, B, 1, 10, -inf, 10
Q, R, 2, -inf, -inf, 10
NT, B, 3, 8, -inf, 10
Q, R, 2, 8, 8, 10
NT, B, 3, 8, 8, 10
Q, R, 2, 8, 8, 10
NT, B, 3, 8, 8, 10
Q, R, 2, 8, 8, 10
V, R, 3, 18, 8, 10
Q, R, 2, 18, 8, 10
NSW, B, 1, 10, -inf, 10
Q, R, 2, -inf, -inf, 10
NT, B, 3, 8, -inf, 10
Q, R, 2, 8, 8, 10
NT, B, 3, 8, 8, 10
Q, R, 2, 8, 8, 10
NT, B, 3, 8, 8, 10
Q, R, 2, 8, 8, 10
V, R, 3, 18, 8, 10
Q, R, 2, 18, 8, 10
NSW, B, 1, 10, -inf, 10
V, R, 2, -inf, -inf, 10
NT, B, 3, 8, -inf, 10
V, R, 2, 8, 8, 10
NT, B, 3, 8, 8, 10
V, R, 2, 8, 8, 10
Q, R, 3, 18, 8, 10
V, R, 2, 18, 8, 10
NSW, B, 1, 10, -inf, 10
V, R, 2, -inf, -inf, 10


In [587]:
!cat testcases/t0.txt

R, G, B
XD42U3: R-1, L3G71B: G-2
3
R: 3, G: 6, B: 12
R: 6, G: 5, B: 4
M0RQA5: AGD6C7, 8QC7VG, 8DS889, Y1J1CO, GQIDX5, J6GBY7, U0CE4S
EEBSPN: BKFOIG, AGD6C7, SS3FJQ, 4BNSMR, J6GBY7, U0CE4S
UI8B76: KTEXTA, XD42U3, C38K1U, 4G3X5F, S4SJY8, F4QKUB
F4QKUB: UI8B76, L3G71B, XD42U3, AGD6C7, BKFOIG
4G3X5F: UI8B76, AGD6C7, 8JWSH1, 8DS889, Y1J1CO, 3UJ31G, J6GBY7
8DS889: 4G3X5F, M0RQA5, KTEXTA, 0F8M35
ZP02UB: 9GOQ8R
J6GBY7: M0RQA5, EEBSPN, 9GOQ8R, 4G3X5F, C38K1U, GQIDX5, U0CE4S
BKFOIG: F4QKUB, PCDM0E, 8JWSH1, C38K1U, EEBSPN
AGD6C7: L3G71B, M0RQA5, EEBSPN, XD42U3, 4G3X5F, S9C85O, F4QKUB
SS3FJQ: 8QC7VG, S9C85O, U0CE4S, EEBSPN
MRAN53: H6W1E3, C38K1U, KTEXTA, S4SJY8
8JWSH1: BKFOIG, 4G3X5F, S4SJY8, GQIDX5, S9C85O, 0F8M35, U0CE4S
8QC7VG: M0RQA5, S9C85O, SS3FJQ, Y1J1CO, U0CE4S
0F8M35: 8DS889, 8JWSH1, KTEXTA, Y1J1CO
KTEXTA: L3G71B, UI8B76, XD42U3, C38K1U, MRAN53, S4SJY8, 8DS889, 3UJ31G, 0F8M35, U0CE4S
XD42U3: L3G71B, KTEXTA, UI8B76, AGD6C7, H6W1E3, F4QKUB, U0CE4S
H6W1E3: 4BNSMR, XD42U

In [588]:
with open('testcases/t0.txt', 'r') as f:
    lines = f.read().splitlines()

color_list = lines[0].split(', ')

pre_set_color = [] # (state, color, player)
for seg in lines[1].split(', '):
    data = seg.split(': ')
    state = data[0]
    data_2 = data[1].split('-')
    color = data_2[0]
    player = int(data_2[1])
    pre_set_color.append((state, color, player))
    
max_level = int(lines[2])

player1_score_dict = {}
for seg in lines[3].split(', '):
    data = seg.split(': ')
    player1_score_dict[data[0]] = int(data[1])
    
player2_score_dict = {}
for seg in lines[4].split(', '):
    data = seg.split(': ')
    player2_score_dict[data[0]] = int(data[1])

state_set = set()
adj_dict={}
for line in lines[5:]:
    data = line.split(': ')
    state = data[0]
    state_set.add(state)
    adj_list = []
    for adj in data[1].split(', '):
        adj_list.append(adj.strip())
    adj_dict[state] = adj_list
    

In [589]:
print color_list
print pre_set_color
print max_level
print player1_score_dict
print player2_score_dict
print adj_dict
print state_set

['R', 'G', 'B']
[('XD42U3', 'R', 1), ('L3G71B', 'G', 2)]
3
{'B': 12, 'R': 3, 'G': 6}
{'B': 4, 'R': 6, 'G': 5}
{'M0RQA5': ['AGD6C7', '8QC7VG', '8DS889', 'Y1J1CO', 'GQIDX5', 'J6GBY7', 'U0CE4S'], 'EEBSPN': ['BKFOIG', 'AGD6C7', 'SS3FJQ', '4BNSMR', 'J6GBY7', 'U0CE4S'], 'UI8B76': ['KTEXTA', 'XD42U3', 'C38K1U', '4G3X5F', 'S4SJY8', 'F4QKUB'], 'F4QKUB': ['UI8B76', 'L3G71B', 'XD42U3', 'AGD6C7', 'BKFOIG'], '3UJ31G': ['4BNSMR', 'KTEXTA', '4G3X5F'], '8DS889': ['4G3X5F', 'M0RQA5', 'KTEXTA', '0F8M35'], 'ZP02UB': ['9GOQ8R'], 'J6GBY7': ['M0RQA5', 'EEBSPN', '9GOQ8R', '4G3X5F', 'C38K1U', 'GQIDX5', 'U0CE4S'], 'BKFOIG': ['F4QKUB', 'PCDM0E', '8JWSH1', 'C38K1U', 'EEBSPN'], 'AGD6C7': ['L3G71B', 'M0RQA5', 'EEBSPN', 'XD42U3', '4G3X5F', 'S9C85O', 'F4QKUB'], 'SS3FJQ': ['8QC7VG', 'S9C85O', 'U0CE4S', 'EEBSPN'], 'MRAN53': ['H6W1E3', 'C38K1U', 'KTEXTA', 'S4SJY8'], '8JWSH1': ['BKFOIG', '4G3X5F', 'S4SJY8', 'GQIDX5', 'S9C85O', '0F8M35', 'U0CE4S'], '8QC7VG': ['M0RQA5', 'S9C85O', 'SS3FJQ', 'Y1J1CO', 'U0CE4S'], '0F8M35': [

In [593]:
node2 = Node(0, list(state_set), color_list)
Node.player1_color_score = player1_score_dict
Node.player2_color_score = player2_score_dict
Node.adj_dict = adj_dict

In [594]:
log=[]

In [595]:
value, path = alpha_beta_search(node2, max_level)

IndexError: list index out of range

> [0;32m<ipython-input-575-cbdc3653b6a7>[0m(31)[0;36mmax_value[0;34m()[0m
[0;32m     29 [0;31m               [0;34m+[0m [0mstr[0m[0;34m([0m[0malpha[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m     30 [0;31m               [0;34m+[0m [0;34m', '[0m[0;34m[0m[0m
[0m[0;32m---> 31 [0;31m               [0;34m+[0m [0mstr[0m[0;34m([0m[0mbeta[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m     32 [0;31m              )
[0m[0;32m     33 [0;31m[0;34m[0m[0m
[0m
ipdb> len(log)
0
ipdb> str(beta)
'inf'
ipdb> c


In [596]:
for line in log:
    print line
print path[0][0] + ', ' + path[0][1] + ', ' + str(value)

NSW, B, 10


In [598]:
noe = Node(0, [], [])