In [1]:
# !pip install -q git+https://github.com/Farama-Foundation/MAgent2

In [2]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from collections import deque
import random
import torch.nn as nn
from magent2.environments import battle_v4
from torch.utils.data import Dataset, DataLoader
from time import time 

# Q Networks
- Phần này chứa cài đặt của các mạng Q 


In [3]:
def kaiming_init(m):
    """
    Khởi tạo tham số của lớp theo Kaiming Initialization.
    """
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, nonlinearity="relu")  
        if m.bias is not None:
            nn.init.zeros_(m.bias)  

## Pretrained Networks
- Trong phần này, chúng tôi cài đặt lại các mạng được sử dụng trong mô hình pretrained và final_pretraiend

In [4]:
"""
Đây là kiến trúc pretrained được sử dụng cho red.pt 
"""


class PretrainedQNetwork(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
        )
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, action_shape),
        )
       # self.apply(kaiming_init)

    def forward(self, x):
        assert len(x.shape) >= 3
        x = self.cnn(x)
        if len(x.shape) == 3:
            batchsize = 1
        else:
            batchsize = x.shape[0]
        x = x.reshape(batchsize, -1)
        return self.network(x)

In [5]:
"""
Đây là kiến trúc được cài đặt trong final_red.pt 
"""


class Final_QNets(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
        )
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            # nn.LayerNorm(120),
            nn.ReLU(),
            nn.Linear(120, 84),
            # nn.LayerNorm(84),
            nn.Tanh(),
        )
        self.last_layer = nn.Linear(84, action_shape)

    def forward(self, x):
        assert len(x.shape) >= 3, "only support magent input observation"
        x = self.cnn(x)
        if len(x.shape) == 3:
            batchsize = 1
        else:
            batchsize = x.shape[0]
        x = x.reshape(batchsize, -1)
        x = self.network(x)
        self.last_latent = x
        return self.last_layer(x)

In [6]:
"""
Đây là kiến trúc mạng sử đổi (được sử dụng trong thí nghiệm 2)
"""

class MyQNetwork(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
        )
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, action_shape),
        )
       # self.apply(kaiming_init)

    def forward(self, x):
        assert len(x.shape) >= 3
        x = self.cnn(x)
        if len(x.shape) == 3:
            batchsize = 1
        else:
            batchsize = x.shape[0]
        x = x.reshape(batchsize, -1)
        return self.network(x)

## QMix Networks 

In [7]:
"""
Đây là kiến trúc mạng dùng chung giữa các Agent trong thuật toán QMix - tương tự như mạng Pretrained 
"""

class SharedQNetwork(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super(SharedQNetwork, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3, padding=1),
            nn.ReLU(),
        )
        
        dummy_input = torch.randn(1, observation_shape[-1], observation_shape[0], observation_shape[1])
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]

        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, action_shape)
        )
        self.apply(kaiming_init)

    def forward(self, x):

        if len(x.shape) == 4: 
            x = x.permute(0, 3, 1, 2) 
        elif len(x.shape) == 3: 
            x = x.permute(2, 0, 1).unsqueeze(0)  # [H, W, C] -> [1, C, H, W

        x = self.cnn(x)
        x = x.reshape(x.shape[0], -1)  # Flatten
        return self.network(x)

In [8]:
"""
Đây là cài đặt kiến trúc mạng Mixing trong thuật toán QMix
"""


