In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
from tqdm import tqdm
import matplotlib.pyplot as plt
from env_new.MAenv import CustomMAEnvironment


# 定义 Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)

    def add(self, state, obs, actions, rewards, next_state, next_obs, dones):
        self.buffer.append({
            "state": state,
            "obs": obs,
            "actions": actions,
            "rewards": rewards,
            "next_state": next_state,
            "next_obs": next_obs,
            "dones": dones
        })

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states = np.array([item["state"] for item in batch], dtype=np.float32)
        obs = np.array([item["obs"] for item in batch], dtype=np.float32)
        actions = np.array([item["actions"] for item in batch], dtype=np.int64)
        rewards = np.array([item["rewards"] for item in batch], dtype=np.float32)
        next_states = np.array([item["next_state"] for item in batch], dtype=np.float32)
        next_obs = np.array([item["next_obs"] for item in batch], dtype=np.float32)
        dones = np.array([item["dones"] for item in batch], dtype=np.float32)

        return (
            torch.tensor(states, dtype=torch.float32),
            torch.tensor(obs, dtype=torch.float32),
            torch.tensor(actions, dtype=torch.long),
            torch.tensor(rewards, dtype=torch.float32),
            torch.tensor(next_states, dtype=torch.float32),
            torch.tensor(next_obs, dtype=torch.float32),
            torch.tensor(dones, dtype=torch.float32)
        )

    def __len__(self):
        return len(self.buffer)


# 定义智能体网络
class AgentNetwork(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_dim=128):
        super(AgentNetwork, self).__init__()
        self.fc1 = nn.Linear(obs_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_dim)

    def forward(self, obs):
        x = torch.relu(self.fc1(obs))
        x = torch.relu(self.fc2(x))
        q_values = self.fc3(x)
        return q_values


