In [None]:
!pip install git+https://github.com/Farama-Foundation/MAgent2

## Import

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import numpy as np
import os
from tqdm import tqdm

from magent2.environments import battle_v4
import cv2
from collections import deque
import time
import random
import warnings
warnings.filterwarnings('ignore')

## DQN

In [None]:
class QNetwork(nn.Module):
    def __init__(self, observation_shape, action_shape, device='cpu'):
        super().__init__()
        self.observation_shape = observation_shape
        self.action_shape = action_shape
        self.device = device

        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], kernel_size=3),
            nn.ReLU(),
        )

        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, action_shape),
        )

    def forward(self, x):
        assert len(x.shape) >= 3, "only support magent input observation"
        x = self.cnn(x)
        if len(x.shape) == 3:
            batchsize = 1
        else:
            batchsize = x.shape[0]
        x = x.reshape(batchsize, -1)
        return self.network(x)

## Replay Buffer

## Trainer

In [None]:
class ReplayMemory(Dataset):
    def __init__(self, maxlen,device = "cpu"):
        super().__init__()
        self.maxlen = maxlen
        self.step_memory = [deque([],maxlen=self.maxlen)]
        self.device = device

    def push(self, step_idx, state, action, reward, next_state, done):
        if step_idx == len(self.step_memory):
            self.step_memory.append(deque([],maxlen=self.maxlen))
        self.step_memory[step_idx].append((state, action, reward, next_state, done))

    def __len__(self):
        return sum([len(memory) for memory in self.step_memory])

    def __getitem__(self, idx):
        step_idx = 0
        while idx >= len(self.step_memory[step_idx]):
            idx -= len(self.step_memory[step_idx])
            step_idx += 1
        state, action, reward, next_state, done = self.step_memory[step_idx][idx]
        return (
            torch.Tensor(state).float().permute([2, 0, 1]).to(self.device),
            torch.tensor(action).to(self.device),
            torch.tensor(reward, dtype=torch.float).to(self.device),
            torch.tensor(next_state).float().permute([2,0,1]).to(self.device),
            torch.tensor(done, dtype=torch.float32).to(self.device),
        )

In [None]:
class Trainer:
    def __init__(
        self,
        policy_dqn, target_dqn,
        n_action,
        loss_fn, optimizer, scheduler,
        epsilon_start, epsilon_end, epsilon_decay,
        device='cpu'
    ):

        self.policy_dqn = policy_dqn.to(device)
        self.target_dqn = target_dqn.to(device)
        self.target_dqn.eval()

        self.n_action = n_action

        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.scheduler = scheduler

        self.epsilon_start = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay = epsilon_decay
        self.epsilon = self.epsilon_start

        self.device = device

        self.policy_dqn.apply(self.weights_init)

    def weights_init(self, m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            nn.init.xavier_uniform_(m.weight)
            if torch.is_tensor(m.bias):
                m.bias.data.fill_(0.01)


    def policy(self, observation):
        if np.random.rand() < self.epsilon:
            return np.random.randint(low=0, high=self.n_action)
        else:
            with torch.no_grad():
                q_values = self.policy_dqn(
                    torch.Tensor(observation).float().permute([2, 0, 1]).unsqueeze(0).to(self.device)
                )
            return torch.argmax(q_values, dim=1).cpu().numpy()[0]


    def optimize_model(self, replay_memory, batch_size, gamma):
        if len(replay_memory) < batch_size:
            return
        train_loader = DataLoader(replay_memory, batch_size=batch_size, shuffle=True)
        self.policy_dqn.train()

        for observations, actions, rewards, next_observations, dones in train_loader:

            self.policy_dqn.zero_grad()

            observations = observations.to(device)
            actions = actions.unsqueeze(1).to(device)
            rewards = rewards.unsqueeze(1).to(device)
            next_observations = next_observations.to(device)
            dones = dones.unsqueeze(1).to(device)

            current_q_values = self.policy_dqn(observations).gather(1, actions)

            with torch.no_grad():
                target_q_values = rewards + gamma * (1 - dones) * self.target_dqn(next_observations).max(1, keepdim=True)[0]

            # Compute loss
            loss = self.loss_fn(current_q_values, target_q_values)

            # Optimize the network
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()

    def train(self,
              env, episodes,
              target_agent, batch_size, gamma, replay_memory,
              update_tg_freq, TAU
             ):
        train_rewards = []
        train_durations = []

        for episode in tqdm(range(episodes)):
            ep_reward = 0

            ep_steps = 0

            observations = {}
            actions = {}
            step_idx = {}

            env.reset()

            for idx, agent in enumerate(env.agent_iter()):
                ep_steps += 1
                observation, reward, termination, truncation, info = env.last()

                if target_agent in agent:
                    ep_reward += reward
                else:
                    ep_reward -= abs(reward)

                step_idx[agent] = 0
                action = self.policy(observation)

                observations[agent] = observation
                actions[agent] = action
                env.step(action)

                if (idx+1) % env.num_agents == 0:
                    break

            for agent in env.agent_iter():
                ep_steps += 1

                next_observation, reward, termination, truncation, info = env.last()

                if target_agent in agent:
                    ep_reward += reward
                else:
                    ep_reward -= abs(reward)

                # Agent die
                if termination or truncation:
                    action = None
                else:
                    action = self.policy(next_observation)

                replay_memory.push(
                    step_idx[agent],
                    observations[agent],
                    actions[agent],
                    reward,
                    next_observation,
                    termination or truncation
                )

                step_idx[agent] += 1
                observations[agent] = next_observation
                actions[agent] = action
                env.step(action)

            # Training mô hình với memory hiện tại
            self.optimize_model(replay_memory, batch_size, gamma)

            # Cập nhật lại mô hình mục tiêu theo chu kì
            if episode % update_tg_freq == 0:
                target_dqn_state_dict = self.target_dqn.state_dict()
                policy_dqn_state_dict = self.policy_dqn.state_dict()
                for key in policy_dqn_state_dict:
                    target_dqn_state_dict[key] = policy_dqn_state_dict[key]*TAU + target_dqn_state_dict[key]*(1-TAU)
                self.target_dqn.load_state_dict(target_dqn_state_dict)


            self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)

            print(f"\nEpisode {episode + 1}, Episode Reward: {ep_reward}, Steps: {ep_steps}, Epsilon: {self.epsilon}")

            train_rewards.append(ep_reward)
            train_durations.append(ep_steps)

        return train_rewards, train_durations