class MixingNetwork(nn.Module):
    def __init__(self, num_agents, embed_dim=32, channels=5, height=13, width=13):
        super(MixingNetwork, self).__init__()
        self.num_agents = num_agents
        self.embed_dim = embed_dim
        
        # CNN feature extractor for states
        self.cnn = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),  # [batch_size*num_agents, 16, height, width]
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),  # [batch_size*num_agents, 16, height//2, width//2]
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),  # [batch_size*num_agents, 32, height//2, width//2]
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),  # [batch_size*num_agents, 32, height//4, width//4]
        )
        cnn_out_dim = (height // 4) * (width // 4) * channels  # Flattened output of CNN
        self.fc_state = nn.Linear(cnn_out_dim, 1)  # Reduce state feature to scalar per agent

        # Hyper-networks
        self.hyper_w1 = nn.Linear(1, num_agents * embed_dim)
        self.hyper_b1 = nn.Linear(1, embed_dim)
        self.hyper_w2 = nn.Linear(1, embed_dim)
        self.hyper_b2 = nn.Linear(1, 1)

    def forward(self, agent_qs, states):
        batch_size = agent_qs.size(0)
        num_agents = agent_qs.size(1)
        states = states.view(batch_size * num_agents, states.size(2), states.size(3), states.size(4))
        
        # Process states with CNN
        states = self.cnn(states)  # [batch_size*num_agents, 32, height//4, width//4]
        states = states.reshape(batch_size * num_agents, -1)  # Flatten to [batch_size*num_agents, cnn_out_dim]
        states = self.fc_state(states)  # [batch_size*num_agents, 1]
        states = states.view(batch_size, num_agents, 1)  # Reshape to [batch_size, num_agents, 1]

        # Aggregate state to batch dimension
        global_state = states.mean(dim=1)  # Mean across agents: [batch_size, 1]

        # Get weights and biases from hyper-networks
        w1 = torch.abs(self.hyper_w1(global_state)).view(batch_size, num_agents, self.embed_dim)
        b1 = self.hyper_b1(global_state).view(batch_size, 1, self.embed_dim)
        w2 = torch.abs(self.hyper_w2(global_state)).view(batch_size, self.embed_dim, 1)
        b2 = self.hyper_b2(global_state).view(batch_size, 1, 1)

        # Compute Mixing Network
        hidden = torch.bmm(agent_qs.unsqueeze(1), w1) + b1  # [batch_size, 1, embed_dim]
        hidden = F.relu(hidden)
        q_total = torch.bmm(hidden, w2) + b2  # [batch_size, 1, 1]
        return q_total.squeeze(-1)  # [batch_size, 1]

# Memory 

In [9]:
"""
Cài đặt replay buffer cho thuật toán Double Q Learning 
"""


# Replay Buffer
from torch.utils.data import Dataset, DataLoader 

class ReplayBuffer(Dataset):
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*batch)
        return (np.stack(state), np.array(action), np.array(reward), 
                np.stack(next_state), np.array(done))
        
    def __len__(self):
        return len(self.buffer)


    def __getitem__(self, idx): 
        state, action, reward, next_state, done = self.buffer[idx]
        return (
            torch.tensor(state), 
            torch.tensor(action), 
            torch.tensor(reward, dtype = torch.float),
            torch.tensor(next_state), 
            torch.tensor(done, dtype = torch.float32)
        )

In [10]:
"""
Cài đặt StateMemory cho thuật toán QMix
"""

class StateMemory(Dataset):
    def __init__(self, capacity, num_agents = 162, grouped_agents = 18):
        self.capacity = capacity
        self.memory = [deque(maxlen=capacity) for _ in range(grouped_agents)]
        self.num_agents = num_agents 
        self.grouped_agents = grouped_agents 

    def push(self, idx, state, action, reward, next_state, done):
        self.memory[idx].append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        """
        Lấy ngẫu nhiên một batch thông tin từ bộ nhớ -> batch 
        batch được sample ra là một chuỗi hành đồng 
        """
        
        batch = random.sample(self.memory, batch_size)
        idx, state, action, reward, next_state, done = zip(*batch)
        return (
            np.stack(state),
            np.array(action),
            np.array(reward, dtype=np.float32),
            np.stack(next_state),
            np.array(done, dtype=np.float32)
        )

    def ensemble(self): 
        """
        mở rộng tất cả deque trong self.memory đến chiều dài tối đa bằng cách thêm các giá trị None vào cuối.
        state 
        """
        max_len = max([len(agent_memory) for agent_memory in self.memory])
        min_len = min([len(agent_memory) for agent_memory in self.memory])

        if max_len == min_len: return 
        
        for i in range(self.grouped_agents):
            current_len = len(self.memory[i])
            while current_len < max_len:
                
                self.memory[i].append((None, None, None, None, None))
                current_len += 1

    def __len__(self):
        """
        Trả về độ dài của bộ nhớ.
        """
        return min([len(i) for i in self.memory])

    

    def __getitem__(self, idx):
        """
        Trả về dữ liệu tại một chỉ số cụ thể dưới dạng tensor cho tất cả agents.
        """
        states, actions, rewards, next_states, dones = [], [], [], [], []

        for i in range(self.grouped_agents):
        
            state, action, reward, next_state, done = self.memory[i][idx]

            state = state if state is not None else np.full_like(self.memory[0][0][0], fill_value=-1)
            action = action if action is not None else -1 
            reward = reward if reward is not None else 0.0
            next_state = next_state if next_state is not None else np.full_like(self.memory[0][0][0], fill_value=-1)
            done = done if done is not None else 1.0 

            # Thêm vào danh sách
            states.append(state)
            actions.append(action)
            rewards.append(reward)
            next_states.append(next_state)
            dones.append(done)

        states = np.array(states, dtype=np.float32)
        actions = np.array(actions, dtype=np.int64)
        rewards = np.array(rewards, dtype=np.float32)
        next_states = np.array(next_states, dtype=np.float32)
        dones = np.array(dones, dtype=np.float32)

   
        
        return (
            torch.tensor(states, dtype=torch.float32),      
            torch.tensor(actions, dtype=torch.long),        
            torch.tensor(rewards, dtype=torch.float32),      
            torch.tensor(next_states, dtype=torch.float32),  
            torch.tensor(dones, dtype=torch.float32)        
        )