# 定义 Mixing Network（包含单调性约束）
class AttentionMixingNetwork(nn.Module):
    def __init__(self, num_agents, state_dim, hidden_dim=64):
        super(AttentionMixingNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.agent_qs_proj = nn.Linear(num_agents, hidden_dim)
        self.attention = nn.MultiheadAttention(hidden_dim, num_heads=4, batch_first=True)
        self.q_combine = nn.Linear(hidden_dim, 1)

    def forward(self, agent_qs, state):
        x = torch.relu(self.fc1(state))  # [batch_size, hidden_dim]
        agent_qs_proj = torch.relu(self.agent_qs_proj(agent_qs))  # 保证单调性
        x, _ = self.attention(x.unsqueeze(1), agent_qs_proj.unsqueeze(1), agent_qs_proj.unsqueeze(1))
        x = torch.relu(self.fc2(x.squeeze(1)))  # [batch_size, hidden_dim]
        q_total = self.q_combine(x)  # [batch_size, 1]
        return q_total


# QMIX 主类
class QMIX:
    def __init__(self, env, state_dim, obs_dim, action_dim, num_agents, lr=1e-4, gamma=0.99, epsilon_decay=0.995):
        self.env = env
        self.state_dim = state_dim
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.num_agents = num_agents
        self.gamma = gamma
        self.epsilon = 1.0
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = 0.05

        self.agent_network = AgentNetwork(obs_dim, action_dim)
        self.mixing_network = AttentionMixingNetwork(num_agents, state_dim)
        self.target_agent_network = AgentNetwork(obs_dim, action_dim)
        self.target_mixing_network = AttentionMixingNetwork(num_agents, state_dim)

        self.target_agent_network.load_state_dict(self.agent_network.state_dict())
        self.target_mixing_network.load_state_dict(self.mixing_network.state_dict())

        self.optimizer = optim.Adam(list(self.agent_network.parameters()) + list(self.mixing_network.parameters()), lr=lr)
        self.replay_buffer = ReplayBuffer(50000)

    def select_action(self, obs):
        actions = []
        if random.random() < self.epsilon:
            actions = [np.random.randint(0, self.action_dim) for _ in range(self.num_agents)]
        else:
            obs_array = np.array(obs)
            obs_tensor = torch.tensor(obs_array, dtype=torch.float32)
            q_values = self.agent_network(obs_tensor)
            actions = q_values.argmax(dim=1).tolist()

        self.epsilon = max(self.epsilon * self.epsilon_decay, self.epsilon_min)
        return actions

    def train(self, batch_size):
        if len(self.replay_buffer) < batch_size:
            return

        states, obs, actions, rewards, next_states, next_obs, dones = self.replay_buffer.sample(batch_size)

        q_values = []
        for agent in range(self.num_agents):
            agent_obs = obs[:, agent, :]
            agent_actions = actions[:, agent]
            agent_q = self.agent_network(agent_obs)
            q_values.append(agent_q.gather(1, agent_actions.unsqueeze(-1)).squeeze(-1))
        q_values = torch.stack(q_values, dim=1)

        q_total = self.mixing_network(q_values, states)

        with torch.no_grad():
            target_q_values = []
            for agent in range(self.num_agents):
                agent_next_obs = next_obs[:, agent, :]
                target_agent_q = self.target_agent_network(agent_next_obs)
                target_q_values.append(torch.max(target_agent_q, dim=1)[0])
            target_q_values = torch.stack(target_q_values, dim=1)
            target_q_total = self.target_mixing_network(target_q_values, next_states)

        rewards_sum = rewards.sum(dim=1)
        targets = rewards_sum + self.gamma * (1 - dones) * target_q_total.squeeze(-1)
        loss = nn.MSELoss()(q_total.squeeze(-1), targets)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def update_target_networks(self, tau=0.01):
        for target_param, param in zip(self.target_agent_network.parameters(), self.agent_network.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
        for target_param, param in zip(self.target_mixing_network.parameters(), self.mixing_network.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)


def train_qmix(env, num_episodes=1000, batch_size=32, target_update_interval=100):
    state_dim = env.get_state().shape[0]
    obs_dim = env.observation_space(env.agents[0]).shape[0]
    action_dim = env.action_space(env.agents[0]).n
    num_agents = len(env.agents)

    qmix = QMIX(env, state_dim, obs_dim, action_dim, num_agents)

    rewards_history = []
    q_values_history = []  # 用于记录每个 episode 的平均 Q 值

    with tqdm(total=num_episodes, desc="Training Progress") as pbar:
        for episode in range(num_episodes):
            obs = env.reset()
            state = env.get_state()
            done = False
            episode_reward = 0
            episode_q_value = 0  # 记录这一 episode 中的 Q 值总和
            steps = 0  # 记录这一 episode 的总步数

            while not done:
                actions = {agent: qmix.select_action([obs[agent]])[0] for agent in env.agents}
                next_obs, rewards, dones, _ = env.step(actions)
                next_state = env.get_state()

                # 存储到 Replay Buffer
                qmix.replay_buffer.add(
                    state,
                    np.array([obs[agent] for agent in env.agents]),
                    np.array([actions[agent] for agent in env.agents]),
                    np.array([rewards[agent] for agent in env.agents]),
                    next_state,
                    np.array([next_obs[agent] for agent in env.agents]),
                    float(all(dones.values()))
                )

                # 更新观测和状态
                obs = next_obs
                state = next_state
                episode_reward += sum(rewards.values())
                steps += 1

                # 训练
                loss = qmix.train(batch_size)

                # 记录当前 Q 值
                if loss is not None:
                    q_values = []
                    for agent in env.agents:
                        obs_tensor = torch.tensor([obs[agent]], dtype=torch.float32)
                        agent_q_values = qmix.agent_network(obs_tensor).max().item()
                        q_values.append(agent_q_values)
                    episode_q_value += np.mean(q_values)  # 取所有智能体 Q 值的平均值

                if all(dones.values()):
                    break

            rewards_history.append(episode_reward)
            # 记录每个 episode 的平均 Q 值（总和除以步数）
            q_values_history.append(episode_q_value / steps if steps > 0 else 0)

            # 更新目标网络
            if episode % target_update_interval == 0:
                qmix.update_target_networks()

            # 更新进度条信息
            pbar.set_postfix({
                "Reward": episode_reward,
                "Avg Q": q_values_history[-1]
            })
            pbar.update(1)

    return qmix, rewards_history, q_values_history

def plot_training_results(rewards_history, q_values_history):
    # 绘制 Reward 曲线
    plt.figure(figsize=(12, 6))
    plt.plot(rewards_history, label="Reward")
    plt.xlabel("Episode")
    plt.ylabel("Total Reward")
    plt.title("Reward over Episodes")
    plt.legend()
    plt.grid()
    plt.show()

    # 绘制 Q 值曲线
    plt.figure(figsize=(12, 6))
    plt.plot(q_values_history, label="Q Value", color="orange")
    plt.xlabel("Episode")
    plt.ylabel("Average Q Value")
    plt.title("Average Q Value over Episodes")
    plt.legend()
    plt.grid()
    plt.show()


# 创建环境
env = CustomMAEnvironment()

In [2]:
# 训练 QMIX
qmix_trained, rewards_history, q_values_history = train_qmix(env, num_episodes=1000)

# 绘制训练结果
plot_training_results(rewards_history, q_values_history)

Training Progress:   0%|          | 0/1000 [00:00<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x3 and 5x128)