In [1]:

from pettingzoo.utils import aec_to_parallel, parallel_to_aec
from pettingzoo.utils.wrappers import AssertOutOfBoundsWrapper, OrderEnforcingWrapper

from env.aquarium import raw_env

In [2]:

def env2(
    render_mode: str = "human",
    observable_walls: int = 2,
    width: int = 800,
    height: int = 800,
    caption: str = "Aquarium",
    fps: int = 60,
    max_time_steps: int = 3000,
    action_count: int = 16,
    predator_count: int = 1,
    prey_count: int = 16,
    predator_observe_count: int = 1,
    prey_observe_count: int = 3,
    draw_force_vectors: bool = False,
    draw_action_vectors: bool = False,
    draw_view_cones: bool = False,
    draw_hit_boxes: bool = False,
    draw_death_circles: bool = False,
    fov_enabled: bool = True,
    keep_prey_count_constant: bool = True,
    prey_radius: int = 20,
    prey_max_acceleration: float = 1,
    prey_max_velocity: float = 4,
    prey_view_distance: int = 100,
    prey_replication_age: int = 200,
    prey_max_steer_force: float = 0.6,
    prey_fov: int = 120,
    prey_reward: int = 1,
    prey_punishment: int = 1000,
    max_prey_count: int = 20,
    predator_max_acceleration: float = 0.6,
    predator_radius: int = 30,
    predator_max_velocity: float = 5,
    predator_view_distance: int = 200,
    predator_max_steer_force: float = 0.6,
    predator_max_age: int = 3000,
    predator_fov: int = 150,
    predator_reward: int = 10,
    catch_radius: int = 100,
    procreate: bool = False,
):
    """Returns the AEC environment"""
    env_aec = parallel_to_aec(
        raw_env(
            render_mode=render_mode,
            observable_walls=observable_walls,
            width=width,
            height=height,
            caption=caption,
            fps=fps,
            max_time_steps=max_time_steps,
            action_count=action_count,
            predator_count=predator_count,
            prey_count=prey_count,
            predator_observe_count=predator_observe_count,
            prey_observe_count=prey_observe_count,
            draw_force_vectors=draw_force_vectors,
            draw_action_vectors=draw_action_vectors,
            draw_view_cones=draw_view_cones,
            draw_hit_boxes=draw_hit_boxes,
            draw_death_circles=draw_death_circles,
            fov_enabled=fov_enabled,
            keep_prey_count_constant=keep_prey_count_constant,
            prey_radius=prey_radius,
            prey_max_acceleration=prey_max_acceleration,
            prey_max_velocity=prey_max_velocity,
            prey_view_distance=prey_view_distance,
            prey_replication_age=prey_replication_age,
            prey_max_steer_force=prey_max_steer_force,
            prey_fov=prey_fov,
            prey_reward=prey_reward,
            prey_punishment=prey_punishment,
            max_prey_count=max_prey_count,
            predator_max_acceleration=predator_max_acceleration,
            predator_radius=predator_radius,
            predator_max_velocity=predator_max_velocity,
            predator_view_distance=predator_view_distance,
            predator_max_steer_force=predator_max_steer_force,
            predator_max_age=predator_max_age,
            predator_fov=predator_fov,
            predator_reward=predator_reward,
            catch_radius=catch_radius,
            procreate=procreate,
        )
    )
    env_aec = AssertOutOfBoundsWrapper(env_aec)
    env_aec = OrderEnforcingWrapper(env_aec)

    return env_aec



In [3]:

def parallel_env(
    render_mode: str = "human",
    observable_walls: int = 2,
    width: int = 800,
    height: int = 800,
    caption: str = "Aquarium",
    fps: int = 60,
    max_time_steps: int = 3000,
    action_count: int = 16,
    predator_count: int = 1,
    prey_count: int = 16,
    predator_observe_count: int = 1,
    prey_observe_count: int = 3,
    draw_force_vectors: bool = False,
    draw_action_vectors: bool = False,
    draw_view_cones: bool = False,
    draw_hit_boxes: bool = False,
    draw_death_circles: bool = False,
    fov_enabled: bool = True,
    keep_prey_count_constant: bool = True,
    prey_radius: int = 20,
    prey_max_acceleration: float = 1,
    prey_max_velocity: float = 4,
    prey_view_distance: int = 100,
    prey_replication_age: int = 200,
    prey_max_steer_force: float = 0.6,
    prey_fov: int = 120,
    prey_reward: int = 1,
    prey_punishment: int = 1000,
    max_prey_count: int = 20,
    predator_max_acceleration: float = 0.6,
    predator_radius: int = 30,
    predator_max_velocity: float = 5,
    predator_view_distance: int = 200,
    predator_max_steer_force: float = 0.6,
    predator_max_age: int = 3000,
    predator_fov: int = 150,
    predator_reward: int = 10,
    catch_radius: int = 100,
    procreate: bool = False,
):
    """Returns the parallel environment"""
    return aec_to_parallel(
        env(
            render_mode=render_mode,
            observable_walls=observable_walls,
            width=width,
            height=height,
            caption=caption,
            fps=fps,
            max_time_steps=max_time_steps,
            action_count=action_count,
            predator_count=predator_count,
            prey_count=prey_count,
            predator_observe_count=predator_observe_count,
            prey_observe_count=prey_observe_count,
            draw_force_vectors=draw_force_vectors,
            draw_action_vectors=draw_action_vectors,
            draw_view_cones=draw_view_cones,
            draw_hit_boxes=draw_hit_boxes,
            draw_death_circles=draw_death_circles,
            fov_enabled=fov_enabled,
            keep_prey_count_constant=keep_prey_count_constant,
            prey_radius=prey_radius,
            prey_max_acceleration=prey_max_acceleration,
            prey_max_velocity=prey_max_velocity,
            prey_view_distance=prey_view_distance,
            prey_replication_age=prey_replication_age,
            prey_max_steer_force=prey_max_steer_force,
            prey_fov=prey_fov,
            prey_reward=prey_reward,
            prey_punishment=prey_punishment,
            max_prey_count=max_prey_count,
            predator_max_acceleration=predator_max_acceleration,
            predator_radius=predator_radius,
            predator_max_velocity=predator_max_velocity,
            predator_view_distance=predator_view_distance,
            predator_max_steer_force=predator_max_steer_force,
            predator_max_age=predator_max_age,
            predator_fov=predator_fov,
            predator_reward=predator_reward,
            catch_radius=catch_radius,
            procreate=procreate,
        )
    )