# Agent

- Đây là module sử dụng để huấn luyện, thực hiện action dựa trên policy đã được huấn luyện 

In [11]:
class RandomAgent:
    def __init__(self, action_space):
        self.n_action = action_space

    def get_action(self, observation):
        return torch.randint(0, self.n_action, (1,)).item()  

## Pretrained Agent 
- cài đặt các agent với tham số được cho trước dùng để eval()

In [12]:
class PretrainedAgent:
    def __init__(self, n_observation, n_actions, device="cpu"):
        self.device = torch.device(device)
        self.qnetwork = PretrainedQNetwork(n_observation, n_actions).to(self.device)
        self.n_action = n_actions
        self.qnetwork.load_state_dict(
            torch.load("/kaggle/input/pretrained/pytorch/default/1/red.pt", weights_only=True, map_location=self.device)
        )

    def get_action(self, observation):
        
        observation = (
            torch.Tensor(observation).float().permute([2, 0, 1]).unsqueeze(0).to(self.device)
        )
        with torch.no_grad():
            q_values = self.qnetwork(observation)
        action = torch.argmax(q_values, dim=1).cpu().numpy()[0]

        return action

In [13]:
class FinalAgent: 
    def __init__(self, n_observation, n_actions, device = "cpu"): 
        self.device = torch.device(device)

        self.final_networks = Final_QNets(n_observation, n_actions).to(self.device)

        self.final_networks.load_state_dict(
            torch.load("/kaggle/input/final_rl/pytorch/default/1/red_final.pt", weights_only = True, map_location = self.device)
        )

    def get_action(self, observation): 
        observation = (
            torch.Tensor(observation).float().permute([2, 0, 1]).unsqueeze(0).to(self.device)
        )
        with torch.no_grad():
            q_values = self.final_networks(observation)
        action = torch.argmax(q_values, dim=1).cpu().numpy()[0]

        return action
    

## Training Agent 
- Chứ các agent được cài đặt thuật toán Double Q Learning + QMix 

In [14]:
class DQNAgent:
    def __init__(self, observation_shape, action_shape, batch_size=64, lr=1e-3, gamma=0.6, device="cpu"):
        self.device = torch.device(device)
        self.q_net = PretrainedQNetwork(observation_shape, action_shape).float().to(self.device)
        self.target_net = PretrainedQNetwork(observation_shape, action_shape).float().to(self.device)
        self.target_net.load_state_dict(self.q_net.state_dict())
        self.target_net.eval()

        self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)
        self.batch_size = batch_size
        self.gamma = gamma
        self.action_shape = action_shape
        self.epsilon = 1.0
        self.epsilon_decay = 0.97
        self.epsilon_min = 0.05
        self.loss_fn = nn.MSELoss()
    

    def get_action(self, observation):
        if np.random.rand() < self.epsilon:
            return np.random.randint(self.action_shape)
        else:
            state_tensor = torch.FloatTensor(observation).unsqueeze(0).permute(0, 3, 1, 2).to(self.device)
            with torch.no_grad():
                return self.q_net(state_tensor).argmax().item()

    def train(self, dataloader):
        """
            cap nhat lai tham so mo hinh voi input dau vao 
        """
        self.q_net.train()
        for obs, action, reward, next_obs, done in dataloader: 
            self.q_net.zero_grad()
    
            obs = obs.permute(0, 3, 1, 2).to(self.device) 
            action = action.unsqueeze(1).to(self.device)
            reward = reward.unsqueeze(1).to(self.device)
            next_obs = next_obs.to(self.device)
            next_obs = next_obs.permute(0, 3, 1, 2).to(self.device)
            done = done.unsqueeze(1).to(self.device)
    
            # cap nhat gia tri q 
            with torch.no_grad(): 
                target_q_values = reward + self.gamma * (1 - done) * self.target_net(next_obs).max(1, keepdim=True)[0]
    
            q_values = self.q_net(obs).gather(1, action)
    
            loss = self.loss_fn(q_values, target_q_values)
            loss.backward()
            self.optimizer.step()
       

    def update_target_network(self):
        self.target_net.load_state_dict(self.q_net.state_dict())

    def decay_epsilon(self):
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)


