In [1]:
!pip install git+https://github.com/Farama-Foundation/MAgent2

Collecting git+https://github.com/Farama-Foundation/MAgent2
  Cloning https://github.com/Farama-Foundation/MAgent2 to /tmp/pip-req-build-w9twgot3
  Running command git clone --filter=blob:none --quiet https://github.com/Farama-Foundation/MAgent2 /tmp/pip-req-build-w9twgot3
  Resolved https://github.com/Farama-Foundation/MAgent2 to commit b2ddd49445368cf85d4d4e1edcddae2e28aa1406
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Collecting pygame>=2.1.0 (from magent2==0.3.3)
  Downloading pygame-2.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading pygame-2.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.0/14.0 MB[0m [31m90.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hBuilding wheels for collected packages: magent2
  Building wheel f

In [14]:
!git clone https://github.com/giangbang/RL-final-project-AIT-3007.git

Cloning into 'RL-final-project-AIT-3007'...
remote: Enumerating objects: 47, done.[K
remote: Counting objects: 100% (15/15), done.[K
remote: Compressing objects: 100% (9/9), done.[K
remote: Total 47 (delta 11), reused 6 (delta 6), pack-reused 32 (from 1)[K
Receiving objects: 100% (47/47), 13.67 MiB | 41.55 MiB/s, done.
Resolving deltas: 100% (22/22), done.


In [15]:
import sys
sys.path.append('/kaggle/working/RL-final-project-AIT-3007')

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
from collections import deque, Counter
import os
from magent2.environments import battle_v4
import time
# from torch_model import QNetwork

class MyQNetwork(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], 13, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(13, 13, kernel_size=3),
            nn.ReLU()
        )
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        self.fc = nn.Sequential(
            nn.Linear(flatten_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, action_shape)
        )

    def forward(self, x):
        assert len(x.shape) >= 3, "only support magent input observation"
        x = self.cnn(x)
        if len(x.shape) == 3:
            batchsize = 1
        else:
            batchsize = x.shape[0]
        x = x.reshape(batchsize, -1)
        return self.fc(x)

class ReplayBuffer(Dataset):
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)
        
    def add(self, state, action, reward, next_state, done):
        experience = (state, action, reward, next_state, done)
        self.buffer.append(experience)

    def __len__(self):
        return len(self.buffer)

    def __getitem__(self, index):
        return self.buffer[index]

In [9]:
class Trainer:
    def __init__(self, env, input_shape, action_shape, learning_rate=1e-3):
        self.env = env
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.q_network = MyQNetwork(input_shape, action_shape).to(self.device)
        # self.q_network.load_state_dict(
        #     torch.load("/kaggle/input/blue-improved/blue_improved.pt", weights_only=True)
        # )
        self.target_network = MyQNetwork(input_shape, action_shape).to(self.device)
        self.target_network.load_state_dict(self.q_network.state_dict())

        # self.red_pretrained_network = QNetwork(input_shape, action_shape).to(self.device)
        # self.red_pretrained_network.load_state_dict(
        #     torch.load("/kaggle/working/RL-final-project-AIT-3007/red.pt", weights_only=True)
        # )

        self.criterion = nn.MSELoss()
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=learning_rate)
        self.replay_buffer = ReplayBuffer(capacity=16200 * 10)

        self.gamma = 0.99
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.9
        self.update_target_every = 2

    def select_action(self, observation, agent):
        if np.random.rand() <= self.epsilon:
            return self.env.action_space(agent).sample()

        observation = (
            torch.FloatTensor(observation).unsqueeze(0).to(self.device)
        )
        with torch.inference_mode():
            q_values = self.q_network(observation)
        return torch.argmax(q_values, dim=1).item()

    def pretrained_action(self, observation):
        observation = (
            torch.FloatTensor(observation).unsqueeze(0).to(self.device)
        )
        with torch.inference_mode():
            q_values = self.red_pretrained_network(observation)
        return torch.argmax(q_values, dim=1).item()

    def training(self, episodes=50, batch_size=2 ** 10):        
        for episode in range(episodes):
            self.env.reset()
            
            total_reward = 0
            reward_for_agent = {agent: 0 for agent in self.env.agents if agent.startswith('blue')}
            prev_observation = {}
            prev_action = {}
            self.env.reset()
            step = 0

            for idx, agent in enumerate(self.env.agent_iter()):
                step += 1
                observation, reward, termination, truncation, info = self.env.last()
                observation = np.transpose(observation, (2, 0, 1))
                
                agent_handle = agent.split('_')[0]
                
                if agent_handle == 'blue':
                    total_reward += reward
                    reward_for_agent[agent] += reward
                    
                if termination or truncation:
                    action = None
                else:
                    if agent_handle == 'blue':
                        action = self.select_action(observation, agent)
                    else:
                        action = self.env.action_space(agent).sample()
                        # action = self.pretrained_action(observation)

                if agent_handle == 'blue':
                    prev_observation[agent] = observation
                    prev_action[agent] = action
                
                self.env.step(action)
                
                if (idx + 1) % self.env.num_agents == 0:
                    break
                
            for agent in self.env.agent_iter():
                step += 1
                
                observation, reward, termination, truncation, info = self.env.last()
                observation = np.transpose(observation, (2, 0, 1))
                
                agent_handle = agent.split('_')[0]
                
                if agent_handle == 'blue':
                    total_reward += reward
                    reward_for_agent[agent] += reward
                    
                if termination or truncation:
                    action = None
                else:
                    if agent_handle == 'blue':
                        action = self.select_action(observation, agent)
                    else:
                        action = self.env.action_space(agent).sample()
                        # action = self.pretrained_action(observation)
    
                    if agent_handle == 'blue':
                        self.replay_buffer.add(
                            prev_observation[agent],
                            prev_action[agent],
                            reward,  
                            observation,
                            termination
                        )

                        prev_observation[agent] = observation
                        prev_action[agent] = action
    
                self.env.step(action)
            
            dataloader = DataLoader(self.replay_buffer, batch_size=batch_size, shuffle=True)
            self.update_model(dataloader)
                
            if (episode + 1) % self.update_target_every == 0:
                self.target_network.load_state_dict(self.q_network.state_dict())
    
            max_reward = max(reward_for_agent.values())
            
            print(f"Episode {episode}, Epsilon: {self.epsilon:.2f}, Total Reward: {total_reward}, Steps: {step}, Max Reward: {max_reward} ")
            self.epsilon = max(self.epsilon * self.epsilon_decay, self.epsilon_min)

    def update_model(self, dataloader):
        for states, actions, rewards, next_states, dones in dataloader:

            states = torch.tensor(states, dtype=torch.float32).to(self.device)
            actions = torch.tensor(actions, dtype=torch.long).to(self.device)
            rewards = torch.tensor(rewards, dtype=torch.float32).to(self.device)
            next_states = torch.tensor(next_states, dtype=torch.float32).to(self.device)
            dones = torch.tensor(dones, dtype=torch.float32).to(self.device)

            current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
            with torch.inference_mode():
                next_q_values = self.target_network(next_states).max(1)[0]
            expected_q_values = rewards + (self.gamma * next_q_values * (1 - dones))

            loss = self.criterion(current_q_values, expected_q_values)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

In [10]:
env = battle_v4.env(map_size=45, render_mode=None)
trainer = Trainer(env, env.observation_space("red_0").shape, env.action_space("red_0").n)
trainer.training()

  states = torch.tensor(states, dtype=torch.float32).to(self.device)
  actions = torch.tensor(actions, dtype=torch.long).to(self.device)
  rewards = torch.tensor(rewards, dtype=torch.float32).to(self.device)
  next_states = torch.tensor(next_states, dtype=torch.float32).to(self.device)
  dones = torch.tensor(dones, dtype=torch.float32).to(self.device)


Episode 0, Epsilon: 1.00, Total Reward: -3244.045118597336, Steps: 159574, Max Reward: -4.495000167749822 
Episode 1, Epsilon: 0.90, Total Reward: -3051.900108466856, Steps: 158424, Max Reward: -31.700001289136708 
Episode 2, Epsilon: 0.81, Total Reward: -2698.4000980220735, Steps: 149892, Max Reward: -21.500001321546733 
Episode 3, Epsilon: 0.73, Total Reward: -2163.8750888127834, Steps: 105450, Max Reward: 25.3999986872077 
Episode 4, Epsilon: 0.66, Total Reward: -2322.300077858381, Steps: 155153, Max Reward: -20.700001030229032 
Episode 5, Epsilon: 0.59, Total Reward: -1984.8000696692616, Steps: 133078, Max Reward: -10.400000834837556 
Episode 6, Epsilon: 0.53, Total Reward: -1622.09005906526, Steps: 117556, Max Reward: 1.9999992102384567 
Episode 7, Epsilon: 0.48, Total Reward: -1786.0000558216125, Steps: 153216, Max Reward: -6.200000794604421 
Episode 8, Epsilon: 0.43, Total Reward: -1502.6000496596098, Steps: 139066, Max Reward: 22.299999219365418 
Episode 9, Epsilon: 0.39, Total

In [11]:
os.makedirs("models", exist_ok=True)
torch.save(trainer.q_network.state_dict(), "models/blue_vs_random.pt")
print("Training complete. Model saved.")

Training complete. Model saved.


In [12]:
# make video
import cv2

env = battle_v4.env(map_size=45, render_mode="rgb_array", max_cycles=300)
vid_dir = "video"
os.makedirs(vid_dir, exist_ok=True)
fps = 35
frames = []
my_q_network = MyQNetwork(
    env.observation_space("red_0").shape, env.action_space("red_0").n
)
my_q_network.load_state_dict(
    torch.load("/kaggle/working/models/blue_vs_random.pt", weights_only=True, map_location="cpu")
)
my_q_network.eval()

red_pretrained_network = QNetwork(
    env.observation_space("red_0").shape, env.action_space("red_0").n
)
red_pretrained_network.load_state_dict(
    torch.load("/kaggle/working/RL-final-project-AIT-3007/red.pt", weights_only=True, map_location="cpu")
)
env.reset()
for agent in env.agent_iter():

    observation, reward, termination, truncation, info = env.last()
    observation = (
        torch.Tensor(observation).float().permute([2, 0, 1]).unsqueeze(0)
    )
    if termination or truncation:
        action = None  # this agent has died
    else:
        agent_handle = agent.split("_")[0]
        if agent_handle == "blue":
            with torch.inference_mode():
                q_values = my_q_network(observation)
            action = torch.argmax(q_values, dim=1).numpy()[0]
        else:
            action = env.action_space(agent).sample()
            # with torch.inference_mode():
            #     q_values = red_pretrained_network(observation)
            # action = torch.argmax(q_values, dim=1).item()
            
    env.step(action)

    if agent == "red_0":
        frames.append(env.render())

height, width, _ = frames[0].shape
out = cv2.VideoWriter(
    os.path.join(vid_dir, f"blue_vs_random_another.mp4"),
    cv2.VideoWriter_fourcc(*"mp4v"),
    fps,
    (width, height),
)
for frame in frames:
    frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
    out.write(frame_bgr)
out.release()
print("Done recording pretrained agents")

env.close()

Done recording pretrained agents


In [16]:
class vQNetwork(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
        )
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, action_shape),
        )

    def forward(self, x):
        assert len(x.shape) >= 3, "only support magent input observation"
        x = self.cnn(x)
        if len(x.shape) == 3:
            batchsize = 1
        else:
            batchsize = x.shape[0]
        x = x.reshape(batchsize, -1)
        return self.network(x)

In [19]:
# from torch_model import QNetwork
try:
    from tqdm import tqdm
except ImportError:
    tqdm = lambda x, *args, **kwargs: x  # Fallback: tqdm becomes a no-op
    
def eval():
    max_cycles = 300
    env = battle_v4.env(map_size=45, max_cycles=max_cycles)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    def random_policy(env, agent, obs):
        return env.action_space(agent).sample()

    MyNetwork = MyQNetwork(
        env.observation_space("red_0").shape, env.action_space("red_0").n
    )
    MyNetwork.load_state_dict(
        torch.load("/kaggle/input/blue-improved/blue_improved.pt", weights_only=True, map_location='cpu')
    )
    MyNetwork.to(device)

    red_pretrained_network = vQNetwork(
        env.observation_space("red_0").shape, env.action_space("red_0").n
    )
    red_pretrained_network.load_state_dict(
        torch.load("/kaggle/working/RL-final-project-AIT-3007/red.pt", weights_only=True, map_location="cpu")
    )
    red_pretrained_network.to(device)
    
    v1 = vQNetwork(
        env.observation_space("red_0").shape, env.action_space("red_0").n
    )
    v1.load_state_dict(
        torch.load("/kaggle/input/another-model/v1.pth", weights_only=True, map_location="cpu")
    )
    v1.to(device)

    v2 = vQNetwork(
        env.observation_space("red_0").shape, env.action_space("red_0").n
    )
    v2.load_state_dict(
        torch.load("/kaggle/input/another-model/v2.pth", weights_only=True, map_location="cpu")
    )
    v2.to(device)

    def pretrain_policy(env, agent, obs):
        observation = (
            torch.Tensor(obs).float().permute([2, 0, 1]).unsqueeze(0).to(device)
        )
        with torch.inference_mode():
            q_values = red_pretrained_network(observation)
        return torch.argmax(q_values, dim=1).cpu().numpy()[0]
        
    def v1_policy(env, agent, obs):
        observation = (
            torch.Tensor(obs).float().permute([2, 0, 1]).unsqueeze(0).to(device)
        )
        with torch.inference_mode():
            q_values = v1(observation)
        return torch.argmax(q_values, dim=1).cpu().numpy()[0]

    def v2_policy(env, agent, obs):
        observation = (
            torch.Tensor(obs).float().permute([2, 0, 1]).unsqueeze(0).to(device)
        )
        with torch.inference_mode():
            q_values = v2(observation)
        return torch.argmax(q_values, dim=1).cpu().numpy()[0]

    def my_policy(env, agent, obs):
        observation = (
            torch.Tensor(obs).float().permute([2, 0, 1]).unsqueeze(0).to(device)
        )
        with torch.inference_mode():
            q_values = MyNetwork(observation)
        return torch.argmax(q_values, dim=1).cpu().numpy()[0]

    def run_eval(env, red_policy, blue_policy, n_episode: int = 100):
        red_win, blue_win = [], []
        red_tot_rw, blue_tot_rw = [], []
        n_agent_each_team = len(env.env.action_spaces) // 2

        for _ in tqdm(range(n_episode)):
            env.reset()
            n_kill = {"red": 0, "blue": 0}
            red_reward, blue_reward = 0, 0

            for agent in env.agent_iter():
                observation, reward, termination, truncation, info = env.last()
                agent_team = agent.split("_")[0]

                n_kill[agent_team] += (
                    reward > 4.5
                )  # This assumes default reward settups
                if agent_team == "red":
                    red_reward += reward
                else:
                    blue_reward += reward

                if termination or truncation:
                    action = None  # this agent has died
                else:
                    if agent_team == "red":
                        action = red_policy(env, agent, observation)
                    else:
                        action = blue_policy(env, agent, observation)

                env.step(action)

            who_wins = "red" if n_kill["red"] >= n_kill["blue"] + 5 else "draw"
            who_wins = "blue" if n_kill["red"] + 5 <= n_kill["blue"] else who_wins
            red_win.append(who_wins == "red")
            blue_win.append(who_wins == "blue")

            red_tot_rw.append(red_reward / n_agent_each_team)
            blue_tot_rw.append(blue_reward / n_agent_each_team)

        return {
            "winrate_red": np.mean(red_win),
            "winrate_blue": np.mean(blue_win),
            "average_rewards_red": np.mean(red_tot_rw),
            "average_rewards_blue": np.mean(blue_tot_rw),
        }

    print("=" * 20)
    print("Eval with pretrain policy")
    print(
        run_eval(
            env=env, red_policy=pretrain_policy, blue_policy=my_policy, n_episode=30
        )
    )

    print("=" * 20)
    print("Eval with v1 policy")
    print(
        run_eval(
            env=env, red_policy=v1_policy, blue_policy=my_policy, n_episode=30
        )
    )
    print("=" * 20)

    print("Eval with v2 policy")
    print(
        run_eval(
            env=env, red_policy=v2_policy, blue_policy=my_policy, n_episode=30
        )
    )
    print("=" * 20)

if __name__ == "__main__":
    eval()

Eval with pretrain policy


100%|██████████| 30/30 [01:46<00:00,  3.54s/it]


{'winrate_red': 0.0, 'winrate_blue': 1.0, 'average_rewards_red': 0.7385699528127663, 'average_rewards_blue': 4.133952614933313}
Eval with v1 policy


100%|██████████| 30/30 [03:21<00:00,  6.71s/it]


{'winrate_red': 0.0, 'winrate_blue': 0.9333333333333333, 'average_rewards_red': 3.603516443757869, 'average_rewards_blue': 2.691331182867519}
Eval with v2 policy


100%|██████████| 30/30 [02:14<00:00,  4.47s/it]

{'winrate_red': 0.4, 'winrate_blue': 0.23333333333333334, 'average_rewards_red': 4.591510265306359, 'average_rewards_blue': 3.4396933533274283}



