In [1]:
!pip install git+https://github.com/Farama-Foundation/MAgent2

Collecting git+https://github.com/Farama-Foundation/MAgent2
  Cloning https://github.com/Farama-Foundation/MAgent2 to /tmp/pip-req-build-a0qxyxij
  Running command git clone --filter=blob:none --quiet https://github.com/Farama-Foundation/MAgent2 /tmp/pip-req-build-a0qxyxij
  Resolved https://github.com/Farama-Foundation/MAgent2 to commit b2ddd49445368cf85d4d4e1edcddae2e28aa1406
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Collecting pygame>=2.1.0 (from magent2==0.3.3)
  Downloading pygame-2.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading pygame-2.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.0/14.0 MB[0m [31m95.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m:01[0m
[?25hBuilding wheels for collected packages: magent2
  Building wheel for

In [2]:
from collections import deque, namedtuple

# DQN NETWORK


In [3]:

import os
from magent2.environments import battle_v4
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import numpy as np
import cv2

import math
import matplotlib.pyplot as plt

In [4]:
class QNetwork(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
        )
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, action_shape),
        )

    def forward(self, x):
        assert len(x.shape) >= 3, "only support magent input observation"
        x = self.cnn(x)
        if len(x.shape) == 3:
            batchsize = 1
        else:
            batchsize = x.shape[0]
        x = x.reshape(batchsize, -1)
        return self.network(x)

# Replay Memory

In [5]:
Transition = namedtuple('Transition',
                        ('observation', 'action', 'reward', 'next_observation', 'done'))

In [6]:
# Define memory for Experience Replay
class ReplayMemory(object):
    def __init__(self, maxlen):
        self.memory = deque([], maxlen=maxlen)

    def append(self, *args):
        self.memory.append(Transition(*args))

    def sample(self, sample_size):
        return random.sample(self.memory, sample_size)

    def __len__(self):
        return len(self.memory)

# Training

In [7]:
# Hyperparameters
learning_rate = 0.0003
gamma = 0.9
epsilon_start = 0.9
epsilon_end = 0.05
network_sync_rate = 20
epsilon_decay = 150000
batch_size = 128
episodes = 20
TAU = 0.005

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


## Epsilon-Greedy

In [9]:
steps_done = 0
def policy(observation, agent, env, q_network, device):
    global steps_done
    sample = random.random()
    epsilon = epsilon_end + (epsilon_start - epsilon_end) * \
        math.exp(-1. * steps_done / epsilon_decay)
    steps_done += 1
    if sample < epsilon:
        return env.action_space(agent).sample()
    else:
        observation = (
            torch.Tensor(observation).float().permute([2, 0, 1]).unsqueeze(0).to(device)
        )
        with torch.no_grad():
            q_values = q_network(observation)
        return torch.argmax(q_values, dim=1).cpu().numpy()[0]

## Init environment and network

In [10]:
# Initialize environment and network
env = battle_v4.env(map_size=45, minimap_mode=False, step_reward=-0.005,
dead_penalty=-0.1, attack_penalty=-0.005, attack_opponent_reward=0.05,
max_cycles=300, extra_features=False)
env.reset()

policy_dqn = QNetwork(
    observation_shape=env.observation_space("blue_0").shape,
    action_shape=env.action_space("blue_0").n,
)
policy_dqn = policy_dqn.to(device)

target_dqn = QNetwork(
    observation_shape=env.observation_space("blue_0").shape,
    action_shape=env.action_space("blue_0").n,
)
target_dqn = target_dqn.to(device)


# Make the target and policy networks the same (copy weights/biases from one network to the other)
target_dqn.load_state_dict(policy_dqn.state_dict())

<All keys matched successfully>

In [11]:

optimizer = optim.Adam(policy_dqn.parameters(), lr=learning_rate)                # NN Optimizer.

In [12]:
loss_fn = nn.SmoothL1Loss()

# Training Loop

In [13]:
replay_buffer = ReplayMemory(10000)

In [14]:
def optimize_model(policy_dqn, target_dqn):
    if len(replay_buffer) < batch_size:
        return
    transitions = replay_buffer.sample(batch_size)

    batch = Transition(*zip(*transitions))

    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                            batch.next_observation)), device=device, dtype=torch.bool)
    non_final_next_observations = torch.cat([s for s in batch.next_observation if s is not None])

    observation_batch = torch.cat(batch.observation).to(device)
    action_batch = torch.cat(batch.action).to(device)
    reward_batch = torch.cat(batch.reward).to(device)


    state_action_values = policy_dqn(observation_batch).gather(1, action_batch)

    next_state_values = torch.zeros(batch_size, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_dqn(non_final_next_observations).max(1).values

    expected_state_action_values = (next_state_values * gamma) + reward_batch

    loss = loss_fn(state_action_values, expected_state_action_values.unsqueeze(1))
    
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_value_(policy_dqn.parameters(), 100)
    optimizer.step()



In [15]:




def train_q_network(env, policy_dqn, target_dqn, optimizer, loss_fn, device):
    """
    Training loop for Q-learning without a replay buffer.

    """
    episode_rewards = []
    # List to keep track of epsilon decay
    epsilon_history = []

    try:
        for episode in range(episodes):
            env.reset()
            ep_reward = 0
            # ep_loss = 0
            ep_steps = 0
            for agent in env.agent_iter():

                observation, reward, termination, truncation, info = env.last()

                if termination or truncation:
                    action = None  # Agent is dead
                    env.step(action)
                else:
                    agent_handle = agent.split("_")[0]
                    if agent_handle == "blue":
                        # Get action
                        action = blue_policy(observation, agent, env, policy_dqn, device)

                        # Take action
                        env.step(action)

                        # Get next state information
                        next_observation, reward, termination, truncation, info = env.last()

                        ep_reward += reward


                        # Append transition to replay buffer
                        observation = torch.Tensor(observation).float().permute([2, 0, 1]).unsqueeze(0).to(device)
                        next_observation = torch.Tensor(next_observation).float().permute([2, 0, 1]).unsqueeze(0).to(device)
                        action = torch.tensor([action], device=device, dtype=torch.int64).unsqueeze(0)
                        reward = torch.tensor([reward], device=device, dtype=torch.float32)
                        replay_buffer.append(observation, action, reward, next_observation, termination)

                    else:
                        # Random policy for red team
                        action = env.action_space(agent).sample()
                        env.step(action)

                        # Get next state information
                        next_observation, reward, termination, truncation, info = env.last()

                        # ep_reward -= reward

                        # Append transition to replay buffer
                        observation = torch.Tensor(observation).float().permute([2, 0, 1]).unsqueeze(0).to(device)
                        next_observation = torch.Tensor(next_observation).float().permute([2, 0, 1]).unsqueeze(0).to(device)
                        action = torch.tensor([action], device=device, dtype=torch.int64).unsqueeze(0)
                        reward = torch.tensor([reward], device=device, dtype=torch.float32)
                        replay_buffer.append(observation, action, reward, next_observation, termination)

                        
                    # Optimize Model
                    optimize_model(policy_dqn, target_dqn)

                    ep_steps += 1
                    if ep_steps % network_sync_rate == 0:
                        target_net_state_dict = target_dqn.state_dict()
                        policy_net_state_dict = policy_dqn.state_dict()
                        for key in policy_net_state_dict:
                            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
                        target_dqn.load_state_dict(target_net_state_dict)

            # Decay epsilon
            epsilon = epsilon_end + (epsilon_start - epsilon_end) * \
                math.exp(-1. * steps_done / epsilon_decay)
            epsilon_history.append(epsilon)

            # Record episode statistics
            episode_rewards.append(ep_reward)

            print(f"Episode {episode+1}/{episodes} end after {ep_steps} steps -Reward: {ep_reward} -Epsilon: {epsilon}")

        # Save trained model
        torch.save({
            'model_state_dict': policy_dqn.state_dict(),
        }, "blue_agent.pt")
        print("Model saved as 'blue_agent.pt'")
        return episode_rewards

    except KeyboardInterrupt:
        print("\nTraining interrupted. Saving model checkpoint...")
        torch.save({
            'model_state_dict': policy_dqn.state_dict(),
        }, "blue_agent_interrupted.pt")
        print("Model saved as 'blue_agent_interrupted.pt'")
        return episode_rewards

## Training Log

In [16]:
# Train the network
episode_rewards = train_q_network(env, policy_dqn, target_dqn, optimizer, loss_fn, device)

Episode 1/20 end after 48600 steps -Reward: -161.15499639790505 -Epsilon: 0.772875023859198
Episode 2/20 end after 48338 steps -Reward: -159.26999644748867 -Epsilon: 0.6650209605890582
Episode 3/20 end after 48498 steps -Reward: -153.7849969258532 -Epsilon: 0.5730391665408499
Episode 4/20 end after 48301 steps -Reward: -151.8899968545884 -Epsilon: 0.49484371407380506
Episode 5/20 end after 48481 steps -Reward: -163.49499634932727 -Epsilon: 0.42861367176132725
Episode 6/20 end after 48383 steps -Reward: -159.66999643296003 -Epsilon: 0.3724548144384288
Episode 7/20 end after 48596 steps -Reward: -157.49499660078436 -Epsilon: 0.3242288608008147
Episode 8/20 end after 48600 steps -Reward: -162.8199963606894 -Epsilon: 0.28321552269914074
Episode 9/20 end after 48599 steps -Reward: -164.17999633401632 -Epsilon: 0.2483360900417373
Episode 10/20 end after 48600 steps -Reward: -166.4299962799996 -Epsilon: 0.21867318331889518
Episode 11/20 end after 48156 steps -Reward: -135.63499745633453 -Epsi