In [15]:
class QMIXAgent:
    def __init__(self, obs_shape, num_agents, action_dim, lr=1e-3, gamma=0.99, device="cpu"):
        self.device = torch.device(device)
        self.num_agents = num_agents
        self.action_dim = action_dim
        
        self.q_net = MyQNetwork(obs_shape, action_dim).to(self.device)
        self.target_q_net = MyQNetwork(obs_shape, action_dim).to(self.device)
        self.target_q_net.load_state_dict(self.q_net.state_dict())

        self.mixing_net = MixingNetwork(num_agents).to(self.device)
        self.target_mixing_net = MixingNetwork(num_agents).to(self.device)
        self.target_mixing_net.load_state_dict(self.mixing_net.state_dict())

        self.target_q_net.eval()
        self.target_mixing_net.eval()

        self.optimizer = optim.Adam(list(self.q_net.parameters()) + list(self.mixing_net.parameters()), lr=lr)
        self.gamma = gamma
        self.loss_fn = nn.MSELoss()

        self.epsilon = 1.0
        self.epsilon_decay = 0.98
        self.epsilon_min = 0.05

    def load_pretrained(self): 
        pass 
    

    def get_action(self, observation):
        if np.random.rand() < self.epsilon:
            return np.random.randint(self.action_dim)
        else:
            state_tensor = torch.FloatTensor(observation).unsqueeze(0).permute(0, 3, 1, 2).to(self.device)
            with torch.no_grad():
                return self.q_net(state_tensor).argmax().item()


    def train(self, dataloader):
        self.q_net.train()
        self.mixing_net.train()
    
        for obs, action, reward, next_obs, done in dataloader:
            self.q_net.zero_grad()
            self.mixing_net.zero_grad()
            
            # Đổi thứ tự permute và chuyển vào device
            obs = obs.permute(0, 1, 4, 2, 3).to(self.device)  # [batch_size, num_agents, channels, height, width]
            next_obs = next_obs.permute(0, 1, 4, 2, 3).to(self.device)
            action = action.to(self.device)  # [batch_size, num_agents]
            reward = reward.to(self.device)  # [batch_size, num_agents, 1]
            done = done.to(self.device)  # [batch_size, num_agents, 1]
    
            # Xác định agent chết và sống
            alive_agent_mask = (action != -1).float()  # [batch_size, num_agents]
        
            # Tính Q-values hiện tại
            obs_flat = obs.view(-1, *obs.shape[2:])  # [batch_size * num_agents, channels, height, width]
            obs_q_values = self.q_net(obs_flat).view(obs.size(0), obs.size(1), -1)  # [batch_size, num_agents, action_dim]
        
            current_q_values = obs_q_values.gather(-1, action.unsqueeze(-1).clamp(min=0)).squeeze(-1)  # [batch_size, num_agents]
            current_q_values = current_q_values * alive_agent_mask  # Bỏ qua Q-values của agent chết

            # Tính Q-values mục tiêu
            with torch.no_grad():
                next_obs_flat = next_obs.view(-1, *next_obs.shape[2:])  # [batch_size * num_agents, channels, height, width]
                next_q_values = self.target_q_net(next_obs_flat).view(next_obs.size(0), next_obs.size(1), -1)  # [batch_size, num_agents, action_dim]
    
                max_next_q_values = next_q_values.max(dim=-1)[0]  # [batch_size, num_agents]
                masked_next_q_values = max_next_q_values * alive_agent_mask  # Bỏ qua Q-values của agent chết
                
                target_q_totals = self.target_mixing_net(masked_next_q_values, next_obs)  # [batch_size, 1]
                targets = reward.mean(dim=1, keepdim=True) + self.gamma * target_q_totals * (1 - done.mean(dim=1, keepdim=True))
            
            # Tính tổng Q-values hiện tại
            current_q_totals = self.mixing_net(current_q_values, obs)  # [batch_size, 1]
    
            # Tính loss và cập nhật
            loss = self.loss_fn(current_q_totals, targets)
            loss.backward()
            self.optimizer.step()


       
    def update_target_network(self):
        self.target_q_net.load_state_dict(self.q_net.state_dict())
        self.target_mixing_net.load_state_dict(self.mixing_net.state_dict())

    def decay_epsilon(self):
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)


