In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import numpy as np
import matplotlib.pyplot as plt
from collections import deque
import random
from tqdm import tqdm
from env.custom_environment import CustomEnvironment

# Replay Buffer定义
class ReplayBuffer:
    def __init__(self, buffer_size):
        self.buffer = deque(maxlen=buffer_size)

    def add(self, experience):
        self.buffer.append(experience)

    def sample(self, batch_size):
        indices = np.random.choice(len(self.buffer), batch_size, replace=False)
        batch = [self.buffer[idx] for idx in indices]
        states, actions, rewards, next_states, dones = map(np.array, zip(*batch))
        return states, actions, rewards, next_states, dones

    def __len__(self):
        return len(self.buffer)

# Hierarchical Policy定义
class HierarchicalPolicy(nn.Module):
    def __init__(self, args):
        super(HierarchicalPolicy, self).__init__()
        self.fc_1 = nn.Linear(args.state_shape, 128)
        self.fc_2 = nn.Linear(128, args.noise_dim)

    def forward(self, state):
        x = F.relu(self.fc_1(state))
        q = self.fc_2(x)
        prob = F.softmax(q, dim=-1)
        return prob

# Bootstrapped RNN定义
class BootstrappedRNN(nn.Module):
    def __init__(self, input_shape, args):
        super(BootstrappedRNN, self).__init__()
        self.args = args
        self.fc = nn.Linear(input_shape, args.rnn_hidden_dim)
        self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.hyper_w = nn.Linear(args.noise_dim + args.n_agents, args.rnn_hidden_dim * args.n_actions)
        self.hyper_b = nn.Linear(args.noise_dim + args.n_agents, args.n_actions)

    def forward(self, obs, hidden_state, z):
        batch_size = obs.size(0)
        agent_id = obs[:, -self.args.n_agents:]
        z = z.expand(batch_size, -1)
        hyper_input = torch.cat([z, agent_id], dim=-1)

        x = F.relu(self.fc(obs))
        h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim)
        h = self.rnn(x, h_in)
        h = h.view(-1, 1, self.args.rnn_hidden_dim)

        hyper_w = self.hyper_w(hyper_input)
        hyper_b = self.hyper_b(hyper_input)
        hyper_w = hyper_w.view(-1, self.args.rnn_hidden_dim, self.args.n_actions)
        hyper_b = hyper_b.view(-1, 1, self.args.n_actions)

        q = torch.bmm(h, hyper_w) + hyper_b
        q = q.view(-1, self.args.n_actions)
        return q, h

# Variational Distribution定义
class VarDistribution(nn.Module):
    def __init__(self, args):
        super(VarDistribution, self).__init__()
        self.args = args
        self.GRU = nn.GRU(args.n_agents * args.n_actions + args.state_shape, 64)
        self.fc_1 = nn.Linear(64, 32)
        self.fc_2 = nn.Linear(32, args.noise_dim)

    def forward(self, inputs):
        _, h = self.GRU(inputs)
        x = F.relu(self.fc_1(h.squeeze(0)))
        x = self.fc_2(x)
        output = F.softmax(x, dim=-1)
        return output

