In [73]:
#!pip install pettingzoo


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
from magent2.environments import battle_v4
from tqdm import tqdm



In [3]:
class QNetwork_pre(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
        )
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, action_shape),
        )

    def forward(self, x):
        assert len(x.shape) >= 3, "only support magent input observation"
        x = self.cnn(x)
        if len(x.shape) == 3:
            batchsize = 1
        else:
            batchsize = x.shape[0]
        x = x.reshape(batchsize, -1)
        return self.network(x)


In [6]:

class QNetwork(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
        )
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            # nn.LayerNorm(120),
            nn.ReLU(),
            nn.Linear(120, 84),
            # nn.LayerNorm(84),
            nn.Tanh(),
        )
        self.last_layer = nn.Linear(84, action_shape)

    def forward(self, x):
        assert len(x.shape) >= 3, "only support magent input observation"
        x = self.cnn(x)
        if len(x.shape) == 3:
            batchsize = 1
        else:
            batchsize = x.shape[0]
        x = x.reshape(batchsize, -1)
        x = self.network(x)
        self.last_latent = x
        return self.last_layer(x)
class DQNAgent:
    def __init__(self, observation_shape, action_shape):
        self.q_network = QNetwork(observation_shape, action_shape)
        self.target_network = QNetwork(observation_shape, action_shape)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.memory = deque(maxlen=10000)
        self.batch_size = 32
        self.gamma = 0.99
        self.epsilon = 0.1
        self.learning_rate = 1e-3
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.learning_rate)

    def act(self, observation):
        # Đổi observation thành tensor với định dạng (channels, height, width)
        state = torch.Tensor(observation).float().permute(2, 0, 1).unsqueeze(0)
        if random.random() < self.epsilon:
            # Chọn hành động ngẫu nhiên
            action = random.choice(range(self.q_network.last_layer.out_features))
        else:
            # Chọn hành động tốt nhất (theo Q-value)
            with torch.no_grad():
                q_values = self.q_network(state)
                action = torch.argmax(q_values, dim=1).item()
        return action
    
    def remember(self, state, action, reward, next_state, done):
        """Lưu trữ trải nghiệm vào bộ nhớ replay."""
        self.memory.append((state, action, reward, next_state, done))

    def replay(self):
        if len(self.memory) < self.batch_size:
            return None
        
        # Chọn một batch từ bộ nhớ
        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        # Xử lý dữ liệu để phù hợp với kiến trúc CNN
        states = torch.Tensor(states).permute(0, 3, 1, 2)
        next_states = torch.Tensor(next_states).permute(0, 3, 1, 2)
        actions = torch.LongTensor(actions)
        rewards = torch.Tensor(rewards)
        dones = torch.Tensor(dones)

        # Tính giá trị Q target
        with torch.no_grad():
            target_q_values = self.target_network(next_states)
            max_target_q_values = torch.max(target_q_values, dim=1)[0]
            target = rewards + self.gamma * max_target_q_values * (1 - dones)

        # Tính giá trị Q hiện tại
        q_values = self.q_network(states)
        q_value = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

        # Tính loss
        loss = torch.mean((q_value - target) ** 2)
        
        # Cập nhật trọng số của mạng Q
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Cập nhật target network
        self.target_network.load_state_dict(self.q_network.state_dict())

        return loss.item()

In [7]:
def train(num_episodes=100, batch_size=32):
    env = battle_v4.env(map_size=45, render_mode="rgb_array")
    # env = battle_v4.env(map_size=45, minimap_mode=False, step_reward=-0.005,
    # dead_penalty=-0.1, attack_penalty=-0.1, attack_opponent_reward=0.2,
    # max_cycles=350, extra_features=False)
    observation_shape = env.observation_space("red_0").shape
    action_shape = env.action_space("red_0").n
    agent = DQNAgent(observation_shape, action_shape)
    best_reward = float('-inf')
    q_network_red = QNetwork_pre(
        env.observation_space("red_0").shape, env.action_space("red_0").n
    )
    q_network_red.load_state_dict(
        torch.load('red.pt', weights_only=True, map_location="cpu")
    )

    for episode in tqdm(range(num_episodes)):
        episode_reward = 0
        episode_losses = []
        env.reset()
        infos = {}
        random_iters = random.sample(range(0, 80), 9)

        for agent_id in env.agent_iter():
            observation, reward, termination, truncation, info = env.last()
            # Chuyển đổi quan sát thành định dạng phù hợp với CNN
            if observation is not None:
                observation = np.array(observation)  # Chuyển thành numpy array nếu cần
            
            if termination or truncation:
                action = None
            else:
                agent_handle = agent_id.split("_")[0]
                iter = int(agent_id.split("_")[1])
                if agent_handle == 'blue' :
                    if agent_id == infos.get('prev_agent'):
                        action = agent.act(observation)
                        if len(env.agents) > 0:  # Lưu trải nghiệm vào replay buffer
                            # prev_observation = infos.get('prev_observation')
                            env.step(action)
                            prev_observation,_,_,_,_,= env.last()
                            if prev_observation is not None:
                                agent.remember(observation,
                                           infos.get('prev_action'),
                                           reward,
                                           prev_observation,
                                           termination or truncation)
                    infos.update({
                        'prev_observation': observation,
                        'prev_action': action,
                        'prev_agent':agent_id
                    })
                    if len(agent.memory) > batch_size:
                        loss = agent.replay()
                        if loss is not None:
                            episode_losses.append(loss)

                    episode_reward += reward
                elif agent_handle == 'red':
                    action = env.action_space(agent_id).sample()
            env.step(action)
        if reward > best_reward:
            best_reward = episode_reward
            torch.save(agent.q_network.state_dict(), 'blue.pt')

In [9]:
train(num_episodes=50)

100%|██████████| 50/50 [06:08<00:00,  7.37s/it]