## Config

In [None]:

env = battle_v4.env(map_size=45, render_mode="rgb_array",step_reward=-0.05, max_cycles = 300)

episodes = 50
target_agent = 'blue'
batch_size = int(81 * episodes)
gamma = 0.9
update_tg_freq = 1
TAU = 0.3

maxlen = 162 * episodes

learning_rate = 1e-3
theta = 0.000001
epsilon_start = 1.0
epsilon_end = 0.01
epsilon_decay = 0.9
device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [None]:
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

## Training Loop

In [None]:
replay_memory = ReplayMemory(maxlen, device)

In [None]:
policy_dqn = QNetwork(env.observation_space("blue_0").shape, env.action_space("blue_0").n).to(device)
target_dqn = QNetwork(env.observation_space("blue_0").shape, env.action_space("blue_0").n).to(device)
target_dqn.load_state_dict(policy_dqn.state_dict())

In [None]:
loss_function = nn.MSELoss()
optimizer = torch.optim.AdamW(policy_dqn.parameters(), lr=learning_rate)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=episodes, eta_min=theta)

In [None]:

trainer = Trainer(
    policy_dqn, target_dqn,
    env.action_space("red_0").n,
    loss_function, optimizer, lr_scheduler,
    epsilon_start, epsilon_end, epsilon_decay,
    device=device
)

In [None]:

train_rewards, train_durations = trainer.train(
    env, episodes,
    target_agent, batch_size, gamma, replay_memory,
    update_tg_freq, TAU
)

In [None]:
trainer.policy_dqn

In [None]:
torch.save(trainer.policy_dqn.state_dict(), 'blue.pt')

## Eval Model

In [None]:
%%capture
!git clone https://github.com/giangbang/RL-final-project-AIT-3007
%cd RL-final-project-AIT-3007

In [None]:
!python main.py

In [None]:
from magent2.environments import battle_v4
from torch_model import QNetwork
from final_torch_model import QNetwork as FinalQNetwork
import torch
import numpy as np

try:
    from tqdm import tqdm
except ImportError:
    tqdm = lambda x, *args, **kwargs: x  # Fallback: tqdm becomes a no-op