## Trainer

In [16]:
# import wandb
# from kaggle_secrets import UserSecretsClient

# user_secrets = UserSecretsClient()
# wandb_key = user_secrets.get_secret("wandb-key")

# wandb.login(key = wandb_key)

# wandb.init(project="RL_TRAINING", name="QMix_27", 
#             config={"epochs_num": 200, "opponents": "random, training with blue + red data", "batch_size" : 128, "num_agent": 81})


In [17]:
from time import time 

class Trainer : 
    """
    Sử dụng blue để huấn luyện 
    
    """
    def __init__(self, env, red_agent, blue_agent, buffer, batch_size, is_self_play = False): 
        self.red_agent = red_agent
        self.blue_agent = blue_agent
        self.buffer = buffer 
        self.batch_size = batch_size 
        self.env = env 
        self.is_self_play = is_self_play 

    def agent_give_action(self, name: str, observation):
        if self.is_self_play : 
            return  self.blue_agent.get_action(observation)
        if name == "blue": 
            return  self.blue_agent.get_action(observation)
        return self.red_agent.get_action(observation)


    
    def update_memory(self, is_longterm: bool = False): 
        """
        Tạo ra một vòng lặp lưu trữ và cập nhật dữ liệu cho từng agent 
        """
        self.env.reset()
        prev_obs = {}
        prev_actions = {}
        red_reward = 0 
        blue_reward = 0 

        prev_team = "red"

        n_kills = {"red": 0, "blue": 0}
        # vong lap 1 
        for idx, agent in enumerate(self.env.agent_iter()): 
            prev_ob, reward, termination, truncation, _ = self.env.last()
            team = agent.split("_")[0]
            n_kills[team] += (reward > 4.5)

            if truncation or termination: 
                prev_action = None
            else: 
                if agent.split("_")[0] == "red": 
                    prev_action =  self.agent_give_action("red", prev_ob)
                    red_reward += reward
                else: 
                    prev_action = self.agent_give_action("blue", prev_ob)
                    blue_reward += reward 
    

        
            prev_obs[agent] = prev_ob 
            prev_actions[agent] = prev_action 
            self.env.step(prev_action)

            if (idx + 1) % self.env.num_agents == 0: break 

        # vong lap 2 
        for agent in self.env.agent_iter(): 

            obs, reward, termination, truncation, _ = self.env.last()
            team = agent.split("_")[0]
            n_kills[team] += (reward > 4.5)
            
            if truncation or termination: 
                action = None 
            else: 
                if agent.split("_")[0] == "red" : 
                    action = self.agent_give_action("red", obs)
                    red_reward += reward 
                
                else: 
                    action = self.agent_give_action("blue", obs)
                    blue_reward += reward
                

            self.env.step(action)
            if isinstance(self.buffer, StateMemory):
                if team != prev_team : 
                    self.buffer.ensemble()  
                    prev_team = team
                idx = int(agent.split("_")[1]) % self.buffer.grouped_agents
                self.buffer.push(
                    idx,
                    prev_obs[agent], 
                    prev_actions[agent], 
                    reward, 
                    obs, 
                    termination 
                )
            else: 
                 self.buffer.push(
                    prev_obs[agent], 
                    prev_actions[agent], 
                    reward, 
                    obs, 
                    termination 
                )

            prev_obs[agent] = obs 
            prev_actions[agent] = action

        return  blue_reward - red_reward,  n_kills, blue_reward # red thắng  


    def save_model (self, file_path):
        
        torch.save(self.blue_agent.q_net.state_dict(), file_path)
        print(f"Model saved to {file_path}")
    
    def train(self, episodes=500, target_update_freq=2, is_type = "dqn"):
        gap_rewards = []


        for eps in range(episodes): 
            start = time()
            gap_reward, n_kills, blue_reward = self.update_memory()

            
            if is_type == "qmix": 
                self.buffer.ensemble()
            
            dataloader = DataLoader(self.buffer, batch_size = self.batch_size, shuffle = True)
            # print(f"Out of dataloader {len(self.buffer)}")
            self.blue_agent.train(dataloader)
            
            self.blue_agent.decay_epsilon()
            if eps % target_update_freq == 0:
                self.blue_agent.update_target_network()
    
            end = time() - start 
            
            # wandb.log({
            #     "episode": eps,
            #     "gap_rewards": gap_reward,
            #     "epsilon": self.blue_agent.epsilon,
            #     "time": end,
            #     "red_kill": n_kills["red"], 
            #     "blue_kill": n_kills["blue"]
            # })
    
            
            gap_rewards.append(gap_reward)
            print(f"Episode {eps}, Gap Reward: {gap_reward}, Total Reward: {blue_reward}, Epsilon: {self.blue_agent.epsilon:.2f}, Time: {end}, Kill: {n_kills}")
    
        self.env.close()