In [4]:

env = env2(
    # draw_force_vectors=True,
    # draw_action_vectors=True,
    # draw_view_cones=True,
    # draw_hit_boxes=True,
    # draw_death_circles=True,
    procreate = True
)
env.reset(seed=42)


In [11]:
for agent in env.agent_iter():
    observation, reward, termination, truncation, info = env.last()
    len(env.agents)

    if termination or truncation:
        action = None
    else:
        # this is where you would insert your policy
        action = env.action_space(agent).sample()

    env.step(action)
    env.render()
env.close()

error: display Surface quit

In [12]:
class AquariumRunner:
    def __init__(self, env):
        self.env = env
        self.episode_count = 0
        self.step_count = 0
        
    def run_episode(self):
        # 重置环境
        self.env.reset()
        episode_rewards = {agent: 0 for agent in self.env.possible_agents}
        
        # 使用agent_iter()来迭代每个agent
        for agent in self.env.agent_iter():
            # 获取当前agent的状态
            observation, reward, termination, truncation, info = self.env.last()
            
            # 更新奖励
            if agent in episode_rewards:
                episode_rewards[agent] += reward
            
            # 决定action
            if termination or truncation:
                action = None
            else:
                # 这里可以替换为实际的策略
                action = self.env.action_space(agent).sample()
            
            # 执行action
            self.env.step(action)
            
            try:
                self.env.render()
            except Exception as e:
                print(f"Render error: {e}")
            
            # 更新步数
            self.step_count += 1
            
            # 检查是否所有agent都终止了
            if len(self.env.agents) == 0:
                break
                
        self.episode_count += 1
        return episode_rewards

def run_aquarium(episodes=10, render=True):
    try:
        # 创建环境
        env = env2(
            render_mode="human" if render else None,
            predator_count=1,
            prey_count=16,
            max_time_steps=3000
        )
        
        runner = AquariumRunner(env)
        
        for episode in range(episodes):
            print(f"\nStarting Episode {episode + 1}")
            
            # 运行一个episode
            rewards = runner.run_episode()
            
            # 打印统计信息
            print(f"\nEpisode {episode + 1} Summary:")
            print(f"Steps completed: {runner.step_count}")
            print(f"Active agents: {len(env.agents)}")
            print("Rewards:", {k: round(v, 2) for k, v in rewards.items()})
            
            # 打印存活统计
            predators = len([a for a in env.agents if a.startswith("predator")])
            prey = len([a for a in env.agents if a.startswith("prey")])
            print(f"Surviving predators: {predators}")
            print(f"Surviving prey: {prey}")
            
    except Exception as e:
        print(f"Error: {e}")
        
    finally:
        env.close()

# 运行示例
if __name__ == "__main__":
    run_aquarium(episodes=10, render=True)


Starting Episode 1


SystemExit: 

In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random

# DQN网络
class DQN(nn.Module):
    def __init__(self, input_size, output_size):
        super(DQN, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, output_size)
        )
        
    def forward(self, x):
        return self.network(x)

# 经验回放缓冲区
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)
    
    def __len__(self):
        return len(self.buffer)