# QMixNet定义
class QMixNet(nn.Module):
    def __init__(self, args):
        super(QMixNet, self).__init__()
        self.args = args
        if args.two_hyper_layers:
            self.hyper_w1 = nn.Sequential(nn.Linear(args.state_shape, args.hyper_hidden_dim),
                                          nn.ReLU(),
                                          nn.Linear(args.hyper_hidden_dim, args.n_agents * args.qmix_hidden_dim))
            self.hyper_w2 = nn.Sequential(nn.Linear(args.state_shape, args.hyper_hidden_dim),
                                          nn.ReLU(),
                                          nn.Linear(args.hyper_hidden_dim, args.qmix_hidden_dim))
        else:
            self.hyper_w1 = nn.Linear(args.state_shape, args.n_agents * args.qmix_hidden_dim)
            self.hyper_w2 = nn.Linear(args.state_shape, args.qmix_hidden_dim * 1)
        self.hyper_b1 = nn.Linear(args.state_shape, args.qmix_hidden_dim)
        self.hyper_b2 = nn.Sequential(nn.Linear(args.state_shape, args.qmix_hidden_dim),
                                      nn.ReLU(),
                                      nn.Linear(args.qmix_hidden_dim, 1))

    def forward(self, q_values, states):
        episode_num = q_values.size(0)
        q_values = q_values.view(episode_num, 1, self.args.n_agents)
        states = states.view(episode_num, -1)
        w1 = torch.abs(self.hyper_w1(states))
        b1 = self.hyper_b1(states)
        w1 = w1.view(episode_num, self.args.n_agents, self.args.qmix_hidden_dim)
        b1 = b1.view(episode_num, 1, self.args.qmix_hidden_dim)
        hidden = F.elu(torch.bmm(q_values, w1) + b1)
        w2 = torch.abs(self.hyper_w2(states))
        b2 = self.hyper_b2(states)
        w2 = w2.view(episode_num, self.args.qmix_hidden_dim, 1)
        b2 = b2.view(episode_num, 1, 1)
        q_total = torch.bmm(hidden, w2) + b2
        q_total = q_total.view(episode_num, -1, 1)
        return q_total

