In [15]:
#!pip install pettingzoo


Collecting pettingzoo
  Downloading pettingzoo-1.24.3-py3-none-any.whl.metadata (8.5 kB)
Collecting gymnasium>=0.28.0 (from pettingzoo)
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium>=0.28.0->pettingzoo)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Downloading pettingzoo-1.24.3-py3-none-any.whl (847 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m847.8/847.8 kB[0m [31m16.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading gymnasium-1.0.0-py3-none-any.whl (958 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m958.1/958.1 kB[0m [31m49.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Installing collected packages: farama-notifications, gymnasium, pettingzoo
Successfully installed farama-notifications-0.0.4 gymnasium-1.0.0 pettingzoo-1.24.3


In [5]:
!pip install pygame opencv-python numpy
!pip install git+https://github.com/Farama-Foundation/MAgent2.git


Collecting git+https://github.com/Farama-Foundation/MAgent2.git
  Cloning https://github.com/Farama-Foundation/MAgent2.git to /tmp/pip-req-build-8dy7zwmu
  Running command git clone --filter=blob:none --quiet https://github.com/Farama-Foundation/MAgent2.git /tmp/pip-req-build-8dy7zwmu
  Resolved https://github.com/Farama-Foundation/MAgent2.git to commit b2ddd49445368cf85d4d4e1edcddae2e28aa1406
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting pettingzoo>=1.23.1 (from magent2==0.3.3)
  Downloading pettingzoo-1.24.3-py3-none-any.whl.metadata (8.5 kB)
Collecting gymnasium>=0.28.0 (from pettingzoo>=1.23.1->magent2==0.3.3)
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium>=0.28.0->pettingzoo>=1.23.1->magent2==0.3.3)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
from magent2.environments import battle_v4
from tqdm import tqdm


class QNetwork(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super(QNetwork, self).__init__()
        # Một mạng nơ-ron đơn giản với 2 lớp fully connected
        self.fc1 = nn.Linear(np.prod(observation_shape), 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_shape)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the observation
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

class DQNAgent:
    def __init__(self, observation_shape, action_shape):
        self.q_network = QNetwork(observation_shape, action_shape)
        self.target_network = QNetwork(observation_shape, action_shape)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.memory = deque(maxlen=10000)
        self.batch_size = 32
        self.gamma = 0.99
        self.epsilon = 0.1
        self.learning_rate = 1e-3
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.learning_rate)

    def act(self, observation):
        # Đổi observation thành tensor
        state = torch.Tensor(observation).float().unsqueeze(0)
        if random.random() < self.epsilon:
            # Chọn hành động ngẫu nhiên
            action = random.choice(range(self.q_network.fc3.out_features))
        else:
            # Chọn hành động tốt nhất (theo Q-value)
            with torch.no_grad():
                q_values = self.q_network(state)
                action = torch.argmax(q_values, dim=1).item()
        return action

    def remember(self, state, action, reward, next_state, done):
        # Lưu trữ trải nghiệm vào bộ nhớ
        self.memory.append((state, action, reward, next_state, done))

    def replay(self):
        if len(self.memory) < self.batch_size:
            return None

        # Chọn một batch từ bộ nhớ
        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.Tensor(states)
        actions = torch.LongTensor(actions)
        rewards = torch.Tensor(rewards)
        next_states = torch.Tensor(next_states)
        dones = torch.Tensor(dones)

        # Tính giá trị Q target
        with torch.no_grad():
            target_q_values = self.target_network(next_states)
            max_target_q_values = torch.max(target_q_values, dim=1)[0]
            target = rewards + self.gamma * max_target_q_values * (1 - dones)

        # Tính giá trị Q hiện tại
        q_values = self.q_network(states)
        q_value = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

        # Tính loss
        loss = torch.mean((q_value - target) ** 2)

        # Cập nhật trọng số của mạng Q
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Cập nhật target network
        self.target_network.load_state_dict(self.q_network.state_dict())

        return loss.item()

    def save(self, filename):
        torch.save(self.q_network.state_dict(), filename)


In [13]:
def train(num_episodes=100, batch_size=32):
    env = battle_v4.env(map_size=45, render_mode="rgb_array")
    # Initialize agent for red team
    observation_shape = env.observation_space("red_0").shape
    action_shape = env.action_space("red_0").n
    agent = DQNAgent(observation_shape, action_shape)
    #agent.q_network.load_state_dict(torch.load('blue.pt', map_location=torch.device('cpu')))
    #agent.target_network.load_state_dict(torch.load('blue.pt', map_location=torch.device('cpu')))
    best_reward = float('-inf')

    for episode in tqdm(range(num_episodes)):
        episode_reward = 0
        episode_losses = []
        env.reset()
        infos = {}
        random_iters = random.sample(range(0, 80), 9)
        for agent_id in env.agent_iter():
            observation, reward, termination, truncation, info = env.last()
            # Handle terminated or truncated agents
            if termination or truncation:
                action = None
            else:
                agent_handle = agent_id.split("_")[0]
                iter = int(agent_id.split("_")[1])
                #random_iter = random.randint(0, 80)
                if agent_handle == 'blue' and iter == 0:
                    # Get action from our trained agent
                    action = agent.act(observation)
                    # Store experience in replay buffer
                    #print('len',len(env.agents))
                    if len(env.agents) > 0:  # Make sure we have a valid previous state
                        prev_observation = infos.get('prev_observation')
                        if prev_observation is not None:
                            agent.remember(prev_observation,
                                         infos.get('prev_action'),
                                         reward,
                                         observation,
                                         termination or truncation)
                    infos.update({
                        'prev_observation': observation,
                        'prev_action': action
                      })
                    # Train the agent
                    if len(agent.memory) == 5000:
                      print(len(agent.memory))
                    if len(agent.memory) > batch_size:
                        loss = agent.replay()
                        if loss is not None:
                            episode_losses.append(loss)

                    episode_reward += reward
                elif agent_handle == 'red':
                    # Random actions for blue team
                    action = env.action_space(agent_id).sample()

            env.step(action)
        # Save model if we got better results
        if reward > best_reward:
          best_reward = episode_reward
          torch.save(agent.q_network.state_dict(), 'blue.pt')

In [14]:
train(num_episodes=100)

  states = torch.Tensor(states)
  5%|▌         | 5/100 [03:01<55:24, 34.99s/it]

5000


100%|██████████| 100/100 [1:03:36<00:00, 38.16s/it]
