In [None]:
import torch 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
import matplotlib.pyplot as plt
import magent2
import os
import gymnasium as gym

In [2]:
from magent2.environments import battle_v4

In [3]:
from tqdm import tqdm

In [4]:
# set cuda visible devices
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


# Model


In [5]:
class BlueAgent(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
        )
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, action_shape),
        )

    def forward(self, x):
        assert len(x.shape) >= 3, "only support magent input observation"
        x = self.cnn(x)
        if len(x.shape) == 3:
            batchsize = 1
        else:
            batchsize = x.shape[0]
        x = x.reshape(batchsize, -1)
        return self.network(x)


In [6]:
class RedAgent(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
        )
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, action_shape),
        )

    def forward(self, x):
        assert len(x.shape) >= 3, "only support magent input observation"
        x = self.cnn(x)
        if len(x.shape) == 3:
            batchsize = 1
        else:
            batchsize = x.shape[0]
        x = x.reshape(batchsize, -1)
        return self.network(x)


In [11]:
class FinalRedAgent(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
        )
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            # nn.LayerNorm(120),
            nn.ReLU(),
            nn.Linear(120, 84),
            # nn.LayerNorm(84),
            nn.Tanh(),
        )
        self.last_layer = nn.Linear(84, action_shape)

    def forward(self, x):
        assert len(x.shape) >= 3, "only support magent input observation"
        x = self.cnn(x)
        if len(x.shape) == 3:
            batchsize = 1
        else:
            batchsize = x.shape[0]
        x = x.reshape(batchsize, -1)
        x = self.network(x)
        self.last_latent = x
        return self.last_layer(x)

# Eval function 

In [7]:
final_weight_path = "q_network_20241216184908-100point.pt"
max_cycles = 300
env = battle_v4.env(map_size=45, max_cycles=max_cycles)
device = "cuda" if torch.cuda.is_available() else "cpu"


    

In [8]:
def random_policy(env, agent, obs):
    return env.action_space(agent).sample()

q_network = BlueAgent(env.observation_space("red_0").shape, env.action_space("red_0").n)
q_network.load_state_dict(
    torch.load(final_weight_path, weights_only=True, map_location="cpu")
)
q_network.to(device)

def blue_pretrain_policy(env, agent, obs):
    observation = (
        torch.Tensor(obs).float().permute([2, 0, 1]).unsqueeze(0).to(device)
    )
    with torch.no_grad():
        q_values = q_network(observation)
    return torch.argmax(q_values, dim=1).cpu().numpy()[0]




In [9]:
red_network = RedAgent(
    env.observation_space("red_0").shape, env.action_space("red_0").n
)
red_network.load_state_dict(
    torch.load("red.pt", weights_only=True, map_location="cpu")
)
red_network.to(device)



def red_pretrain_policy(env, agent, obs):
    observation = (
        torch.Tensor(obs).float().permute([2, 0, 1]).unsqueeze(0).to(device)
    )
    with torch.no_grad():
        q_values = red_network(observation)
    return torch.argmax(q_values, dim=1).cpu().numpy()[0]



In [12]:
red_final_network = FinalRedAgent(env.observation_space("red_0").shape, env.action_space("red_0").n)
red_final_network.load_state_dict(torch.load("red_final.pt", weights_only=True, map_location="cpu"))
red_final_network.to(device)

def final_red_pretrain_policy(env, agent, obs):
    observation = (
        torch.Tensor(obs).float().permute([2, 0, 1]).unsqueeze(0).to(device)
    )
    with torch.no_grad():
        q_values = red_final_network(observation)
    return torch.argmax(q_values, dim=1).cpu().numpy()[0]