class MAVEN:
    def __init__(self, env, args, use_cuda=False):
        self.env = env
        self.args = args
        self.use_cuda = use_cuda
        self.device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")
        self.replay_buffer = ReplayBuffer(args.buffer_size)
        self.z_policy = HierarchicalPolicy(args).to(self.device)
        self.eval_rnn = BootstrappedRNN(args.obs_shape, args).to(self.device)
        self.target_rnn = BootstrappedRNN(args.obs_shape, args).to(self.device)
        self.eval_qmix_net = QMixNet(args).to(self.device)
        self.target_qmix_net = QMixNet(args).to(self.device)
        self.mi_net = VarDistribution(args).to(self.device)
        self.optimizer = Adam(list(self.z_policy.parameters()) +
                              list(self.eval_qmix_net.parameters()) +
                              list(self.eval_rnn.parameters()) +
                              list(self.mi_net.parameters()), lr=args.lr)
        self.loss_fn = nn.MSELoss()
        self.eval_hidden = None
        self.target_hidden = None
        self.model_dir = args.model_dir + '/' + args.alg + '/' + args.map
        if self.args.load_model:
            if os.path.exists(self.model_dir + '/rnn_net_params.pkl'):
                path_z_policy = self.model_dir + '/z_policy_params.pkl'
                path_rnn = self.model_dir + '/rnn_net_params.pkl'
                path_qmix = self.model_dir + '/qmix_net_params.pkl'
                path_mi = self.model_dir + '/mi_net_params.pkl'
                map_location = 'cuda:0' if self.args.cuda else 'cpu'
                self.z_policy.load_state_dict(torch.load(path_z_policy, map_location=map_location))
                self.eval_rnn.load_state_dict(torch.load(path_rnn, map_location=map_location))
                self.eval_qmix_net.load_state_dict(torch.load(path_qmix, map_location=map_location))
                self.mi_net.load_state_dict(torch.load(path_mi, map_location=map_location))
                print('Successfully load the model: {}, {}, {} and {}'.format(path_z_policy, path_rnn, path_qmix, path_mi))
            else:
                raise Exception("No model!")
        self.target_rnn.load_state_dict(self.eval_rnn.state_dict())
        self.target_qmix_net.load_state_dict(self.eval_qmix_net.state_dict())

    def choose_action(self, obs, hidden_state, z, epsilon=0.01):
        actions = {}
        new_hidden_state = {}
        num_agents = len(self.env.agents)
        batch_size = obs[list(obs.keys())[0]].shape[0]  # assuming all agents have the same batch size
        z_expanded = z.expand(batch_size, -1)  # Expand z to match the batch size

        for i, agent in enumerate(self.env.agents):
            agent_obs = torch.tensor(obs[agent], dtype=torch.float32).to(self.device).unsqueeze(0)  # Add batch dimension
            agent_hidden_state = torch.tensor(hidden_state[agent], dtype=torch.float32).to(self.device).unsqueeze(0)  # Add batch dimension
            z_agent = z_expanded  # Use expanded z for the current agent

            q_values, new_agent_hidden_state = self.eval_rnn(agent_obs, agent_hidden_state, z_agent)
            if random.random() < epsilon:
                action = random.choice(range(self.args.n_actions))
            else:
                action = q_values.argmax(dim=-1).item()
            actions[agent] = action
            new_hidden_state[agent] = new_agent_hidden_state.detach().cpu().numpy()
        return actions, new_hidden_state

    def train(self):
        if len(self.replay_buffer) < self.args.minimal_size:
            return
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.args.batch_size)
        states = {agent: torch.tensor([s[agent] for s in states], dtype=torch.float32).to(self.device) for agent in self.env.agents}
        actions = {agent: torch.tensor([a[agent] for a in actions], dtype=torch.int64).to(self.device) for agent in self.env.agents}
        rewards = {agent: torch.tensor([r[agent] for r in rewards], dtype=torch.float32).to(self.device) for agent in self.env.agents}
        next_states = {agent: torch.tensor([ns[agent] for ns in next_states], dtype=torch.float32).to(self.device) for agent in self.env.agents}
        dones = {agent: torch.tensor([d[agent] for d in dones], dtype=torch.float32).to(self.device) for agent in self.env.agents}
        hidden_state = {agent: torch.zeros((self.args.batch_size, self.args.rnn_hidden_dim), dtype=torch.float32).to(self.device) for agent in self.env.agents}
        next_hidden_state = {agent: torch.zeros((self.args.batch_size, self.args.rnn_hidden_dim), dtype=torch.float32).to(self.device) for agent in self.env.agents}

        q_values = {agent: self.eval_rnn(states[agent], hidden_state[agent], z)[0].gather(1, actions[agent].unsqueeze(-1)).squeeze(-1) for agent in self.env.agents}
        next_q_values = {agent: self.target_rnn(next_states[agent], next_hidden_state[agent], z)[0].max(dim=-1)[0] for agent in self.env.agents}

        targets = {agent: rewards[agent] + self.args.gamma * next_q_values[agent] * (1 - dones[agent]) for agent in self.env.agents}

        q_values_tensor = torch.stack(list(q_values.values()), dim=1)
        targets_tensor = torch.stack(list(targets.values()), dim=1)
        states_tensor = torch.cat([states[agent].view(self.args.batch_size, -1) for agent in self.env.agents], dim=1)

        q_total = self.eval_qmix_net(q_values_tensor, states_tensor)
        targets_total = self.target_qmix_net(targets_tensor, states_tensor)

        loss = self.loss_fn(q_total, targets_total.detach())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def run(self, num_episodes):
        total_q_values = []
        total_returns = []

        for i in range(10):
            with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:
                for episode in range(int(num_episodes / 10)):
                    obs = self.env.reset()
                    hidden_state = {agent: np.zeros(self.args.rnn_hidden_dim) for agent in self.env.agents}
                    state_tensor = torch.tensor([obs[agent] for agent in self.env.agents], dtype=torch.float32).to(self.device)
                    z = self.z_policy(state_tensor).unsqueeze(0)
                    episode_q_values = []
                    episode_return = 0
                    actions_over_time = []

                    for t in range(int(self.args.num_iterations)):
                        actions, hidden_state = self.choose_action(obs, hidden_state, z)
                        next_obs, rewards, dones, _ = self.env.step(actions)
                        self.replay_buffer.add((obs, actions, rewards, next_obs, dones))
                        obs = next_obs

                        actions_over_time.append(actions)

                        q_values = []
                        for agent in self.env.agents:
                            agent_obs = torch.tensor(obs[agent], dtype=torch.float32).unsqueeze(0).to(self.device)
                            agent_hidden_state = torch.tensor(hidden_state[agent], dtype=torch.float32).unsqueeze(0).to(self.device)
                            agent_q_values, _ = self.eval_rnn(agent_obs, agent_hidden_state, z)
                            q_values.append(agent_q_values.max().item())
                        episode_q_values.append(np.mean(q_values))

                        episode_return += sum(rewards.values())

                        self.train()

                        if all(dones.values()):
                            break

                    total_q_values.append(np.mean(episode_q_values))
                    total_returns.append(episode_return)

                    if episode_return > 10000:
                        print(f"Episode {episode + 1} reached a total return of {episode_return}")
                        for t, actions in enumerate(actions_over_time):
                            actions_list = [actions[agent] for agent in self.env.agents]
                            print(actions_list)
                        print("done")
                    if (episode + 1) % 10 == 0:
                        pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + episode + 1), 'returns': '%.3f' % np.mean(total_returns[-10:])})
                    pbar.update(1)

        return total_q_values, total_returns

    def run_one_iteration(self):
        obs = self.env.reset()
        hidden_state = {agent: np.zeros(self.args.rnn_hidden_dim) for agent in self.env.agents}
        z = self.z_policy(torch.tensor([obs[agent] for agent in self.env.agents], dtype=torch.float32).to(self.device))

        actions_over_time = []

        for t in range(self.args.num_iterations):
            actions, hidden_state = self.choose_action(obs, hidden_state, z)
            next_obs, rewards, dones, _ = self.env.step(actions)
            actions_over_time.append(actions)
            if all(dones.values()):
                break
            obs = next_obs

        for t, actions in enumerate(actions_over_time):
            print(f"Time step {t}:")
            for agent, action in actions.items():
                print(f"  Agent {agent}: Action {action}")

        return actions, next_obs, rewards, dones


