In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
from pettingzoo import ParallelEnv
import matplotlib.pyplot as plt
import random
from env.custom_environment import CustomEnvironment

# 定义Agent网络
class AgentNetwork(nn.Module):
    def __init__(self, obs_shape, n_actions):
        super(AgentNetwork, self).__init__()
        self.fc1 = nn.Linear(obs_shape, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, n_actions)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

# 定义QMIX网络
class QMixNet(nn.Module):
    def __init__(self, n_agents, state_shape, mixing_embed_dim):
        super(QMixNet, self).__init__()
        self.n_agents = n_agents
        self.state_shape = state_shape
        self.embed_dim = mixing_embed_dim
        
        self.hyper_w_1 = nn.Sequential(nn.Linear(state_shape, self.embed_dim),
                                       nn.ReLU(),
                                       nn.Linear(self.embed_dim, n_agents * self.embed_dim))
        self.hyper_w_final = nn.Sequential(nn.Linear(state_shape, self.embed_dim),
                                           nn.ReLU(),
                                           nn.Linear(self.embed_dim, self.embed_dim))
        
        self.hyper_b_1 = nn.Linear(state_shape, self.embed_dim)
        self.hyper_b_final = nn.Sequential(nn.Linear(state_shape, self.embed_dim),
                                           nn.ReLU(),
                                           nn.Linear(self.embed_dim, 1))
        
        self.V = nn.Sequential(nn.Linear(state_shape, self.embed_dim),
                               nn.ReLU(),
                               nn.Linear(self.embed_dim, 1))
        
    def forward(self, agent_qs, states):
        bs = agent_qs.size(0)
        states = states.reshape(-1, self.state_shape)
        
        w1 = torch.abs(self.hyper_w_1(states)).view(-1, self.n_agents, self.embed_dim)
        b1 = self.hyper_b_1(states).view(-1, 1, self.embed_dim)
        hidden = torch.relu(torch.bmm(agent_qs.view(-1, 1, self.n_agents), w1) + b1)
        
        w_final = torch.abs(self.hyper_w_final(states)).view(-1, self.embed_dim, 1)
        b_final = self.hyper_b_final(states).view(-1, 1, 1)
        
        y = torch.bmm(hidden, w_final) + b_final
        q_tot = y.view(bs, -1, 1).squeeze(-1)
        
        q_tot = q_tot + self.V(states).view(bs, -1)
        
        return q_tot

# 经验回放缓冲区
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, obs, action, reward, next_state, next_obs, done):
        self.buffer.append((state, obs, action, reward, next_state, next_obs, done))
    
    def sample(self, batch_size):
        state, obs, action, reward, next_state, next_obs, done = zip(*random.sample(self.buffer, batch_size))
        return np.array(state), np.array(obs), np.array(action), np.array(reward), np.array(next_state), np.array(next_obs), np.array(done)
    
    def __len__(self):
        return len(self.buffer)

# QMIX 训练代码
def train_qmix(env, n_episodes=1000, batch_size=32, gamma=0.99, lr=0.0005):
    n_agents = len(env.agents)
    obs_shape = env.observation_space(env.agents[0]).shape[0]
    state_shape = env.get_state().shape[0]
    n_actions = env.action_space(env.agents[0]).n
    
    agent_net = AgentNetwork(obs_shape, n_actions)
    qmix_net = QMixNet(n_agents, state_shape, 32)
    
    optimizer = optim.Adam(list(agent_net.parameters()) + list(qmix_net.parameters()), lr=lr)
    replay_buffer = ReplayBuffer(5000)
    
    returns = []
    q_values = []
    
    for episode in range(n_episodes):
        obs = env.reset()
        state = env.get_state()
        episode_reward = 0
        episode_q_values = []
        
        done = False
        while not done:
            actions = {}
            for agent in env.agents:
                obs_tensor = torch.tensor(obs[agent], dtype=torch.float32).unsqueeze(0)
                q_values_agent = agent_net(obs_tensor)
                action = q_values_agent.argmax().item()
                actions[agent] = action
                episode_q_values.append(q_values_agent.max().item())
            
            next_obs, rewards, dones, infos = env.step(actions)
            next_state = env.get_state()
            
            for agent in env.agents:
                replay_buffer.push(state, obs[agent], actions[agent], rewards[agent], next_state, next_obs[agent], dones[agent])
            
            obs = next_obs
            state = next_state
            episode_reward += sum(rewards.values())
            done = all(dones.values())
        
        returns.append(episode_reward)
        q_values.append(np.mean(episode_q_values))
        
        if len(replay_buffer) > batch_size:
            state, obs, action, reward, next_state, next_obs, done = replay_buffer.sample(batch_size)
            state = torch.tensor(state, dtype=torch.float32)
            obs = torch.tensor(obs, dtype=torch.float32)
            action = torch.tensor(action, dtype=torch.long)
            reward = torch.tensor(reward, dtype=torch.float32)
            next_state = torch.tensor(next_state, dtype=torch.float32)
            next_obs = torch.tensor(next_obs, dtype=torch.float32)
            done = torch.tensor(done, dtype=torch.float32)
            
            agent_qs = agent_net(obs).gather(1, action.unsqueeze(-1)).squeeze(-1)

            agent_qs_next = agent_net(next_obs).max(1)[0]
            target_q_tot = reward + gamma * agent_qs_next * (1 - done)
            
            q_tot = qmix_net(agent_qs, state)
            loss = nn.MSELoss()(q_tot, target_q_tot)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
    return returns, q_values

# 绘制结果
def plot_results(returns, q_values):
    fig, ax = plt.subplots(2, 1, figsize=(12, 8))
    ax[0].plot(returns)
    ax[0].set_title("Total Returns Over Time")
    ax[0].set_xlabel("Episodes")
    ax[0].set_ylabel("Total Return")
    
    ax[1].plot(q_values)
    ax[1].set_title("Average Q-Value Over Time")
    ax[1].set_xlabel("Episodes")
    ax[1].set_ylabel("Average Q-Value")
    
    plt.tight_layout()
    plt.show()

# 主函数
if __name__ == "__main__":
    env = CustomEnvironment()
    returns, q_values = train_qmix(env)
    plot_results(returns, q_values)

RuntimeError: shape '[-1, 1, 5]' is invalid for input of size 32