In [None]:
# !pip install git+https://github.com/Farama-Foundation/MAgent2

Collecting git+https://github.com/Farama-Foundation/MAgent2
  Cloning https://github.com/Farama-Foundation/MAgent2 to /tmp/pip-req-build-xjpbpmle
  Running command git clone --filter=blob:none --quiet https://github.com/Farama-Foundation/MAgent2 /tmp/pip-req-build-xjpbpmle
  Resolved https://github.com/Farama-Foundation/MAgent2 to commit b2ddd49445368cf85d4d4e1edcddae2e28aa1406
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [None]:
import numpy as np
import torch
from magent2.environments import battle_v4
from pettingzoo.utils import random_demo
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import random
import os
import cv2
from collections import deque
import time
import matplotlib.pyplot as plt
from collections import namedtuple
import imageio
from torch.nn import MSELoss

In [None]:
def save_model(model, file_path):
    torch.save(model.state_dict(), file_path)
    print(f"Model saved to {file_path}")

# Models

### Replay Buffer

In [None]:
# Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*batch)
        return (np.stack(state), np.array(action), np.array(reward),
                np.stack(next_state), np.array(done))
    def __len__(self):
        return len(self.buffer)

In [None]:
class RandomAgent:
    def __init__(self, action_space):
        self.action_space = action_space

    def get_action(self, observation):
        return self.action_space.sample()

### VDN networks

In [None]:
class QNetwork(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
        )
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, action_shape),
        )

    def forward(self, x):
        assert len(x.shape) >= 3, "Input shape error"
        x = self.cnn(x)
        batchsize = x.shape[0] if len(x.shape) > 3 else 1
        x = x.reshape(batchsize, -1)
        return self.network(x)

class VDN:
    def __init__(self, observation_shape, action_shape, agents, batch_size=64, lr=1e-3, gamma=0.8, device="cpu"):
        self.device = torch.device(device)
        self.agents = agents
        self.q_networks = {
            agent: QNetwork(observation_shape, action_shape).to(self.device)
            for agent in agents
        }
        self.lr = lr
        self.optimizers = {
            agent: optim.Adam(self.q_networks[agent].parameters(), lr=self.lr)
            for agent in agents
        }
        self.schedulers = {
            agent: torch.optim.lr_scheduler.StepLR(self.optimizers[agent], step_size=10, gamma=0.9)
            for agent in agents
        }
        self.target_networks = {
            agent: QNetwork(observation_shape, action_shape).to(self.device)
            for agent in agents
        }
        for agent in agents:
            self.target_networks[agent].load_state_dict(self.q_networks[agent].state_dict())
            self.target_networks[agent].eval()

        self.replay_buffer = ReplayBuffer(capacity=10000)
        self.batch_size = batch_size
        self.gamma = gamma
        self.action_shape = action_shape
        self.epsilon = 0.5
        self.epsilon_decay = 0.9
        self.epsilon_min = 0.05
        self.max_grad_norm = 1.0

    def get_action(self, agent, observation):
        if np.random.rand() < self.epsilon:
            return np.random.randint(self.action_shape)
        else:
            state_tensor = torch.FloatTensor(observation).unsqueeze(0).permute(0, 3, 1, 2).to(self.device)
            with torch.no_grad():
                return self.q_networks[agent](state_tensor).argmax().item()

    def update(self):
        if len(self.replay_buffer) < self.batch_size:
            return

        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
        states = torch.FloatTensor(states).permute(0, 3, 1, 2).to(self.device)
        actions = torch.LongTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        next_states = torch.FloatTensor(next_states).permute(0, 3, 1, 2).to(self.device)
        dones = torch.FloatTensor(dones).to(self.device)

        q_values = []
        for agent in self.agents:
            q_value = self.q_networks[agent](states).gather(1, actions.unsqueeze(-1)).squeeze(-1)
            q_values.append(q_value)
        q_tot = torch.sum(torch.stack(q_values, dim=0), dim=0)

        with torch.no_grad():
            next_q_values = []
            for agent in self.agents:
                next_q_value = self.target_networks[agent](next_states).max(dim=1)[0]
                next_q_values.append(next_q_value)
            next_q_tot = torch.sum(torch.stack(next_q_values, dim=0), dim=0)
            q_tot_target = rewards + self.gamma * (1 - dones) * next_q_tot

        loss = torch.mean((q_tot - q_tot_target) ** 2)

        for agent in self.agents:
            self.optimizers[agent].zero_grad()
        loss.backward()

        for agent in self.agents:
            torch.nn.utils.clip_grad_norm_(self.q_networks[agent].parameters(), self.max_grad_norm)
        for agent in self.agents:
            self.optimizers[agent].step()

        # print(f"Loss: {loss.item()}")
        # if "red_0" in self.agents:
        #     agent = "red_0"
        #     for name, param in self.q_networks[agent].named_parameters():
        #         if param.grad is not None:
        #             print(f"Gradient for {agent} -> {name}: {param.grad.abs().mean().item()}")

    def update_target_networks(self):
        for agent in self.agents:
            self.target_networks[agent].load_state_dict(self.q_networks[agent].state_dict())

    def decay_epsilon(self):
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)


