In [2]:
!pip install git+https://github.com/Farama-Foundation/MAgent2
!pip install torch
!pip install opencv-python

Collecting git+https://github.com/Farama-Foundation/MAgent2
  Cloning https://github.com/Farama-Foundation/MAgent2 to /tmp/pip-req-build-a_zxsm4l
  Running command git clone --filter=blob:none --quiet https://github.com/Farama-Foundation/MAgent2 /tmp/pip-req-build-a_zxsm4l
  Resolved https://github.com/Farama-Foundation/MAgent2 to commit b2ddd49445368cf85d4d4e1edcddae2e28aa1406
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Collecting pygame>=2.1.0 (from magent2==0.3.3)
  Downloading pygame-2.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading pygame-2.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.0/14.0 MB[0m [31m81.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hBuilding wheels for collected packages: magent2
  Building wheel f

In [3]:
from magent2.environments import battle_v4
import os
import numpy as np
import pandas as pd
import cv2
import torch
import torch.nn as nn
import random
from collections import deque
from IPython.display import HTML
from IPython.display import FileLink
import copy

## Config

In [24]:
# Env
env = battle_v4.env(map_size=45, render_mode="rgb_array")

# Model
model_folder = '/kaggle/input/pretrained-qnet'
red_model_name = 'red.pt'
blue_model_name = '45-1024.pth'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Result Video
vid_dir = "video"
os.makedirs(vid_dir, exist_ok=True)
fps = 35
video_name = f'red_{red_model_name.replace(".", "")} vs blue_{blue_model_name.replace(".", "")}'

In [25]:
class QNetwork(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
        )
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, action_shape),
        )

    def forward(self, x):
        assert len(x.shape) >= 3, "only support magent input observation"
        x = self.cnn(x)
        if len(x.shape) == 3:
            batchsize = 1
        else:
            batchsize = x.shape[0]
        x = x.reshape(batchsize, -1)
        return self.network(x)

In [26]:
red_model = QNetwork(env.observation_space("red_0").shape, env.action_space("red_0").n).to(device)
red_model.load_state_dict(
    torch.load(f"{model_folder}/{red_model_name}", weights_only=True, map_location=device)
)

blue_model = QNetwork(env.observation_space("red_0").shape, env.action_space("red_0").n).to(device)
blue_model.load_state_dict(
    torch.load(f"{model_folder}/{blue_model_name}", weights_only=True, map_location=device)
)

<All keys matched successfully>

In [27]:
def make_battle_video(env, red_model, blue_model, vid_dir, fps=24, name='result', device='cpu'):
    frames = []
    env.reset()
    add_frame = True
    for agent in env.agent_iter():
        observation, reward, termination, truncation, info = env.last()
        observation = torch.Tensor(observation).float().permute([2, 0, 1]).unsqueeze(0).to(device)
    
        if termination or truncation:
            action = None
        else:
            agent_handle = agent.split("_")[0]
            with torch.no_grad():
                if agent_handle == "blue":
                    q_values = blue_model(observation)
                else:
                    q_values = red_model(observation)
                    # q_values = torch.randn((1, env.action_space("red_0").n))
            action = torch.argmax(q_values, dim=1).cpu().numpy()[0]
        env.step(action)
    
        if 'red' in agent:
            if add_frame:
                frames.append(env.render())
                add_frame = False
        else:
            add_frame = True
    frames.append(env.render())
    
    height, width, _ = frames[0].shape
    out = cv2.VideoWriter(
        os.path.join(vid_dir, f"{name}.mp4"),
        cv2.VideoWriter_fourcc(*"mp4v"),
        fps,
        (width, height),
    )
    for frame in frames:
        frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        out.write(frame_bgr)
    out.release()
    print("Done recording pretrained agents")

In [28]:
make_battle_video(env, red_model, blue_model, vid_dir, fps=fps, name=video_name, device=device)

Done recording pretrained agents
