In [2]:
!pip install git+https://github.com/Farama-Foundation/MAgent2

Collecting git+https://github.com/Farama-Foundation/MAgent2
  Cloning https://github.com/Farama-Foundation/MAgent2 to /tmp/pip-req-build-3fk0th1j
  Running command git clone --filter=blob:none --quiet https://github.com/Farama-Foundation/MAgent2 /tmp/pip-req-build-3fk0th1j
  Resolved https://github.com/Farama-Foundation/MAgent2 to commit b2ddd49445368cf85d4d4e1edcddae2e28aa1406
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting pettingzoo>=1.23.1 (from magent2==0.3.3)
  Downloading pettingzoo-1.24.3-py3-none-any.whl.metadata (8.5 kB)
Collecting gymnasium>=0.28.0 (from pettingzoo>=1.23.1->magent2==0.3.3)
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium>=0.28.0->pettingzoo>=1.23.1->magent2==0.3.3)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Downloading pe

## Import

In [3]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import numpy as np
import os
from tqdm import tqdm

from magent2.environments import battle_v4
import cv2
from collections import deque
import time
import random
import warnings
warnings.filterwarnings('ignore')

## DQN

In [4]:
class QNetwork(nn.Module):
    def __init__(self, observation_shape, action_shape, device='cpu'):
        super().__init__()
        self.observation_shape = observation_shape
        self.action_shape = action_shape
        self.device = device

        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], kernel_size=3),
            nn.ReLU(),
        )

        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, action_shape),
        )

    def forward(self, x):
        assert len(x.shape) >= 3, "only support magent input observation"
        x = self.cnn(x)
        if len(x.shape) == 3:
            batchsize = 1
        else:
            batchsize = x.shape[0]
        x = x.reshape(batchsize, -1)
        return self.network(x)

## Replay Buffer

In [5]:
class ReplayMemory(Dataset):
    def __init__(self, maxlen):
        super().__init__()
        self.maxlen = maxlen
        self.step_memory = [deque([],maxlen=self.maxlen)]

    def push(self, step_idx, state, action, reward, next_state, done):
        if step_idx == len(self.step_memory):
            self.step_memory.append(deque([],maxlen=self.maxlen))
        self.step_memory[step_idx].append((state, action, reward, next_state, done))

    def __len__(self):
        return sum([len(memory) for memory in self.step_memory])

    def __getitem__(self, idx):
        step_idx = 0
        while idx >= len(self.step_memory[step_idx]):
            idx -= len(self.step_memory[step_idx])
            step_idx += 1
        state, action, reward, next_state, done = self.step_memory[step_idx][idx]
        return (
            torch.Tensor(state).float().permute([2, 0, 1]),
            torch.tensor(action),
            torch.tensor(reward, dtype=torch.float),
            torch.tensor(next_state).float().permute([2,0,1]),
            torch.tensor(done, dtype=torch.float32),
        )

## Trainer