In [None]:
import wandb
from google.colab import userdata

wandb_key = userdata.get("wandb-key")

wandb.login(key = wandb_key)

wandb.init(project="RL_TRAINING", name="VDN",
            config={"epochs_num": 70, "opponents": "random, training with blue + red data", "batch_size" : 128, "num_agent": 81})

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mtheseventeengv[0m ([33mtrungviet17[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
import time
def train_vdn(env, red_agents, blue_agents, episodes=70, target_update_freq=5):
    total_rewards = []
    red_rewards = 0
    blue_rewards = 0
    for episode in range(episodes):
        kill_counts = {"red": 0, "blue": 0}
        env.reset()
        episode_kills = {"red": 0, "blue": 0}  # Track kills for the current episode
        total_reward = {agent: 0 for agent in blue_agents.agents}  # Initialize rewards for each red agent

        start = time.time()
        for agent in env.agent_iter():
            observation, reward, termination, truncation, _ = env.last()

            # Handle agent termination or truncation
            if termination or truncation:
                env.step(None)  # Pass None explicitly for terminated agents
                continue

            team = agent.split("_")[0]

            if reward > 4.5:
                episode_kills[team] += 1

            # Decide action based on the agent type
            if agent.startswith("blue"):
                # Red agent uses the VDN model to select action
                action = blue_agents.get_action(agent, observation)
                red_rewards += reward
            else:
                # Blue agent uses the random agent policy
                action = red_agents.get_action(observation)
                blue_rewards += reward

            env.step(action)

            # Update replay buffer and rewards for red agents
            if agent.startswith("blue"):
                # next_obs = env.last()[0] if agent in env.agents else None
                blue_agents.replay_buffer.push(
                    observation, action, reward, env.last()[0], termination or truncation
                )
                total_reward[agent] += reward

        # Train VDN model after each episode
        blue_agents.update()

        # Decay epsilon for exploration-exploitation balance
        blue_agents.decay_epsilon()

        # Update target networks periodically
        if episode % target_update_freq == 0:
            blue_agents.update_target_networks()

        # Accumulate kills for this episode
        for team in kill_counts:
            kill_counts[team] += episode_kills[team]

        # Calculate total reward for this episode
        episode_total_reward = sum(total_reward.values())
        total_rewards.append(episode_total_reward)

        # Logging
        print(f"Episode {episode}, Total Reward: {episode_total_reward}")
        print(f"Episode {episode} Kills - Red: {episode_kills['red']}, Blue: {episode_kills['blue']}")
        wandb.log({
                "episode": episode,
                "gap_rewards": blue_rewards - red_rewards,
                "epsilon": vdn.epsilon,
                "time": time.time() - start,
                "red_kill": episode_kills["red"],
                "blue_kill": episode_kills["blue"]
            })

    env.close()
    print(f"Total Kills - Red: {kill_counts['red']}, Blue: {kill_counts['blue']}")
    return total_rewards, kill_counts


In [None]:
env = battle_v4.env(map_size=45, render_mode="rgb-array")

device = "cuda" if torch.cuda.is_available() else "cpu"

observation_shape = env.observation_space("red_0").shape
action_shape = env.action_space("red_0").n
env.reset()

# Initialize the VDN wrapper for centralized training
blue_agents = [agent for agent in env.agents if agent.startswith("blue")]
vdn = VDN(observation_shape, action_shape, blue_agents, device=device)
red_agent = RandomAgent(env.action_space("red_0"))

# Train the VDN algorithm
train_vdn(env, red_agent, vdn)


Episode 0, Total Reward: -3162.0601164018735
Episode 0 Kills - Red: 4, Blue: 12
Episode 1, Total Reward: -3257.4851212650537
Episode 1 Kills - Red: 3, Blue: 13
Episode 2, Total Reward: -3195.220118932426
Episode 2 Kills - Red: 2, Blue: 18
Episode 3, Total Reward: -3190.8101193299517
Episode 3 Kills - Red: 6, Blue: 17
Episode 4, Total Reward: -3307.1401236010715
Episode 4 Kills - Red: 2, Blue: 18
Episode 5, Total Reward: -3315.2751254737377
Episode 5 Kills - Red: 4, Blue: 21
Episode 6, Total Reward: -3304.195124122314
Episode 6 Kills - Red: 2, Blue: 22
Episode 7, Total Reward: -3425.720127790235
Episode 7 Kills - Red: 4, Blue: 15
Episode 8, Total Reward: -3355.8851252188906
Episode 8 Kills - Red: 1, Blue: 19
Episode 9, Total Reward: -3396.815126657486
Episode 9 Kills - Red: 3, Blue: 19
Episode 10, Total Reward: -3440.490129268728
Episode 10 Kills - Red: 3, Blue: 21
Episode 11, Total Reward: -3504.320131923072
Episode 11 Kills - Red: 2, Blue: 21
Episode 12, Total Reward: -3678.1501383213