## Training 

- Với thuật toán Double Deep Q -> khởi tạo đối tượng DQNAgent + ReplayBuffer
- Với thuật toán QMix -> khởi tạo QMix Agent + StateMemory
- Với việc huấn luyện self-play -> set tham số is_self_play = True
  < code đang được huấn luyện cho blue agent> 

In [18]:
env = battle_v4.env(map_size=45, render_mode="rgb_array", attack_opponent_reward=0.5)

device = "cuda" if torch.cuda.is_available() else "cpu"

observation_shape = env.observation_space("red_0").shape
action_shape = env.action_space("red_0").n
num_agents = 27

blue_agent = DQNAgent(observation_shape,action_shape, device=device)

# blue_agent = QMIXAgent(observation_shape, num_agents , action_shape, device=device)

# blue_agent = PretrainedAgent(n_observation = observation_shape, n_actions = action_shape, device = device)
red_agent = RandomAgent(action_shape)
buffer = ReplayBuffer(capacity=10000)
# buffer = StateMemory(capacity = 10000, grouped_agents = num_agents)

trainer = Trainer(env, red_agent, blue_agent, buffer, batch_size = 64, is_self_play=False)
trainer.train(episodes = 70)


Episode 0, Gap Reward: -2.4750006729736924, Total Reward: -2987.9551359387115, Epsilon: 0.97, Time: 4.9963459968566895, Kill: {'red': 1, 'blue': 1}
Episode 1, Gap Reward: 129.58500491362065, Total Reward: -2907.700127983466, Epsilon: 0.94, Time: 5.836490154266357, Kill: {'red': 4, 'blue': 2}
Episode 2, Gap Reward: 7.53498422075063, Total Reward: -2852.0951310805976, Epsilon: 0.91, Time: 6.948957204818726, Kill: {'red': 0, 'blue': 23}
Episode 3, Gap Reward: 174.7499812869355, Total Reward: -2502.8851289544255, Epsilon: 0.89, Time: 8.242162466049194, Kill: {'red': 3, 'blue': 35}
Episode 4, Gap Reward: -1030.555068277754, Total Reward: -2239.6701236618683, Epsilon: 0.86, Time: 8.72317910194397, Kill: {'red': 3, 'blue': 72}
Episode 5, Gap Reward: -1469.000080970116, Total Reward: -2340.995119580999, Epsilon: 0.83, Time: 10.10652470588684, Kill: {'red': 0, 'blue': 79}
Episode 6, Gap Reward: -913.8050667922944, Total Reward: -1948.710114103742, Epsilon: 0.81, Time: 10.806957483291626, Kill: 

In [19]:
trainer.save_model("my_final.pt")

Model saved to my_final3.pt


# Eval 

In [33]:
class TestQAgent: 

    def __init__(self, n_observation, n_actions, model_path: str): 
        self.qnetwork = PretrainedQNetwork(n_observation, n_actions)
        self.n_action = n_actions
        self.qnetwork.load_state_dict(
            torch.load(model_path, weights_only=True, map_location="cpu")
        ) 

    def get_action(self, observation):

        if np.random.rand() < 0.05:
            return np.random.randint(self.n_action)
        else:
            observation = (
                        torch.Tensor(observation).float().permute([2, 0, 1]).unsqueeze(0)
                    )
            with torch.no_grad():
                q_values = self.qnetwork(observation)
                action = torch.argmax(q_values, dim=1).numpy()[0]

        return action