def eval():
    max_cycles = 300
    env = battle_v4.env(map_size=45, max_cycles=max_cycles)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    def random_policy(env, agent, obs):
        return env.action_space(agent).sample()

    q_network = QNetwork(
        env.observation_space("red_0").shape, env.action_space("red_0").n
    )
    q_network.load_state_dict(
        torch.load("red.pt", weights_only=True, map_location="cpu")
    )
    q_network.to(device)

    final_q_network = FinalQNetwork(
        env.observation_space("red_0").shape, env.action_space("red_0").n
    )
    final_q_network.load_state_dict(
        torch.load("red_final.pt", weights_only=True, map_location="cpu")
    )
    final_q_network.to(device)
# blue policy
    blue_q_network = QNetwork(
        env.observation_space("red_0").shape, env.action_space("red_0").n
    )
    blue_q_network.load_state_dict(
        torch.load("/kaggle/working/blue.pt", weights_only=True, map_location="cpu")
    )
    blue_q_network.to(device)
    
    def blue_policy(env, agent, obs):
        observation = (
            torch.Tensor(obs).float().permute([2, 0, 1]).unsqueeze(0).to(device)
        )
        with torch.no_grad():
            q_values = blue_q_network(observation)
        return torch.argmax(q_values, dim=1).cpu().numpy()[0]
    def pretrain_policy(env, agent, obs):
        observation = (
            torch.Tensor(obs).float().permute([2, 0, 1]).unsqueeze(0).to(device)
        )
        with torch.no_grad():
            q_values = q_network(observation)
        return torch.argmax(q_values, dim=1).cpu().numpy()[0]

    def final_pretrain_policy(env, agent, obs):
        observation = (
            torch.Tensor(obs).float().permute([2, 0, 1]).unsqueeze(0).to(device)
        )
        with torch.no_grad():
            q_values = final_q_network(observation)
        return torch.argmax(q_values, dim=1).cpu().numpy()[0]

    def run_eval(env, red_policy, blue_policy, n_episode: int = 100):
        red_win, blue_win = [], []
        red_tot_rw, blue_tot_rw = [], []
        n_agent_each_team = len(env.env.action_spaces) // 2

        step_game = []

        for _ in tqdm(range(n_episode)):
            env.reset()
            n_kill = {"red": 0, "blue": 0}
            red_reward, blue_reward = 0, 0
            total_step = 0

            for agent in env.agent_iter():
                observation, reward, termination, truncation, info = env.last()
                agent_team = agent.split("_")[0]

                n_kill[agent_team] += (
                    reward > 4.5
                )  # This assumes default reward settups
                if agent_team == "red":
                    red_reward += reward
                else:
                    blue_reward += reward

                if termination or truncation:
                    action = None  # this agent has died
                else:
                    if agent_team == "red":
                        action = red_policy(env, agent, observation)
                    else:
                        action = blue_policy(env, agent, observation)

                env.step(action)
                
                if not termination and not truncation:
                    total_step += 1
            
            step_game.append(total_step)

            who_wins = "red" if n_kill["red"] >= n_kill["blue"] + 5 else "draw"
            who_wins = "blue" if n_kill["red"] + 5 <= n_kill["blue"] else who_wins
            red_win.append(who_wins == "red")
            blue_win.append(who_wins == "blue")

            red_tot_rw.append(red_reward / n_agent_each_team)
            blue_tot_rw.append(blue_reward / n_agent_each_team)

        return {
            "winrate_red": np.mean(red_win),
            "winrate_blue": np.mean(blue_win),
            "average_rewards_red": np.mean(red_tot_rw),
            "average_rewards_blue": np.mean(blue_tot_rw),
            "average_step": int(np.mean(step_game)),
        }

    print("=" * 20)
    print("Eval with random policy")
    print(
        run_eval(
            env=env, red_policy=random_policy, blue_policy=blue_policy, n_episode=100
        )
    )
    print("=" * 20)

    print("Eval with trained policy")
    print(
        run_eval(
            env=env, red_policy=pretrain_policy, blue_policy=blue_policy, n_episode=100
        )
    )
    print("=" * 20)

    print("Eval with final trained policy")
    print(
        run_eval(
            env=env,
            red_policy=final_pretrain_policy,
            blue_policy=blue_policy,
            n_episode=100,
        )
    )
    print("=" * 20)


if __name__ == "__main__":
    eval()