In [1]:
# !pip install -q git+https://github.com/Farama-Foundation/MAgent2

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for magent2 (pyproject.toml) ... [?25l[?25hdone


In [2]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from collections import deque
import random
import torch.nn as nn
from magent2.environments import battle_v4
from torch.utils.data import Dataset, DataLoader
from time import time 

In [3]:

class PretrainedQNetwork(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
        )
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, action_shape),
        )
       # self.apply(kaiming_init)

    def forward(self, x):
        assert len(x.shape) >= 3
        x = self.cnn(x)
        if len(x.shape) == 3:
            batchsize = 1
        else:
            batchsize = x.shape[0]
        x = x.reshape(batchsize, -1)
        return self.network(x)

  and should_run_async(code)


In [4]:
class Final_QNets(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
        )
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            # nn.LayerNorm(120),
            nn.ReLU(),
            nn.Linear(120, 84),
            # nn.LayerNorm(84),
            nn.Tanh(),
        )
        self.last_layer = nn.Linear(84, action_shape)

    def forward(self, x):
        assert len(x.shape) >= 3, "only support magent input observation"
        x = self.cnn(x)
        if len(x.shape) == 3:
            batchsize = 1
        else:
            batchsize = x.shape[0]
        x = x.reshape(batchsize, -1)
        x = self.network(x)
        self.last_latent = x
        return self.last_layer(x)

In [5]:
class TestQAgent: 

    def __init__(self, n_observation, n_actions, model_path: str): 
        self.qnetwork = PretrainedQNetwork(n_observation, n_actions)
        self.n_action = n_actions
        self.qnetwork.load_state_dict(
            torch.load(model_path, weights_only=True, map_location="cpu")
        ) 

    def get_action(self, observation):

        if np.random.rand() < 0.05:
            return np.random.randint(self.n_action)
        else:
            observation = (
                        torch.Tensor(observation).float().permute([2, 0, 1]).unsqueeze(0)
                    )
            with torch.no_grad():
                q_values = self.qnetwork(observation)
                action = torch.argmax(q_values, dim=1).numpy()[0]

        return action


In [6]:
from tqdm import tqdm

def eval(model_path : str):
    max_cycles = 300
    env = battle_v4.env(map_size=45, max_cycles=max_cycles)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    def random_policy(env, agent, obs):
        return env.action_space(agent).sample()
    
    
    q_network = PretrainedQNetwork(
        env.observation_space("red_0").shape, env.action_space("red_0").n
    )
    q_network.load_state_dict(
        torch.load("/kaggle/input/pretrained/pytorch/default/1/red.pt", weights_only=True, map_location="cpu")
    )
    q_network.to(device)

    final_q_network = Final_QNets(
        env.observation_space("red_0").shape, env.action_space("red_0").n
    )
    final_q_network.load_state_dict(
        torch.load("/kaggle/input/final_rl/pytorch/default/1/red_final.pt", weights_only=True, map_location="cpu")
    )
    final_q_network.to(device)

    def my_policy(env, agent, obs):
        my_agent = TestQAgent(env.observation_space("red_0").shape,  env.action_space("red_0").n, model_path= model_path)
        return my_agent.get_action(obs)


    def pretrain_policy(env, agent, obs):
        observation = (
            torch.Tensor(obs).float().permute([2, 0, 1]).unsqueeze(0).to(device)
        )
        with torch.no_grad():
            q_values = q_network(observation)
        return torch.argmax(q_values, dim=1).cpu().numpy()[0]

    def final_pretrain_policy(env, agent, obs):
        observation = (
            torch.Tensor(obs).float().permute([2, 0, 1]).unsqueeze(0).to(device)
        )
        with torch.no_grad():
            q_values = final_q_network(observation)
        return torch.argmax(q_values, dim=1).cpu().numpy()[0]

    def run_eval(env, red_policy, blue_policy, n_episode: int = 100):
        red_win, blue_win = [], []
        red_tot_rw, blue_tot_rw = [], []
        n_agent_each_team = len(env.env.action_spaces) // 2
        blue_agents = []
        red_agents = []

        for _ in tqdm(range(n_episode)):
            env.reset()
            n_kill = {"red": 0, "blue": 0}
            red_reward, blue_reward = 0, 0

            for agent in env.agent_iter():
                observation, reward, termination, truncation, info = env.last()
                agent_team = agent.split("_")[0]

                n_kill[agent_team] += (
                    reward > 4.5
                )  
                if agent_team == "red":
                    red_reward += reward
                else:
                    blue_reward += reward

                if termination or truncation:
                    action = None  
                else:
                    if agent_team == "red":
                        action = red_policy(env, agent, observation)
                    else:
                        action = blue_policy(env, agent, observation)

                env.step(action)

            who_wins = "red" if n_kill["red"] >= n_kill["blue"] + 5 else "draw"
            who_wins = "blue" if n_kill["red"] + 5 <= n_kill["blue"] else who_wins
            red_win.append(who_wins == "red")
            blue_win.append(who_wins == "blue")

            blue_agents.append(n_kill["blue"])
            red_agents.append(n_kill["red"])

            red_tot_rw.append(red_reward / n_agent_each_team)
            blue_tot_rw.append(blue_reward / n_agent_each_team)

        return {
            "winrate_red": np.mean(red_win),
            "winrate_blue": np.mean(blue_win),
            "average_rewards_red": np.mean(red_tot_rw),
            "average_rewards_blue": np.mean(blue_tot_rw),
            "red_kill": np.mean(red_agents) / n_agent_each_team,
            "blue_kill": np.mean(blue_agents) / n_agent_each_team,
        }

    print("=" * 20)
    print("Eval with random policy")
    print(
        run_eval(
            env=env, red_policy=random_policy, blue_policy=my_policy, n_episode=5
        )
    )
    print("=" * 20)

    print("Eval with trained policy")
    print(
        run_eval(
            env=env, red_policy=pretrain_policy, blue_policy=my_policy, n_episode=5
        )
    )
    print("=" * 20)

    print("Eval with final trained policy")
    print(
        run_eval(
            env=env,
            red_policy=final_pretrain_policy,
            blue_policy=my_policy,
            n_episode=5,
        )
    )
    print("=" * 20)

In [7]:
eval("/kaggle/input/my_model/pytorch/default/1/my_random5.pt")



Eval with random policy


100%|██████████| 5/5 [02:40<00:00, 32.12s/it]


{'winrate_red': 0.0, 'winrate_blue': 1.0, 'average_rewards_red': -1.155246953903065, 'average_rewards_blue': 4.63825922574572, 'red_kill': 0.0, 'blue_kill': 1.0}
Eval with trained policy


100%|██████████| 5/5 [01:46<00:00, 21.22s/it]


{'winrate_red': 0.0, 'winrate_blue': 1.0, 'average_rewards_red': -0.0058148162462461125, 'average_rewards_blue': 4.857901201187921, 'red_kill': 0.03209876543209877, 'blue_kill': 1.0}
Eval with final trained policy


100%|██████████| 5/5 [02:13<00:00, 26.73s/it]

{'winrate_red': 0.0, 'winrate_blue': 1.0, 'average_rewards_red': 2.206259242019811, 'average_rewards_blue': 4.754888859198049, 'red_kill': 0.48395061728395067, 'blue_kill': 0.9851851851851852}



