In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
from env.custom_environment import CustomEnvironment

# 定义Q网络
class QNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, output_dim)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 定义策略网络
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, output_dim)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.tanh(self.fc3(x))
        return x

# 定义经验回放
class ReplayBuffer:
    def __init__(self, buffer_size):
        self.buffer = deque(maxlen=buffer_size)
    
    def add(self, experience):
        self.buffer.append(experience)
    
    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)
    
    def __len__(self):
        return len(self.buffer)

# 多智能体DDPG训练
class MultiAgentDDPG:
    def __init__(self, env, num_agents, state_size, action_size, buffer_size=10000, batch_size=64, gamma=0.99, lr=0.001, tau=0.01):
        self.env = env
        self.num_agents = num_agents
        self.state_size = state_size
        self.action_size = action_size
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.gamma = gamma
        self.lr = lr
        self.tau = tau
        
        self.q_networks = [QNetwork(state_size + action_size * num_agents, 1) for _ in range(num_agents)]
        self.target_q_networks = [QNetwork(state_size + action_size * num_agents, 1) for _ in range(num_agents)]
        self.policy_networks = [PolicyNetwork(state_size, action_size) for _ in range(num_agents)]
        self.target_policy_networks = [PolicyNetwork(state_size, action_size) for _ in range(num_agents)]
        
        self.q_optimizers = [optim.Adam(q_net.parameters(), lr=self.lr) for q_net in self.q_networks]
        self.policy_optimizers = [optim.Adam(policy_net.parameters(), lr=self.lr) for policy_net in self.policy_networks]
        
        self.replay_buffer = ReplayBuffer(buffer_size)
        
        for target_q_net, q_net in zip(self.target_q_networks, self.q_networks):
            target_q_net.load_state_dict(q_net.state_dict())
        
        for target_policy_net, policy_net in zip(self.target_policy_networks, self.policy_networks):
            target_policy_net.load_state_dict(policy_net.state_dict())
    
    def select_actions(self, states, epsilon):
        actions = []
        for i in range(self.num_agents):
            state = torch.FloatTensor(states[i]).unsqueeze(0)
            action = self.policy_networks[i](state).detach().numpy()[0]
            if random.random() < epsilon:
                action = random.randint(0, self.action_size - 1)
            actions.append(action)
        return actions
    
    def train(self, num_episodes, max_steps, epsilon_start=1.0, epsilon_end=0.1, epsilon_decay=0.995):
        epsilon = epsilon_start
        for episode in range(num_episodes):
            states = self.env.reset()
            episode_rewards = np.zeros(self.num_agents)
            for step in range(max_steps):
                actions = self.select_actions([states[agent] for agent in self.env.agents], epsilon)
                next_states, rewards, dones, _ = self.env.step({agent: action for agent, action in zip(self.env.agents, actions)})
                
                self.replay_buffer.add((states, actions, rewards, next_states, dones))
                states = next_states
                episode_rewards += np.array([rewards[agent] for agent in self.env.agents])
                
                if len(self.replay_buffer) >= self.batch_size:
                    self.update_networks()
                
                if all(dones.values()):
                    break
            
            epsilon = max(epsilon_end, epsilon_decay * epsilon)
            print(f"Episode {episode + 1}/{num_episodes}, Reward: {np.sum(episode_rewards)}")
    
    def update_networks(self):
        batch = self.replay_buffer.sample(self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        
        for i in range(self.num_agents):
            state_batch = torch.FloatTensor([s[i] for s in states])
            action_batch = torch.FloatTensor([[a[i]] for a in actions])
            reward_batch = torch.FloatTensor([r[i] for r in rewards])
            next_state_batch = torch.FloatTensor([ns[i] for ns in next_states])
            done_batch = torch.FloatTensor([d[i] for d in dones])
            
            all_states = torch.FloatTensor([np.concatenate(s) for s in states])
            all_actions = torch.FloatTensor([np.concatenate(a) for a in actions])
            all_next_states = torch.FloatTensor([np.concatenate(ns) for ns in next_states])
            
            current_q_values = self.q_networks[i](torch.cat([state_batch, all_actions.view(self.batch_size, -1)], 1))
            next_actions = [self.target_policy_networks[j](next_state_batch) for j in range(self.num_agents)]
            target_q_values = self.target_q_networks[i](torch.cat([next_state_batch, torch.cat(next_actions, 1)], 1)).detach()
            target_q_values = reward_batch + (self.gamma * target_q_values * (1 - done_batch))
            
            q_loss = nn.MSELoss()(current_q_values, target_q_values)
            self.q_optimizers[i].zero_grad()
            q_loss.backward()
            self.q_optimizers[i].step()
            
            all_current_actions = [self.policy_networks[j](state_batch) if j == i else self.policy_networks[j](torch.FloatTensor([s[j] for s in states])) for j in range(self.num_agents)]
            policy_loss = -self.q_networks[i](torch.cat([state_batch, torch.cat(all_current_actions, 1)], 1)).mean()
            self.policy_optimizers[i].zero_grad()
            policy_loss.backward()
            self.policy_optimizers[i].step()
            
            self.soft_update(self.q_networks[i], self.target_q_networks[i])
            self.soft_update(self.policy_networks[i], self.target_policy_networks[i])
    
    def soft_update(self, local_model, target_model):
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(self.tau * local_param.data + (1.0 - self.tau) * target_param.data)

# 使用自定义环境进行训练
env = CustomEnvironment()
state_size = env.max_obs_size
action_size = env.action_space(env.agents[0]).n
num_agents = len(env.agents)

multi_agent_ddpg = MultiAgentDDPG(env, num_agents, state_size, action_size)
multi_agent_ddpg.train(num_episodes=1000, max_steps=200)

KeyError: 0