In [6]:
class Trainer:
    def __init__(
        self,
        policy_dqn, target_dqn,
        n_action,
        loss_fn, optimizer, scheduler,
        epsilon_start, epsilon_end, epsilon_decay,
        device='cpu'
    ):

        self.policy_dqn = policy_dqn.to(device)
        self.target_dqn = target_dqn.to(device)
        self.target_dqn.eval()

        self.n_action = n_action

        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.scheduler = scheduler

        self.epsilon_start = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay = epsilon_decay
        self.epsilon = self.epsilon_start

        self.device = device

        self.policy_dqn.apply(self.weights_init)

    def weights_init(self, m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            nn.init.xavier_uniform_(m.weight)
            if torch.is_tensor(m.bias):
                m.bias.data.fill_(0.01)


    def policy(self, observation):
        if np.random.rand() < self.epsilon:
            return np.random.randint(low=0, high=self.n_action)
        else:
            with torch.no_grad():
                q_values = self.policy_dqn(
                    torch.Tensor(observation).float().permute([2, 0, 1]).unsqueeze(0).to(self.device)
                )
            return torch.argmax(q_values, dim=1).cpu().numpy()[0]


    def optimize_model(self, replay_memory, batch_size, gamma):
        if len(replay_memory) < batch_size:
            return
        train_loader = DataLoader(replay_memory, batch_size=batch_size, shuffle=True)
        self.policy_dqn.train()

        for observations, actions, rewards, next_observations, dones in train_loader:

            self.policy_dqn.zero_grad()

            observations = observations.to(self.device)
            actions = actions.unsqueeze(1).to(self.device)
            rewards = rewards.unsqueeze(1).to(self.device)
            next_observations = next_observations.to(self.device)
            dones = dones.unsqueeze(1).to(self.device)

            current_q_values = self.policy_dqn(observations).gather(1, actions)

            with torch.no_grad():
                target_q_values = rewards + gamma * (1 - dones) * self.target_dqn(next_observations).max(1, keepdim=True)[0]

            # Compute loss
            loss = self.loss_fn(current_q_values, target_q_values)

            # Optimize the network
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()

    def train(self,
              env, episodes,
              target_agent, batch_size, gamma, replay_memory,
              update_tg_freq, TAU
             ):
        train_rewards = []
        train_durations = []

        for episode in tqdm(range(episodes)):
            ep_reward = 0

            ep_steps = 0

            observations = {}
            actions = {}
            step_idx = {}

            env.reset()

            for idx, agent in enumerate(env.agent_iter()):
                ep_steps += 1
                observation, reward, termination, truncation, info = env.last()

                if target_agent in agent:
                    ep_reward += reward
                else:
                    ep_reward -= abs(reward)

                step_idx[agent] = 0
                action = self.policy(observation)

                observations[agent] = observation
                actions[agent] = action
                env.step(action)

                if (idx+1) % env.num_agents == 0:
                    break

            for agent in env.agent_iter():
                ep_steps += 1

                next_observation, reward, termination, truncation, info = env.last()

                if target_agent in agent:
                    ep_reward += reward
                else:
                    ep_reward -= abs(reward)

                # Agent die
                if termination or truncation:
                    action = None
                else:
                    action = self.policy(next_observation)

                replay_memory.push(
                    step_idx[agent],
                    observations[agent],
                    actions[agent],
                    reward,
                    next_observation,
                    termination or truncation
                )

                step_idx[agent] += 1
                observations[agent] = next_observation
                actions[agent] = action
                env.step(action)

            # Training mô hình với memory hiện tại
            self.optimize_model(replay_memory, batch_size, gamma)

            # Cập nhật lại mô hình mục tiêu theo chu kì
            if episode % update_tg_freq == 0:
                target_dqn_state_dict = self.target_dqn.state_dict()
                policy_dqn_state_dict = self.policy_dqn.state_dict()
                for key in policy_dqn_state_dict:
                    target_dqn_state_dict[key] = policy_dqn_state_dict[key]*TAU + target_dqn_state_dict[key]*(1-TAU)
                self.target_dqn.load_state_dict(target_dqn_state_dict)


            self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)

            print(f"\nEpisode {episode + 1}, Episode Reward: {ep_reward}, Steps: {ep_steps}, Epsilon: {self.epsilon}")

            train_rewards.append(ep_reward)
            train_durations.append(ep_steps)

        return train_rewards, train_durations

## Config

In [13]:

env = battle_v4.env(map_size=45, render_mode="rgb_array")

episodes = 40
target_agent = 'blue'
batch_size = 1024
gamma = 0.89
update_tg_freq = 1
TAU = 0.3

maxlen = 162 * episodes

learning_rate = 1e-3
theta = 0.000001
epsilon_start = 1.0
epsilon_end = 0.01
epsilon_decay = 0.9
device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [8]:
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

## Training Loop

In [9]:
replay_memory = ReplayMemory(maxlen)

In [10]:
policy_dqn = QNetwork(env.observation_space("red_0").shape, env.action_space("red_0").n).to(device)
target_dqn = QNetwork(env.observation_space("red_0").shape, env.action_space("red_0").n).to(device)
target_dqn.load_state_dict(policy_dqn.state_dict())

<All keys matched successfully>

In [11]:
loss_function = nn.MSELoss()
optimizer = torch.optim.AdamW(policy_dqn.parameters(), lr=learning_rate)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=episodes, eta_min=theta)

In [12]:

trainer = Trainer(
    policy_dqn, target_dqn,
    env.action_space("red_0").n,
    loss_function, optimizer, lr_scheduler,
    epsilon_start, epsilon_end, epsilon_decay,
    device=device
)

In [14]:

