In [None]:
!pip install git+https://github.com/Farama-Foundation/MAgent2
!git clone https://github.com/giangbang/RL-final-project-AIT-3007.git

In [None]:
import torch
import torch.nn as nn
import numpy as np
import random
import os
from magent2.environments import battle_v4
import sys
sys.path.append('RL-final-project-AIT-3007/')
import final_torch_model
import torch_model

class MyQNetwork(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], 13, kernel_size=3, padding=1, stride=1),
            nn.ReLU(),
            nn.Conv2d(13, 13, kernel_size=3, padding=1, stride=1),
            nn.ReLU(),
            nn.Conv2d(13, 13, kernel_size=3),
            nn.ReLU(),
        )
        
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        
        self.fc = nn.Sequential(
            nn.Linear(flatten_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, action_shape)
        )

    def forward(self, x):
        assert len(x.shape) >= 3, "only support magent input observation"
        out = self.cnn(x)
        
        if len(x.shape) == 3:
            batchsize = 1
            
        else:
            batchsize = x.shape[0]
            
        out = out.reshape(batchsize, -1)
        
        return self.fc(out)


In [3]:
# from torch_model import QNetwork
try:
    from tqdm import tqdm
except ImportError:
    tqdm = lambda x, *args, **kwargs: x  # Fallback: tqdm becomes a no-op
    
def eval():
    max_cycles = 300
    env = battle_v4.env(map_size=45, max_cycles=max_cycles)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    ### Load my network
    MyNetwork = MyQNetwork(observation_shape=env.observation_space("red_0").shape, action_shape=env.action_space("red_0").n)
    MyNetwork.load_state_dict(
        torch.load(
            "final_blue.pt", 
            weights_only=True, 
            map_location="cpu"
        )
    )
    MyNetwork.to(device)

    ### Load red_final model
    red_final_network = final_torch_model.QNetwork(
        env.observation_space("red_0").shape, env.action_space("red_0").n
    )
    red_final_network.load_state_dict(
        torch.load(
            "RL-final-project-AIT-3007/red_final.pt", 
            weights_only=True, 
            map_location="cpu"
        )
    )
    red_final_network = red_final_network.to(device)

    ### Load red model
    red_network = torch_model.QNetwork(
        env.observation_space("red_0").shape, env.action_space("red_0").n
    )
    red_network.load_state_dict(
        torch.load(
            "RL-final-project-AIT-3007/red.pt",
            weights_only=True, 
            map_location="cpu"
        )
    )
    red_network = red_network.to(device)

    def random_policy(env, agent, obs):
        return env.action_space(agent).sample()
        
    def red_final_policy(env, agent, obs):
        observation = (
            torch.Tensor(obs).float().permute([2, 0, 1]).unsqueeze(0).to(device)
        )
        red_final_network.eval()
        with torch.inference_mode():
            q_values = red_final_network(observation)
        return torch.argmax(q_values, dim=1).cpu().numpy()[0]
        
    def red_policy(env, agent, obs):
        observation = (
            torch.Tensor(obs).float().permute([2, 0, 1]).unsqueeze(0).to(device)
        )
        red_network.eval()
        with torch.inference_mode():
            q_values = red_network(observation)
        return torch.argmax(q_values, dim=1).cpu().numpy()[0]

    def my_policy(env, agent, obs):
        observation = (
            torch.Tensor(obs).float().permute([2, 0, 1]).unsqueeze(0).to(device)
        )
        MyNetwork.eval()
        with torch.inference_mode():
            q_values = MyNetwork(observation)
        return torch.argmax(q_values, dim=1).cpu().numpy()[0]

    # def policy()

    def run_eval(env, red_policy, blue_policy, n_episode: int = 100):
        red_win, blue_win = [], []
        red_tot_rw, blue_tot_rw = [], []
        n_agent_each_team = len(env.env.action_spaces) // 2

        for _ in tqdm(range(n_episode)):
            env.reset()
            n_kill = {"red": 0, "blue": 0}
            red_reward, blue_reward = 0, 0

            for agent in env.agent_iter():
                observation, reward, termination, truncation, info = env.last()
                agent_team = agent.split("_")[0]

                n_kill[agent_team] += (
                    reward > 4.5
                )  # This assumes default reward settups
                if agent_team == "red":
                    red_reward += reward
                else:
                    blue_reward += reward

                if termination or truncation:
                    action = None  # this agent has died
                else:
                    if agent_team == "red":
                        action = red_policy(env, agent, observation)
                    else:
                        action = blue_policy(env, agent, observation)

                env.step(action)

            who_wins = "red" if n_kill["red"] >= n_kill["blue"] + 5 else "draw"
            who_wins = "blue" if n_kill["red"] + 5 <= n_kill["blue"] else who_wins
            red_win.append(who_wins == "red")
            blue_win.append(who_wins == "blue")

            red_tot_rw.append(red_reward / n_agent_each_team)
            blue_tot_rw.append(blue_reward / n_agent_each_team)

        return {
            "winrate_red": np.mean(red_win),
            "winrate_blue": np.mean(blue_win),
            "average_rewards_red": np.mean(red_tot_rw),
            "average_rewards_blue": np.mean(blue_tot_rw),
        }

    print("=" * 20)
    print("Eval with red final policy")
    print(
        run_eval(
            env=env, red_policy=red_final_policy, blue_policy=my_policy, n_episode=30
        )
    )

    print("=" * 20)
    print("Eval with red policy")
    print(
        run_eval(
            env=env, red_policy=red_policy, blue_policy=my_policy, n_episode=30
        )
    )

    print("=" * 20)
    print("Eval with random policy")
    print(
        run_eval(
            env=env, red_policy=random_policy, blue_policy=my_policy, n_episode=30
        )
    )
    
    print("=" * 20)

if __name__ == "__main__":
    eval()



Eval with red final policy


100%|██████████| 30/30 [01:34<00:00,  3.15s/it]


{'winrate_red': 0.0, 'winrate_blue': 1.0, 'average_rewards_red': 1.5334773535778892, 'average_rewards_blue': 4.866333302184402}
Eval with red policy


100%|██████████| 30/30 [01:06<00:00,  2.20s/it]


{'winrate_red': 0.0, 'winrate_blue': 1.0, 'average_rewards_red': 0.2403106967106063, 'average_rewards_blue': 4.983382684575707}
Eval with random policy


100%|██████████| 30/30 [00:59<00:00,  1.98s/it]

{'winrate_red': 0.0, 'winrate_blue': 1.0, 'average_rewards_red': -0.9832119690580102, 'average_rewards_blue': 4.9292242474007}