class Args:
    def  __init__(self, env):
        self.state_shape = env.max_obs_size
        self.obs_shape = env.max_obs_size
        self.n_agents = 5
        self.n_actions = 2
        self.qmix_hidden_dim = 32
        self.hyper_hidden_dim = 64
        self.rnn_hidden_dim = 64
        self.two_hyper_layers = False
        self.buffer_size = 100000
        self.minimal_size = 65
        self.batch_size = 64
        self.gamma = 0.95
        self.lr = 1e-3
        self.num_iterations = env.num_iterations
        self.num_episodes = 1000
        self.noise_dim = 10
        self.entropy_coefficient = 0.01
        self.lambda_mi = 1.0
        self.lambda_ql = 1.0
        self.grad_norm_clip = 10
        self.model_dir = ''
        self.alg = ''
        self.map = ''
        self.load_model = False
        self.cuda = False

env = CustomEnvironment()
torch.manual_seed(0)
args = Args(env)

use_cuda = False
maven = MAVEN(env, args, use_cuda)
total_q_values, total_returns = maven.run(args.num_episodes)

# 可视化总回报变化
plt.plot(total_returns)
plt.xlabel('Episodes')
plt.ylabel('Total Return')
plt.title('Total Return over episodes')
plt.show()

# 可视化Q值变化
plt.plot(total_q_values)
plt.xlabel('Episodes')
plt.ylabel('Total Q value')
plt.title('Total Q value over episodes')
plt.show()

Iteration 0:   0%|          | 0/100 [00:00<?, ?it/s]


RuntimeError: expand(torch.FloatTensor{[1, 5, 10]}, size=[4, -1]): the number of sizes provided (2) must be greater or equal to the number of dimensions in the tensor (3)