In [13]:
def run_eval(env, red_policy, blue_policy, n_episode: int = 100):
    red_win, blue_win = [], []
    red_tot_rw, blue_tot_rw = [], []
    n_agent_each_team = len(env.env.action_spaces) // 2

    for _ in tqdm(range(n_episode)):
        env.reset()
        n_kill = {"red": 0, "blue": 0}
        red_reward, blue_reward = 0, 0

        for agent in env.agent_iter():
            observation, reward, termination, truncation, info = env.last()
            agent_team = agent.split("_")[0]

            n_kill[agent_team] += (
                reward > 4.5
            )  # This assumes default reward settups
            if agent_team == "red":
                red_reward += reward
            else:
                blue_reward += reward

            if termination or truncation:
                action = None  # this agent has died
            else:
                if agent_team == "red":
                    action = red_policy(env, agent, observation)
                else:
                    action = blue_policy(env, agent, observation)

            env.step(action)

        who_wins = "red" if n_kill["red"] >= n_kill["blue"] + 5 else "draw"
        who_wins = "blue" if n_kill["red"] + 5 <= n_kill["blue"] else who_wins
        red_win.append(who_wins == "red")
        blue_win.append(who_wins == "blue")

        red_tot_rw.append(red_reward / n_agent_each_team)
        blue_tot_rw.append(blue_reward / n_agent_each_team)

    return {
        "winrate_red": np.mean(red_win),
        "winrate_blue": np.mean(blue_win),
        "average_rewards_red": np.mean(red_tot_rw),
        "average_rewards_blue": np.mean(blue_tot_rw),
    }

# Eval

In [14]:
print("=" * 20)

print("Eval with random policy")
print(
    run_eval(
        env=env, red_policy=random_policy, blue_policy= blue_pretrain_policy, n_episode=10
    )
)



Eval with random policy


100%|██████████| 10/10 [02:00<00:00, 12.04s/it]

{'winrate_red': 0.0, 'winrate_blue': 1.0, 'average_rewards_red': -3.262290238263661, 'average_rewards_blue': 2.1981481072285938}





In [24]:
print("=" * 20)

print("Eval with red trained policy")
print(
    run_eval(
        env=env,
        red_policy=red_pretrain_policy,
        blue_policy=blue_pretrain_policy,
        n_episode=30,
    )
)



Eval with red trained policy


100%|██████████| 30/30 [01:47<00:00,  3.57s/it]

{'winrate_red': 0.0, 'winrate_blue': 1.0, 'average_rewards_red': 0.8483765369760249, 'average_rewards_blue': 3.6821563055087267}





In [25]:
print("=" * 20)

print("Eval with final red trained policy")
print(
    run_eval(
        env=env,
        red_policy=final_red_pretrain_policy,
        blue_policy=blue_pretrain_policy,
        n_episode=30,
    )
)

Eval with final red trained policy


100%|██████████| 30/30 [04:15<00:00,  8.52s/it]

{'winrate_red': 0.0, 'winrate_blue': 1.0, 'average_rewards_red': 1.400993809580742, 'average_rewards_blue': 1.5769504811250654}





# Video 

In [21]:

# epsilon decay
epsilon= 0
# reset env
# env.reset()
env = battle_v4.env(map_size=45, max_cycles=max_cycles, render_mode="rgb_array")
env.reset()
# render_game_image(env)

rewards = [0, 0] # red reward, blue reward
cycle_count = 0
last_agent_team = None
frames = []
#store data of prev cycle
for agent_id in env.agent_iter():
    
    observation, reward, termination, truncation, info = env.last()
    agent_team = agent_id.split('_')[0]
    
    if agent_team == "blue":
        #get action
        if termination or truncation:
            action = None # this agent has died
        else:
            # action = get_blue_agent_action(agent_id ,q_network, observation, epsilon, env)
            action = blue_pretrain_policy(env, agent_id, observation)
    
    else: #red agent_id (random)
        if termination or truncation:
            action = None # this agent_id has died
        else:
            action = final_red_pretrain_policy(env, agent_id, observation)
        
    env.step(action)
    rewards[0] += reward if agent_team == "red" else 0
    rewards[1] += reward if agent_team == "blue" else 0
    
    if agent_team != last_agent_team and agent_team == "red":
        # frames.append(env.render())
        cycle_count += 1
        frames.append(env.render())
    last_agent_team = agent_team
        
print(f"step length: {cycle_count}, red reward: {rewards[0]}, blue reward: {rewards[1]}, epsilon: {epsilon}")

    


step length: 212, red reward: 159.41499827522784, blue reward: 178.2249902104959, epsilon: 0


In [22]:
import cv2
import os
def record_game_video(frames, vid_dir="videos", video_name="random2", fps=5):
    height, width, _ = frames[0].shape
    out = cv2.VideoWriter(
        os.path.join(vid_dir, video_name),
        cv2.VideoWriter_fourcc(*"mp4v"),
        fps,
        (width, height),
    )
    for frame in frames:
        # Convert BGR to RGB since OpenCV uses BGR by default
        frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        out.write(frame_bgr)
    out.release()




In [23]:
record_game_video(frames, vid_dir="video", video_name="test_16_12_final_red2.mp4")