In [34]:
from tqdm import tqdm

def eval():
    max_cycles = 300
    env = battle_v4.env(map_size=45, max_cycles=max_cycles)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    def random_policy(env, agent, obs):
        return env.action_space(agent).sample()
    
    
    q_network = PretrainedQNetwork(
        env.observation_space("red_0").shape, env.action_space("red_0").n
    )
    q_network.load_state_dict(
        torch.load("/kaggle/input/pretrained/pytorch/default/1/red.pt", weights_only=True, map_location="cpu")
    )
    q_network.to(device)

    final_q_network = Final_QNets(
        env.observation_space("red_0").shape, env.action_space("red_0").n
    )
    final_q_network.load_state_dict(
        torch.load("/kaggle/input/final_rl/pytorch/default/1/red_final.pt", weights_only=True, map_location="cpu")
    )
    final_q_network.to(device)

    def my_policy(env, agent, obs):
        my_agent = TestQAgent(env.observation_space("red_0").shape,  env.action_space("red_0").n, model_path= '/kaggle/working/my_final.pt')
        return my_agent.get_action(obs)


    def pretrain_policy(env, agent, obs):
        observation = (
            torch.Tensor(obs).float().permute([2, 0, 1]).unsqueeze(0).to(device)
        )
        with torch.no_grad():
            q_values = q_network(observation)
        return torch.argmax(q_values, dim=1).cpu().numpy()[0]

    def final_pretrain_policy(env, agent, obs):
        observation = (
            torch.Tensor(obs).float().permute([2, 0, 1]).unsqueeze(0).to(device)
        )
        with torch.no_grad():
            q_values = final_q_network(observation)
        return torch.argmax(q_values, dim=1).cpu().numpy()[0]

    def run_eval(env, red_policy, blue_policy, n_episode: int = 100):
        red_win, blue_win = [], []
        red_tot_rw, blue_tot_rw = [], []
        n_agent_each_team = len(env.env.action_spaces) // 2
        blue_agents = []
        red_agents = []

        for _ in tqdm(range(n_episode)):
            env.reset()
            n_kill = {"red": 0, "blue": 0}
            red_reward, blue_reward = 0, 0

            for agent in env.agent_iter():
                observation, reward, termination, truncation, info = env.last()
                agent_team = agent.split("_")[0]

                n_kill[agent_team] += (
                    reward > 4.5
                )  
                if agent_team == "red":
                    red_reward += reward
                else:
                    blue_reward += reward

                if termination or truncation:
                    action = None  
                else:
                    if agent_team == "red":
                        action = red_policy(env, agent, observation)
                    else:
                        action = blue_policy(env, agent, observation)

                env.step(action)

            who_wins = "red" if n_kill["red"] >= n_kill["blue"] + 5 else "draw"
            who_wins = "blue" if n_kill["red"] + 5 <= n_kill["blue"] else who_wins
            red_win.append(who_wins == "red")
            blue_win.append(who_wins == "blue")

            blue_agents.append(n_kill["blue"])
            red_agents.append(n_kill["red"])

            red_tot_rw.append(red_reward / n_agent_each_team)
            blue_tot_rw.append(blue_reward / n_agent_each_team)

        return {
            "winrate_red": np.mean(red_win),
            "winrate_blue": np.mean(blue_win),
            "average_rewards_red": np.mean(red_tot_rw),
            "average_rewards_blue": np.mean(blue_tot_rw),
            "red_kill": np.mean(red_agents) / n_agent_each_team,
            "blue_kill": np.mean(blue_agents) / n_agent_each_team,
        }

    print("=" * 20)
    print("Eval with random policy")
    print(
        run_eval(
            env=env, red_policy=random_policy, blue_policy=my_policy, n_episode=30
        )
    )
    print("=" * 20)

    print("Eval with trained policy")
    print(
        run_eval(
            env=env, red_policy=pretrain_policy, blue_policy=my_policy, n_episode=30
        )
    )
    print("=" * 20)

    print("Eval with final trained policy")
    print(
        run_eval(
            env=env,
            red_policy=final_pretrain_policy,
            blue_policy=my_policy,
            n_episode=30,
        )
    )
    print("=" * 20)

In [None]:
eval()