train_rewards, train_durations = trainer.train(
    env, episodes,
    target_agent, batch_size, gamma, replay_memory,
    update_tg_freq, TAU
)

  2%|▎         | 1/40 [00:53<34:40, 53.33s/it]


Episode 1, Episode Reward: -6647.125233091414, Steps: 158611, Epsilon: 0.9


  5%|▌         | 2/40 [02:45<55:33, 87.72s/it]


Episode 2, Episode Reward: -6047.095206912607, Steps: 158669, Epsilon: 0.81


  8%|▊         | 3/40 [05:07<1:09:20, 112.45s/it]


Episode 3, Episode Reward: -3693.5851257890463, Steps: 105894, Epsilon: 0.7290000000000001


 10%|█         | 4/40 [07:36<1:16:16, 127.14s/it]


Episode 4, Episode Reward: -2363.1900751609355, Steps: 71251, Epsilon: 0.6561000000000001


 12%|█▎        | 5/40 [10:54<1:28:56, 152.46s/it]


Episode 5, Episode Reward: -4044.89513541013, Steps: 107790, Epsilon: 0.5904900000000002


 15%|█▌        | 6/40 [14:35<1:39:39, 175.86s/it]


Episode 6, Episode Reward: -2555.2950820820406, Steps: 96029, Epsilon: 0.5314410000000002


 18%|█▊        | 7/40 [18:16<1:44:47, 190.53s/it]


Episode 7, Episode Reward: -1391.555031816475, Steps: 43618, Epsilon: 0.47829690000000014


 20%|██        | 8/40 [21:54<1:46:22, 199.47s/it]


Episode 8, Episode Reward: -799.0200171433389, Steps: 24652, Epsilon: 0.43046721000000016


 22%|██▎       | 9/40 [26:01<1:50:45, 214.37s/it]


Episode 9, Episode Reward: -1557.7200390445068, Steps: 64982, Epsilon: 0.38742048900000015


 25%|██▌       | 10/40 [30:04<1:51:30, 223.01s/it]


Episode 10, Episode Reward: -450.9500146685168, Steps: 26464, Epsilon: 0.34867844010000015


 28%|██▊       | 11/40 [34:14<1:51:51, 231.44s/it]


Episode 11, Episode Reward: -380.3900141045451, Steps: 28637, Epsilon: 0.31381059609000017


 30%|███       | 12/40 [38:26<1:50:49, 237.49s/it]


Episode 12, Episode Reward: -466.28000976890326, Steps: 24160, Epsilon: 0.28242953648100017


 32%|███▎      | 13/40 [42:49<1:50:21, 245.24s/it]


Episode 13, Episode Reward: -460.8950093910098, Steps: 26577, Epsilon: 0.25418658283290013


 35%|███▌      | 14/40 [47:36<1:51:46, 257.94s/it]


Episode 14, Episode Reward: -793.6000152053311, Steps: 49441, Epsilon: 0.22876792454961012


 38%|███▊      | 15/40 [52:17<1:50:22, 264.89s/it]


Episode 15, Episode Reward: -176.27000698260963, Steps: 22390, Epsilon: 0.2058911320946491


 40%|████      | 16/40 [56:54<1:47:26, 268.59s/it]


Episode 16, Episode Reward: -233.36000366602093, Steps: 15067, Epsilon: 0.1853020188851842


 42%|████▎     | 17/40 [1:01:56<1:46:47, 278.57s/it]


Episode 17, Episode Reward: -634.4250093987212, Steps: 46693, Epsilon: 0.16677181699666577


 45%|████▌     | 18/40 [1:06:54<1:44:16, 284.38s/it]


Episode 18, Episode Reward: -235.70500364527106, Steps: 18155, Epsilon: 0.1500946352969992


 48%|████▊     | 19/40 [1:11:52<1:40:58, 288.48s/it]


Episode 19, Episode Reward: -139.6500032087788, Steps: 16890, Epsilon: 0.13508517176729928


 50%|█████     | 20/40 [1:16:44<1:36:34, 289.71s/it]


Episode 20, Episode Reward: 61.01999752037227, Steps: 9540, Epsilon: 0.12157665459056936


 52%|█████▎    | 21/40 [1:22:17<1:35:50, 302.64s/it]


Episode 21, Episode Reward: -734.095005095005, Steps: 63465, Epsilon: 0.10941898913151243


 55%|█████▌    | 22/40 [1:27:37<1:32:21, 307.87s/it]


