In [1]:
!pip install -q git+https://github.com/Farama-Foundation/MAgent2

In [10]:
import torch
import torch.optim as optim
import numpy as np
from collections import deque
import random
import torch.nn as nn
from magent2.environments import battle_v4

In [11]:
class QNetwork(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
        )
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, action_shape),
        )

    def forward(self, x):
        assert len(x.shape) >= 3, "only support magent input observation"
        x = self.cnn(x)
        if len(x.shape) == 3:
            batchsize = 1
        else:
            batchsize = x.shape[0]
        x = x.reshape(batchsize, -1)
        return self.network(x)

In [12]:
class RandomAgent:
    def __init__(self, action_space):
        self.action_space = action_space

    def get_action(self, observation):
        return self.action_space.sample()

In [13]:
class PretrainedAgent:
    def __init__(self, n_observation, n_actions, device="cpu"):
        self.device = torch.device(device)
        self.qnetwork = QNetwork(n_observation, n_actions).to(self.device)

        self.qnetwork.load_state_dict(
            torch.load("/kaggle/input/pretrained/pytorch/default/1/red.pt", weights_only=True, map_location=self.device)
        )

    def get_action(self, observation):
        observation = (
            torch.Tensor(observation).float().permute([2, 0, 1]).unsqueeze(0).to(self.device)
        )
        with torch.no_grad():
            q_values = self.qnetwork(observation)
        action = torch.argmax(q_values, dim=1).cpu().numpy()[0]

        return action


In [14]:
class DQNAgent:
    def __init__(self, observation_shape, action_shape, buffer_size=10000, batch_size=64, lr=1e-3, gamma=0.6, device="cpu"):
        self.device = torch.device(device)
        self.q_net = QNetwork(observation_shape, action_shape).float().to(self.device)
        self.target_net = QNetwork(observation_shape, action_shape).float().to(self.device)
        self.target_net.load_state_dict(self.q_net.state_dict())
        self.target_net.eval()

        self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)
        self.replay_buffer = ReplayBuffer(buffer_size)
        self.batch_size = batch_size
        self.gamma = gamma
        self.action_shape = action_shape
        self.epsilon = 1.0
        self.epsilon_decay = 0.995
        self.epsilon_min = 0.1

    def get_action(self, observation):
        if np.random.rand() < self.epsilon:
            return np.random.randint(self.action_shape)
        else:
            state_tensor = torch.FloatTensor(observation).unsqueeze(0).permute(0, 3, 1, 2).to(self.device)
            with torch.no_grad():
                return self.q_net(state_tensor).argmax().item()

    def train(self):
        if len(self.replay_buffer) < self.batch_size:
            return

        # Sample the replay buffer
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)

        # Filter out terminal states (actions == None)
        valid_indices = [i for i in range(len(actions)) if actions[i] is not None]
        if not valid_indices:  # If no valid samples, return
            return

        states = np.array([states[i] for i in valid_indices])
        actions = np.array([actions[i] for i in valid_indices])
        rewards = np.array([rewards[i] for i in valid_indices])
        next_states = np.array([next_states[i] for i in valid_indices])
        dones = np.array([dones[i] for i in valid_indices])

        # Convert to tensors and move to device
        states = torch.FloatTensor(states).permute(0, 3, 1, 2).to(self.device)
        next_states = torch.FloatTensor(next_states).permute(0, 3, 1, 2).to(self.device)
        actions = torch.tensor(actions).unsqueeze(1).to(self.device)
        rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
        dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)

        # Compute Q-values and targets
        q_values = self.q_net(states).gather(1, actions)
        with torch.no_grad():
            next_q_values = self.target_net(next_states).max(1, keepdim=True)[0]
            target_q_values = rewards + self.gamma * next_q_values * (1 - dones)

        # Compute loss and update the network
        loss = nn.MSELoss()(q_values, target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target_network(self):
        self.target_net.load_state_dict(self.q_net.state_dict())

    def decay_epsilon(self):
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)


In [15]:
# Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*batch)
        return (np.stack(state), np.array(action), np.array(reward), 
                np.stack(next_state), np.array(done))
    def __len__(self):
        return len(self.buffer)

In [16]:
def train_dqn(env, red_agent: DQNAgent, blue_agent: RandomAgent, episodes=500, target_update_freq=5):
    total_rewards = []

    for episode in range(episodes):
        env.reset()
        total_reward = 0
        for agent in env.agent_iter():
            observation, reward, termination, truncation, _ = env.last()
            if termination or truncation:
                action = None
            else:
                if agent == "red_0":
                    action = red_agent.get_action(observation)
                else:  # "blue_0"
                    action = blue_agent.get_action(observation)

            env.step(action)
            if agent == "red_0":
                red_agent.replay_buffer.push(
                    observation, action, reward, env.last()[0], termination or truncation
                )
                total_reward += reward
            red_agent.train()

        red_agent.decay_epsilon()
        if episode % target_update_freq == 0:
            red_agent.update_target_network()

        total_rewards.append(total_reward)
        print(f"Episode {episode}, Total Reward: {total_reward}, Epsilon: {red_agent.epsilon:.2f}")

    env.close()


In [17]:
env = battle_v4.env(map_size=45, render_mode="rgb-array")

device = "cuda" if torch.cuda.is_available() else "cpu"

observation_shape = env.observation_space("red_0").shape
action_shape = env.action_space("red_0").n

red_agent = DQNAgent(observation_shape, action_shape, device=device)
blue_agent = PretrainedAgent(observation_shape, action_shape, device=device)

train_dqn(env, red_agent, blue_agent, episodes=500)


Episode 444, Total Reward: -2.33000008482486, Epsilon: 0.11
Episode 445, Total Reward: -3.0350001147016883, Epsilon: 0.11
Episode 446, Total Reward: -3.905000147409737, Epsilon: 0.11
Episode 447, Total Reward: -3.700000138953328, Epsilon: 0.11
Episode 448, Total Reward: -2.6500000907108188, Epsilon: 0.11
Episode 449, Total Reward: -2.4250000827014446, Epsilon: 0.10
Episode 450, Total Reward: -1.3650000412017107, Epsilon: 0.10
Episode 451, Total Reward: -3.150000118650496, Epsilon: 0.10
Episode 452, Total Reward: -2.9500001100823283, Epsilon: 0.10
Episode 453, Total Reward: -2.5450000930577517, Epsilon: 0.10
Episode 454, Total Reward: -2.5550000928342342, Epsilon: 0.10
Episode 455, Total Reward: -2.790000100620091, Epsilon: 0.10
Episode 456, Total Reward: -3.7100001387298107, Epsilon: 0.10
Episode 457, Total Reward: -6.785000259056687, Epsilon: 0.10
Episode 458, Total Reward: -1.7550000585615635, Epsilon: 0.10
Episode 459, Total Reward: -3.1950001176446676, Epsilon: 0.10
Episode 460, To

In [18]:
def save_model(model, file_path):
    """
    Save the PyTorch model to a .pt file.
    
    Args:
        model (nn.Module): The model to save.
        file_path (str): The file path to save the model to.
    """
    torch.save(model.state_dict(), file_path)
    print(f"Model saved to {file_path}")

In [19]:
save_model(red_agent.q_net, "q_net_trained.pt")

Model saved to q_net_trained.pt