# DQN Agent
class DQNAgent:
    def __init__(self, state_size, action_size, agent_type, learning_rate=0.001):
        self.state_size = state_size
        self.action_size = action_size
        self.agent_type = agent_type
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # 创建Q网络和目标网络
        self.q_network = DQN(state_size, action_size).to(self.device)
        self.target_network = DQN(state_size, action_size).to(self.device)
        self.target_network.load_state_dict(self.q_network.state_dict())
        
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=learning_rate)
        self.memory = ReplayBuffer(10000)
        
        # 超参数
        self.batch_size = 64
        self.gamma = 0.99
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.target_update = 10
        self.update_counter = 0
        
    def select_action(self, state, training=True):
        if training and random.random() < self.epsilon:
            return random.randrange(self.action_size)
        
        with torch.no_grad():
            state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
            q_values = self.q_network(state)
            return q_values.argmax().item()
    
    def train(self):
        if len(self.memory) < self.batch_size:
            return
        
        # 从经验回放中采样
        batch = self.memory.sample(self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        
        # 转换为tensor
        states = torch.FloatTensor(states).to(self.device)
        actions = torch.LongTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.FloatTensor(dones).to(self.device)
        
        # 计算当前Q值
        current_q = self.q_network(states).gather(1, actions.unsqueeze(1))
        
        # 计算目标Q值
        with torch.no_grad():
            next_q = self.target_network(next_states).max(1)[0]
            target_q = rewards + (1 - dones) * self.gamma * next_q
        
        # 计算损失并更新
        loss = nn.MSELoss()(current_q.squeeze(), target_q)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # 更新目标网络
        self.update_counter += 1
        if self.update_counter % self.target_update == 0:
            self.target_network.load_state_dict(self.q_network.state_dict())
        
        # 衰减探索率
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
        
        return loss.item()

# 训练环境
class AquariumTrainer:
    def __init__(self, env):
        self.env = env
        
        # 获取观察空间和动作空间的大小
        self.predator_obs_size = len(self.env.reset()[0]['predator_0'])
        self.prey_obs_size = len(self.env.reset()[0]['prey_0'])
        self.action_size = 8  # 8个方向的动作
        
        # 创建predator和prey的agents
        self.predator_agent = DQNAgent(self.predator_obs_size, self.action_size, "predator")
        self.prey_agents = {}  # 为每个prey创建一个agent
        
    def train_episode(self, episode):
        # 重置环境
        observations = self.env.reset()[0]
        episode_rewards = {agent: 0 for agent in self.env.possible_agents}
        total_loss = 0
        steps = 0
        
        # 为新的prey创建agents
        for agent_id in self.env.agents:
            if agent_id.startswith("prey") and agent_id not in self.prey_agents:
                self.prey_agents[agent_id] = DQNAgent(self.prey_obs_size, self.action_size, "prey")
        
        # 运行一个episode
        for agent in self.env.agent_iter():
            observation, reward, termination, truncation, info = self.env.last()
            
            # 更新奖励
            if agent in episode_rewards:
                episode_rewards[agent] += reward
            
            # 选择动作
            if termination or truncation:
                action = None
            else:
                if agent.startswith("predator"):
                    action = self.predator_agent.select_action(observation)
                else:
                    action = self.prey_agents[agent].select_action(observation)
            
            # 执行动作
            self.env.step(action)
            
            # 获取新的状态
            next_observation = self.env.observe(agent) if not (termination or truncation) else None
            
            # 存储经验
            if next_observation is not None:
                if agent.startswith("predator"):
                    self.predator_agent.memory.push(
                        observation, action, reward, next_observation, 
                        termination or truncation
                    )
                else:
                    self.prey_agents[agent].memory.push(
                        observation, action, reward, next_observation, 
                        termination or truncation
                    )
            
            # 训练
            if agent.startswith("predator"):
                loss = self.predator_agent.train()
            else:
                loss = self.prey_agents[agent].train()
            
            if loss is not None:
                total_loss += loss
            
            steps += 1
            
            # 可选：渲染环境
            if episode % 100 == 0:
                self.env.render()
            
            # 检查是否结束
            if len(self.env.agents) == 0:
                break
        
        # 计算平均损失
        avg_loss = total_loss / steps if steps > 0 else 0
        
        # 打印训练信息
        if episode % 10 == 0:
            print(f"\nEpisode {episode}")
            print(f"Average Loss: {avg_loss:.4f}")
            print(f"Predator Epsilon: {self.predator_agent.epsilon:.4f}")
            print("Rewards:", {k: round(v, 2) for k, v in episode_rewards.items()})
            print(f"Steps: {steps}")
            
        return episode_rewards, avg_loss

def train_aquarium(episodes=1000):
    env = env2(render_mode="human", predator_count=1, prey_count=16)
    trainer = AquariumTrainer(env)
    
    try:
        for episode in range(episodes):
            rewards, loss = trainer.train_episode(episode)
            
            # 保存模型（每100个episode）
            if episode % 100 == 0:
                torch.save(trainer.predator_agent.q_network.state_dict(), 
                         f'predator_model_ep{episode}.pth')
                # 保存一个prey的模型作为示例
                if trainer.prey_agents:
                    first_prey = next(iter(trainer.prey_agents.values()))
                    torch.save(first_prey.q_network.state_dict(), 
                             f'prey_model_ep{episode}.pth')
    
    except KeyboardInterrupt:
        print("\nTraining interrupted by user")
    
    finally:
        env.close()

if __name__ == "__main__":
    train_aquarium()

TypeError: 'NoneType' object is not subscriptable