Episode 22, Episode Reward: 182.1049973508343, Steps: 7946, Epsilon: 0.0984770902183612


 57%|█████▊    | 23/40 [1:32:52<1:27:50, 310.04s/it]


Episode 23, Episode Reward: -66.63500150013715, Steps: 11820, Epsilon: 0.08862938119652508


 60%|██████    | 24/40 [1:38:26<1:24:36, 317.27s/it]


Episode 24, Episode Reward: -451.70500161033124, Steps: 35694, Epsilon: 0.07976644307687257


 62%|██████▎   | 25/40 [1:43:58<1:20:24, 321.62s/it]


Episode 25, Episode Reward: -273.2400004938245, Steps: 12644, Epsilon: 0.07178979876918531


 65%|██████▌   | 26/40 [1:49:21<1:15:06, 321.90s/it]


Episode 26, Episode Reward: -348.8099996307865, Steps: 8869, Epsilon: 0.06461081889226679


 68%|██████▊   | 27/40 [1:54:52<1:10:21, 324.76s/it]


Episode 27, Episode Reward: -26.845001183450222, Steps: 15299, Epsilon: 0.05814973700304011


 70%|███████   | 28/40 [2:00:37<1:06:10, 330.87s/it]


Episode 28, Episode Reward: -130.81500063464046, Steps: 33954, Epsilon: 0.0523347633027361


 72%|███████▎  | 29/40 [2:06:23<1:01:29, 335.40s/it]


Episode 29, Episode Reward: -211.9950000019744, Steps: 12663, Epsilon: 0.04710128697246249


 75%|███████▌  | 30/40 [2:12:11<56:31, 339.20s/it]  


Episode 30, Episode Reward: 50.39499890431762, Steps: 11642, Epsilon: 0.042391158275216244


 78%|███████▊  | 31/40 [2:17:55<51:04, 340.48s/it]


Episode 31, Episode Reward: 75.86499898321927, Steps: 9092, Epsilon: 0.03815204244769462


 80%|████████  | 32/40 [2:23:37<45:28, 341.12s/it]


Episode 32, Episode Reward: 166.22499839682132, Steps: 8312, Epsilon: 0.03433683820292516


 82%|████████▎ | 33/40 [2:29:24<40:00, 342.87s/it]


Episode 33, Episode Reward: -28.600000829435885, Steps: 11047, Epsilon: 0.030903154382632643


 85%|████████▌ | 34/40 [2:35:11<34:24, 344.12s/it]


Episode 34, Episode Reward: 181.33499837014824, Steps: 6693, Epsilon: 0.02781283894436938


 88%|████████▊ | 35/40 [2:40:58<28:43, 344.73s/it]


Episode 35, Episode Reward: 70.40999890398234, Steps: 6562, Epsilon: 0.025031555049932444


 90%|█████████ | 36/40 [2:46:47<23:04, 346.08s/it]


Episode 36, Episode Reward: -137.4550003753975, Steps: 8786, Epsilon: 0.0225283995449392


 92%|█████████▎| 37/40 [2:52:39<17:23, 347.98s/it]


Episode 37, Episode Reward: 136.9599985582754, Steps: 6441, Epsilon: 0.020275559590445278


 95%|█████████▌| 38/40 [2:58:33<11:39, 349.59s/it]


Episode 38, Episode Reward: -315.00999911967665, Steps: 8915, Epsilon: 0.01824800363140075


 98%|█████████▊| 39/40 [3:04:33<05:52, 352.78s/it]


Episode 39, Episode Reward: 70.1149987373501, Steps: 15808, Epsilon: 0.016423203268260675


100%|██████████| 40/40 [3:10:32<00:00, 285.82s/it]


Episode 40, Episode Reward: 174.69499843847007, Steps: 6970, Epsilon: 0.014780882941434608





In [15]:

trainer.policy_dqn

QNetwork(
  (cnn): Sequential(
    (0): Conv2d(5, 5, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): Conv2d(5, 5, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU()
  )
  (network): Sequential(
    (0): Linear(in_features=405, out_features=120, bias=True)
    (1): ReLU()
    (2): Linear(in_features=120, out_features=84, bias=True)
    (3): ReLU()
    (4): Linear(in_features=84, out_features=21, bias=True)
  )
)

In [16]:
torch.save(trainer.policy_dqn.state_dict(), 'blue.pt')