## Tải các thư viện và repo cần thiết

In [1]:
!pip install git+https://github.com/Farama-Foundation/MAgent2
!git clone https://github.com/giangbang/RL-final-project-AIT-3007.git

Collecting git+https://github.com/Farama-Foundation/MAgent2
  Cloning https://github.com/Farama-Foundation/MAgent2 to /tmp/pip-req-build-da1gasp_
  Running command git clone --filter=blob:none --quiet https://github.com/Farama-Foundation/MAgent2 /tmp/pip-req-build-da1gasp_
  Resolved https://github.com/Farama-Foundation/MAgent2 to commit b2ddd49445368cf85d4d4e1edcddae2e28aa1406
  Installing build dependencies ... [?25l- \ | / - \ | done
[?25h  Getting requirements to build wheel ... [?25l- done
[?25h  Preparing metadata (pyproject.toml) ... [?25l- done
Collecting pygame>=2.1.0 (from magent2==0.3.3)
  Downloading pygame-2.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading pygame-2.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.0/14.0 MB[0m [31m89.6 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages

In [2]:
import sys
sys.path.append('/kaggle/working/RL-final-project-AIT-3007')

## Định nghĩa mô hình sử dụng:

- **Mô hình có 2 phần chính**: *CNN* và *Fully Connected* (FC).

### 1. **CNN (Convolutional Neural Network)**:
   - Bao gồm **3 lớp `Conv2d`**, mỗi lớp có:
     - **Số lượng đầu ra (output channels):** 13.
     - **Hàm kích hoạt phi tuyến:** ReLU.

### 2. **Fully Connected (FC)**:
   - Bao gồm **2 lớp tuyến tính (Linear)**:
     - Lớp 1: Đầu ra là **256**.
     - Lớp 2: Đầu ra là **128**.
   - Sau mỗi lớp tuyến tính là **hàm kích hoạt phi tuyến ReLU**.
   - Lớp `Linear` cuối cùng có chức năng **chiếu (project)** các đặc trưng (**features**) xuống không gian có kích thước **128**.


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
from collections import deque, Counter
import os
from magent2.environments import battle_v4
import time

class MyQNetwork(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], 13, kernel_size=3, padding=1, stride=1),
            nn.ReLU(),
            nn.Conv2d(13, 13, kernel_size=3, padding=1, stride=1),
            nn.ReLU(),
            nn.Conv2d(13, 13, kernel_size=3),
            nn.ReLU(),
        )
        
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        
        self.fc = nn.Sequential(
            nn.Linear(flatten_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, action_shape)
        )

    def forward(self, x):
        assert len(x.shape) >= 3, "only support magent input observation"
        out = self.cnn(x)
        
        if len(x.shape) == 3:
            batchsize = 1
            
        else:
            batchsize = x.shape[0]
            
        out = out.reshape(batchsize, -1)
        
        return self.fc(out)

class ReplayBuffer(Dataset):
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)
        
    def add(self, state, action, reward, next_state, done):
        experience = (state, action, reward, next_state, done)
        self.buffer.append(experience)

    def __len__(self):
        return len(self.buffer)

    def __getitem__(self, index):
        return self.buffer[index]

## Phương pháp huấn luyện

### **Thuật toán sử dụng**: DQN (Deep Q-Network)
- `q_network` được cập nhật sau mỗi lần gọi `dataloader`, sau khi hoàn thành 1 episode.

### **Cách huấn luyện**:
1. **Thêm dữ liệu vào ReplayBuffer**:
   - Sau mỗi episode, một chuỗi các *trajectories* sẽ được lưu vào `ReplayBuffer`.
2. **Cập nhật `q_network`**:
   - Sau đó hàm `update_model` được gọi để cập nhật các tham số của `q_network`.

### **Các tham số**:
- **`steplr`**:
  - Learning rate được giảm sau mỗi lần `update_every_target` với tỉ lệ:
    - `gamma = 0.9`.
- **Discount factor**: `0.9`.
- **Epsilon (`\epsilon`) trong \(\epsilon\)-greedy**:
  - Giảm theo tỷ lệ `0.96` mỗi bước.
  - Tối thiểu giảm xuống còn `0.1`.
- **Cập nhật mô hình `target`**:
  - Mô hình `target` sẽ được cập nhật sau mỗi `update_every_target = 2`.


In [4]:
import torch
import torch.optim.lr_scheduler as lr_scheduler
class Trainer:
    def __init__(self, env, config_qnet=None, input_shape=None, action_shape=None, learning_rate=1e-3):
        self.env = env
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.q_network = MyQNetwork(input_shape, action_shape).to(self.device)
        
        self.target_network = MyQNetwork(input_shape, action_shape).to(self.device)
        self.target_network.load_state_dict(self.q_network.state_dict())

        self.criterion = nn.MSELoss()
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=learning_rate)
        self.steplr = lr_scheduler.StepLR(optimizer=self.optimizer, step_size=1, gamma=0.9)
        self.replay_buffer = ReplayBuffer(capacity=16200 * 10)

        self.gamma = 0.9
        self.epsilon = 1.0
        self.epsilon_min = 0.1
        self.epsilon_decay = 0.96
        self.update_target_every = 2

    def select_action(self, observation, agent):
        if np.random.rand() <= self.epsilon:
            return self.env.action_space(agent).sample()

        observation = (
            torch.FloatTensor(observation).unsqueeze(0).to(self.device)
        )
        self.q_network.eval()
        with torch.inference_mode():
            q_values = self.q_network(observation)
            # print(q_values)
        return torch.argmax(q_values, dim=1).item()

    # def pretrained_action(self, observation):
    #     observation = (
    #         torch.FloatTensor(observation).unsqueeze(0).to(self.device)
    #     )
    #     self.red_pretrained_network.eval()
    #     with torch.inference_mode():
    #         q_values = self.red_pretrained_network(observation)
    #     return torch.argmax(q_values, dim=1).item()

    def training(self, episodes=100, batch_size=2 ** 12):        
        for episode in range(episodes):
            self.env.reset()
            
            total_reward = 0
            reward_for_agent = {agent: 0 for agent in self.env.agents if agent.startswith('blue')}
            prev_observation = {}
            prev_action = {}
            self.env.reset()
            step = 0

            for idx, agent in enumerate(self.env.agent_iter()):
                step += 1
                observation, reward, termination, truncation, info = self.env.last()
                observation = np.transpose(observation, (2, 0, 1))
                
                agent_handle = agent.split('_')[0]
                
                if agent_handle == 'blue':
                    total_reward += reward
                    reward_for_agent[agent] += reward
                    
                if termination or truncation:
                    action = None
                else:
                    if agent_handle == 'blue':
                        action = self.select_action(observation, agent)
                    else:
                        action = self.env.action_space(agent).sample()
                        # action = self.pretrained_action(observation)

                if agent_handle == 'blue':
                    prev_observation[agent] = observation
                    prev_action[agent] = action
                
                self.env.step(action)
                
                if (idx + 1) % self.env.num_agents == 0:
                    break
                
            for agent in self.env.agent_iter():
                step += 1
                
                observation, reward, termination, truncation, info = self.env.last()
                observation = np.transpose(observation, (2, 0, 1))
                
                agent_handle = agent.split('_')[0]
                
                if agent_handle == 'blue':
                    total_reward += reward
                    reward_for_agent[agent] += reward
                    
                if termination or truncation:
                    action = None
                else:
                    if agent_handle == 'blue':
                        action = self.select_action(observation, agent)
                    else:
                        action = self.env.action_space(agent).sample()
                        # action = self.pretrained_action(observation)
    
                    if agent_handle == 'blue':
                        self.replay_buffer.add(
                            prev_observation[agent],
                            prev_action[agent],
                            reward,  
                            observation,
                            termination
                        )

                        prev_observation[agent] = observation
                        prev_action[agent] = action
    
                self.env.step(action)
            
            dataloader = DataLoader(self.replay_buffer, batch_size=batch_size, shuffle=True, drop_last=True)
            self.update_model(dataloader)
                
            if (episode + 1) % self.update_target_every == 0:
                self.target_network.load_state_dict(self.q_network.state_dict())
                self.steplr.step()
    
            max_reward = max(reward_for_agent.values())
            
            print(f"Episode {episode}, Epsilon: {self.epsilon:.2f}, Total Reward: {total_reward}, Steps: {step}, Max Reward: {max_reward}, lr: {self.steplr.get_last_lr()} ")
            self.epsilon = max(self.epsilon * self.epsilon_decay, self.epsilon_min)

    def update_model(self, dataloader):
        self.q_network.train()
        for states, actions, rewards, next_states, dones in dataloader:
            # print(states.shape)

            states = states.to(dtype=torch.float32, device=self.device)
            actions = actions.to(dtype=torch.long, device=self.device)
            rewards = rewards.to(dtype=torch.float32, device=self.device)
            next_states = next_states.to(dtype=torch.float32, device=self.device)
            dones = dones.to(dtype=torch.float32, device=self.device)

            current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
            with torch.inference_mode():
                next_q_values = self.target_network(next_states).max(1)[0]
            expected_q_values = rewards + (self.gamma * next_q_values * (1 - dones))

            loss = self.criterion(current_q_values, expected_q_values)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

## Bắt đầu huấn luyện

In [5]:
env = battle_v4.env(map_size=45, render_mode=None)

trainer = Trainer(env, config_qnet=None, input_shape=env.observation_space("red_0").shape, action_shape=env.action_space("red_0").n)
trainer.training()

Episode 0, Epsilon: 1.00, Total Reward: -3283.7001208886504, Steps: 159056, Max Reward: -32.00000128429383, lr: [0.001] 
Episode 1, Epsilon: 0.96, Total Reward: -3166.280114626512, Steps: 158571, Max Reward: -9.060000314377248, lr: [0.0009000000000000001] 
Episode 2, Epsilon: 0.92, Total Reward: -2992.7851090747863, Steps: 156817, Max Reward: -11.125000424683094, lr: [0.0009000000000000001] 
Episode 3, Epsilon: 0.88, Total Reward: -2891.8201054576784, Steps: 156284, Max Reward: -18.520000678487122, lr: [0.0008100000000000001] 
Episode 4, Epsilon: 0.85, Total Reward: -2851.5701018059626, Steps: 156379, Max Reward: -19.400001253932714, lr: [0.0008100000000000001] 
Episode 5, Epsilon: 0.82, Total Reward: -2599.220096149482, Steps: 149153, Max Reward: -0.0700002321973443, lr: [0.000729] 
Episode 6, Epsilon: 0.78, Total Reward: -2719.7000952856615, Steps: 155466, Max Reward: -24.90000123810023, lr: [0.000729] 
Episode 7, Epsilon: 0.75, Total Reward: -2440.2800892386585, Steps: 141278, Max R

## Lưu mô hình

In [6]:
os.makedirs("models", exist_ok=True)
torch.save(trainer.q_network.state_dict(), "models/blue_resnet_vs_random.pt")
print("Training complete. Model saved.")

Training complete. Model saved.
