In [16]:
from pettingzoo.classic import connect_four_v3

env = connect_four_v3.env()
env.reset(seed=42)

for agent in env.agent_iter():
    observation, reward, termination, truncation, info = env.last()

    if termination or truncation:
        action = None
    else:
        mask = observation["action_mask"]
        # this is where you would insert your policy
        action = env.action_space(agent).sample(mask)

    env.step(action)
env.close()

In [29]:
class ConnectFour:
    def __init__(self):
        self.board = [[0]*7 for _ in range(6)]
        self.current_player = 1  # Start with player 1

    def make_move(self, column):
        if self.board[0][column] != 0:
            return False  # Column is full
        for row in reversed(range(6)):
            if self.board[row][column] == 0:
                self.board[row][column] = self.current_player
                return True
        return False

    def available_moves(self):
        return [c for c in range(7) if self.board[0][c] == 0]

    def is_winner(self, player):
        # Check horizontal, vertical, and diagonal for wins
        # Check for horizontal wins
        for row in range(6):
            for col in range(4):  # Only need to start check from first 4 columns
                if all(self.board[row][col+i] == player for i in range(4)):
                    return True
        # Check for vertical wins
        for col in range(7):
            for row in range(3):  # Only need to start check from first 3 rows
                if all(self.board[row+i][col] == player for i in range(4)):
                    return True
        # Check for diagonal wins
        for row in range(3):
            for col in range(4):
                if all(self.board[row+i][col+i] == player for i in range(4)):
                    return True
                if all(self.board[row+3-i][col+i] == player for i in range(4)):
                    return True
        return False

    def is_draw(self):
        return all(self.board[0][col] != 0 for col in range(7))

    def is_terminal(self):
        return self.is_winner(1) or self.is_winner(2) or self.is_draw()

    def switch_player(self):
        self.current_player = 2 if self.current_player == 1 else 1


In [None]:
class Node:
    def __init__(self, game_state, move=None, parent=None, action_mask=None):
        self.game_state = deepcopy(game_state)
        self.move = move
        self.parent = parent
        self.children = []
        self.wins = 0
        self.visits = 0
        if action_mask is not None:
            self.untried_moves = [move for move in game_state.available_moves() if action_mask[move]]
        else:
            self.untried_moves = game_state.available_moves()
        self.player_just_moved = game_state.current_player


    def expand(self):
        move = self.untried_moves.pop()
        next_state = deepcopy(self.game_state)
        next_state.make_move(move)
        next_state.switch_player()
        child_node = Node(next_state, move, self)
        self.children.append(child_node)
        return child_node

    def update(self, result):
        self.visits += 1
        if self.player_just_moved == result:
            self.wins += 1

    def best_child(self, c_param=1.41):
        choices_weights = [
            (child.wins / child.visits) + c_param * math.sqrt((2 * math.log(self.visits) / child.visits))
            for child in self.children
        ]
        return self.children[choices_weights.index(max(choices_weights))]



In [44]:
def monte_carlo_tree_search(root, action_mask=None, max_seconds=3, max_milliseconds=0):
    end_time = datetime.now() + timedelta(seconds=max_seconds, milliseconds=max_milliseconds)
    while datetime.now() < end_time:
        node = root
        while not node.untried_moves and node.children:
            node = node.best_child()
        if node.untried_moves:
            node = node.expand()
        while not node.game_state.is_terminal():
            possible_moves = node.game_state.available_moves()
            legal_moves = [move for move in possible_moves if action_mask[move]]
            node.game_state.make_move(random.choice(legal_moves))
            node.game_state.switch_player()
        result = 1 if node.game_state.is_winner(1) else 2 if node.game_state.is_winner(2) else 0
        while node is not None:
            node.update(result)
            node = node.parent
    return root.best_child(c_param=0).move

In [42]:
class Trainer:
    def __init__(self, iterations=1000):
        self.game = ConnectFour()
        self.iterations = iterations

    def train(self):
        for i in range(self.iterations):
            game = ConnectFour()
            root = Node(game)
            while not game.is_terminal():
                action_mask = [True] * 7  # In training, assume all moves are initially legal
                move = monte_carlo_tree_search(root, action_mask=action_mask)
                game.make_move(move)
                game.switch_player()
                root = Node(game)  # Reset the root for the next decision
            print(f"Training game {i+1} completed")

from pettingzoo.classic import connect_four_v3

class PettingZooAgent:
    def __init__(self, train_iterations=1):
        self.trainer = Trainer(train_iterations)
        self.trainer.train()  # Train the agent during initialization

    def update_game_state(self, observation):
        # Reset the board to an empty state
        self.game.board = [[0] * 7 for _ in range(6)]
    
        # Determine who's the next player to move
        # Count the number of tokens for each player
        player1_count = 0
        player2_count = 0
    
        for row in range(6):
            for col in range(7):
                if observation[row, col, 0] == 1:
                    self.game.board[row][col] = 1
                    player1_count += 1
                elif observation[row, col, 1] == 1:
                    self.game.board[row][col] = 2
                    player2_count += 1
    
        # Determine who's turn it is based on the number of tokens
        if player1_count <= player2_count:
            self.trainer.game.current_player = 1
        else:
            self.trainer.game.current_player = 2


    def choose_action(self, observation, action_mask):
        self.update_game_state(observation)  # Update the game state based on the latest observation

        if action_mask is None:
            action_mask = [True] * 7  # Fallback in case no mask is provided

        # Initialize a root node for MCTS with the current game state
        root_node = Node(self.game, action_mask=action_mask)
        best_move = monte_carlo_tree_search(root_node, action_mask)
        return best_move


In [43]:
pz_agent = PettingZooAgent()

<__main__.Node object at 0x72858f75de10>
<__main__.Node object at 0x72858f75fbe0>
<__main__.Node object at 0x7285ad8195d0>
<__main__.Node object at 0x72858edaabf0>
<__main__.Node object at 0x7285adf32c80>
<__main__.Node object at 0x72858f52b160>
<__main__.Node object at 0x72858f51b100>
<__main__.Node object at 0x72858edf8190>
<__main__.Node object at 0x7285addb1720>
<__main__.Node object at 0x7285adba2ec0>
<__main__.Node object at 0x72858f75fa00>
<__main__.Node object at 0x72858ed47130>
<__main__.Node object at 0x7285bcb0b5e0>
<__main__.Node object at 0x72858f75e680>
<__main__.Node object at 0x7285af25fc70>
<__main__.Node object at 0x72858fffadd0>
<__main__.Node object at 0x72858e3df880>
<__main__.Node object at 0x72858fa88610>
<__main__.Node object at 0x7285af4c62c0>
<__main__.Node object at 0x7285ae49e440>
<__main__.Node object at 0x72858e7a0160>
<__main__.Node object at 0x7285ae4fae00>
<__main__.Node object at 0x7285adb7b9d0>
<__main__.Node object at 0x7285ac562ef0>
<__main__.Node o

In [41]:
# Create an instance of the PettingZoo agent and use it
env = connect_four_v3.env()
env.reset()
observation, reward, termination, truncation, info = env.last()

for agent_name in env.agent_iter():
    observation, reward, termination, truncation, info = env.last()
    action_mask = observation['action_mask']
    if termination or truncation:
        action = None
    else:
        action = pz_agent.choose_action(observation, action_mask)
    env.step(action)
env.close()

AttributeError: 'PettingZooAgent' object has no attribute 'game'