In [1]:
!pip install -q git+https://github.com/Farama-Foundation/MAgent2

# Q-Learning 

In [7]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from collections import deque
import random
import torch.nn as nn
from magent2.environments import battle_v4

## Kiến trúc mạng Q 

In [3]:
class PretrainedQNetwork(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
        )
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, action_shape),
        )

    def forward(self, x):
        assert len(x.shape) >= 3
        x = self.cnn(x)
        if len(x.shape) == 3:
            batchsize = 1
        else:
            batchsize = x.shape[0]
        x = x.reshape(batchsize, -1)
        return self.network(x)

In [4]:
class Final_QNets(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
            nn.Conv2d(observation_shape[-1], observation_shape[-1], 3),
            nn.ReLU(),
        )
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1)
        dummy_output = self.cnn(dummy_input)
        flatten_dim = dummy_output.view(-1).shape[0]
        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            # nn.LayerNorm(120),
            nn.ReLU(),
            nn.Linear(120, 84),
            # nn.LayerNorm(84),
            nn.Tanh(),
        )
        self.last_layer = nn.Linear(84, action_shape)

    def forward(self, x):
        assert len(x.shape) >= 3, "only support magent input observation"
        x = self.cnn(x)
        if len(x.shape) == 3:
            batchsize = 1
        else:
            batchsize = x.shape[0]
        x = x.reshape(batchsize, -1)
        x = self.network(x)
        self.last_latent = x
        return self.last_layer(x)

In [5]:
class MyQNetwork(nn.Module):
    def __init__(self, observation_shape, action_shape):
        super().__init__()

        # CNN Feature Extractor with Reduced Parameters
        self.cnn = nn.Sequential(
            nn.Conv2d(observation_shape[-1], 16, kernel_size=3, padding=1),  # Reduced filters
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            
            nn.Conv2d(16, 32, kernel_size=3, padding=1),  # Reduced filters
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )

        # Adaptive Pooling with Smaller Output
        self.adaptive_pool = nn.AdaptiveAvgPool2d((2, 2))  # Smaller output size

        # Calculate the flattened dimension
        dummy_input = torch.randn(observation_shape).permute(2, 0, 1).unsqueeze(0)
        dummy_output = self.adaptive_pool(self.cnn(dummy_input))
        flatten_dim = dummy_output.reshape(-1).shape[0]

        # Fully Connected Layers with Reduced Parameters
        self.network = nn.Sequential(
            nn.Linear(flatten_dim, 128),  # Reduced size
            nn.ReLU(),
            nn.Linear(128, 64),  # Reduced size
            nn.ReLU(),
        )

        # Final Layer
        self.last_layer = nn.Linear(64, action_shape)  # Match action shape

    def forward(self, x):
        # Input shape: (batch_size, C, H, W)
        assert len(x.shape) == 4, "Input tensor must be 4D (batch_size, C, H, W)"
        
        # Pass through CNN
        x = self.cnn(x)
        x = self.adaptive_pool(x)
        
        # Flatten the features
        x = x.reshape(x.size(0), -1)
        
        # Pass through Fully Connected Layers
        x = self.network(x)

        # Output action values
        return self.last_layer(x)


In [6]:
input_shape = ( 13, 13,5)  
num_actions = 21

model = MyQNetwork(input_shape, num_actions)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_params}")

Number of trainable parameters: 31509


In [8]:
# Replay Buffer
from torch.utils.data import Dataset, DataLoader 

class ReplayBuffer(Dataset):
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*batch)
        return (np.stack(state), np.array(action), np.array(reward), 
                np.stack(next_state), np.array(done))
        
    def __len__(self):
        return len(self.buffer)


    def __getitem__(self, idx): 
        state, action, reward, next_state, done = self.buffer[idx]
        return (
            torch.tensor(state), 
            torch.tensor(action), 
            torch.tensor(reward, dtype = torch.float),
            torch.tensor(next_state), 
            torch.tensor(done, dtype = torch.float32)
        )

## Cài đặt Agent 

In [9]:
class RandomAgent:
    def __init__(self, action_space):
        self.n_action = action_space

    def get_action(self, observation):
        return torch.randint(0, self.n_action, (1,)).item()  

In [10]:
class PretrainedAgent:
    def __init__(self, n_observation, n_actions, device="cpu"):
        self.device = torch.device(device)
        self.qnetwork = PretrainedQNetwork(n_observation, n_actions).to(self.device)
        self.n_action = n_actions
        self.qnetwork.load_state_dict(
            torch.load("/kaggle/input/pretrained/pytorch/default/1/red.pt", weights_only=True, map_location=self.device)
        )

    def get_action(self, observation):
        if np.random.rand() < 0.1:
            return np.random.randint(self.n_action)
        else: 
            observation = (
                torch.Tensor(observation).float().permute([2, 0, 1]).unsqueeze(0).to(self.device)
            )
            with torch.no_grad():
                q_values = self.qnetwork(observation)
            action = torch.argmax(q_values, dim=1).cpu().numpy()[0]

        return action

In [11]:
class FinalAgent: 
    def __init__(self, n_observation, n_actions, device = "cpu"): 
        self.device = torch.device(device)

        self.final_networks = Final_QNets(n_observation, n_actions).to(self.device)

        self.final_networks.load_state_dict(
            torch.load("/kaggle/input/final_rl/pytorch/default/1/red_final.pt", weights_only = True, map_location = self.device)
        )

    def get_action(self, observation): 
        observation = (
            torch.Tensor(observation).float().permute([2, 0, 1]).unsqueeze(0).to(self.device)
        )
        with torch.no_grad():
            q_values = self.final_networks(observation)
        action = torch.argmax(q_values, dim=1).cpu().numpy()[0]

        return action
    

In [12]:
class DQNAgent:
    def __init__(self, observation_shape, action_shape, batch_size=64, lr=1e-3, gamma=0.6, device="cpu"):
        self.device = torch.device(device)
        self.q_net = MyQNetwork(observation_shape, action_shape).float().to(self.device)
        self.target_net = MyQNetwork(observation_shape, action_shape).float().to(self.device)
        self.target_net.load_state_dict(self.q_net.state_dict())
        self.target_net.eval()

        self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)
        self.batch_size = batch_size
        self.gamma = gamma
        self.action_shape = action_shape
        self.epsilon = 0.9
        self.epsilon_decay = 0.995
        self.epsilon_min = 0.1
        self.loss_fn = nn.MSELoss()
    

    def get_action(self, observation):
        if np.random.rand() < self.epsilon:
            return np.random.randint(self.action_shape)
        else:
            state_tensor = torch.FloatTensor(observation).unsqueeze(0).permute(0, 3, 1, 2).to(self.device)
            with torch.no_grad():
                return self.q_net(state_tensor).argmax().item()

    def train(self, dataloader):
        """
            cap nhat lai tham so mo hinh voi input dau vao 
        """
        self.q_net.train()
        for obs, action, reward, next_obs, done in dataloader: 
            self.q_net.zero_grad()
    
            obs = obs.permute(0, 3, 1, 2).to(self.device) 
            action = action.unsqueeze(1).to(self.device)
            reward = reward.unsqueeze(1).to(self.device)
            next_obs = next_obs.to(self.device)
            next_obs = next_obs.permute(0, 3, 1, 2).to(self.device)
            done = done.unsqueeze(1).to(self.device)
    
            # cap nhat gia tri q 
            with torch.no_grad(): 
                target_q_values = reward + self.gamma * (1 - done) * self.target_net(next_obs).max(1, keepdim=True)[0]
    
            q_values = self.q_net(obs).gather(1, action)
    
            loss = self.loss_fn(q_values, target_q_values)
            loss.backward()
            self.optimizer.step()
       

    def update_target_network(self):
        self.target_net.load_state_dict(self.q_net.state_dict())

    def decay_epsilon(self):
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)


## Trainer

In [13]:
import wandb
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()
wandb_key = user_secrets.get_secret("wandb-key")

wandb.login(key = wandb_key)

wandb.init(project="RL_TRAINING", name="DeepQ_complex_pretrained", 
           config={"epochs_num": 100, "opponents": "pretrained+random"})


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mtheseventeengv[0m ([33mtrungviet17[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [14]:
from time import time 

class Trainer : 

    def __init__(self, env, red_agent, blue_agent, buffer, batch_size): 
        self.red_agent = red_agent
        self.blue_agent = blue_agent
        self.buffer = buffer 
        self.batch_size = batch_size 
        self.env = env 

    def update_memory(self): 
        """
        Tạo ra một vòng lặp lưu trữ và cập nhật dữ liệu cho từng agent 
        """
        self.env.reset()
        prev_obs = {}
        prev_actions = {}
        red_reward = 0 
        blue_reward = 0 

        red_agents = 0 
        blue_agents = 0 

        # vong lap 1 
        for idx, agent in enumerate(self.env.agent_iter()): 
            prev_ob, reward, termination, truncation, _ = self.env.last()

            if truncation or termination: 
                prev_action = None
                if "red" in agent: red_agents +=1 
                else : blue_agents += 1
            else: 
                if agent.split("_")[0] == "red": 
                    prev_action =  self.red_agent.get_action(prev_ob)
                    red_reward += reward
                else: 
                    prev_action = self.blue_agent.get_action(prev_ob)
                    blue_reward += reward 
    

        
            prev_obs[agent] = prev_ob 
            prev_actions[agent] = prev_action 
            self.env.step(prev_action)

            if (idx + 1) % self.env.num_agents == 0: break 

        # vong lap 2 
        for agent in self.env.agent_iter(): 

            obs, reward, termination, truncation, _ = self.env.last()

            if truncation or termination: 
                action = None 
                if "red" in agent: red_agents +=1 
                else : blue_agents += 1
            else: 
                if agent.split("_")[0] == "red" : 
                    action = self.red_agent.get_action(obs)
                    red_reward += reward 
                
                else: 
                    action = self.blue_agent.get_action(obs)
                    blue_reward += reward
                

            self.env.step(action)

            self.buffer.push(
                prev_obs[agent], 
                prev_actions[agent], 
                reward, 
                obs, 
                termination 
            )


            prev_obs[agent] = obs 
            prev_actions[agent] = action

        return red_reward - blue_reward,  red_agents - blue_agents < 0 # red thắng  

    def save_model (self, file_path):
        
        torch.save(self.red_agent.q_net.state_dict(), file_path)
        print(f"Model saved to {file_path}")
    
    def train_dqn(self, episodes=500, target_update_freq=2):
        gap_rewards = []


        for eps in range(episodes): 
            start = time()
            gap_reward, winner = self.update_memory()
            dataloader = DataLoader(self.buffer, batch_size = self.batch_size, shuffle = True)

            self.red_agent.train(dataloader)
    
            self.red_agent.decay_epsilon()
            if eps % target_update_freq == 0:
                self.red_agent.update_target_network()
    
            end = time() - start 
            wandb.log({
                "episode": eps,
                "gap_rewards": gap_reward,
                "epsilon": red_agent.epsilon,
                "time": end
            })
    
            gap_rewards.append(gap_reward)
            print(f"Episode {eps}, Total Reward: {gap_reward}, Epsilon: {red_agent.epsilon:.2f}, Time: {end}, Winner : {winner}")
    
        self.env.close()


In [None]:
env = battle_v4.env(map_size=45, render_mode="rgb_array")

device = "cuda" if torch.cuda.is_available() else "cpu"

observation_shape = env.observation_space("red_0").shape
action_shape = env.action_space("red_0").n

red_agent = DQNAgent(observation_shape, action_shape, device=device)
blue_agent = PretrainedAgent(n_observation = observation_shape, n_actions = action_shape, device = device)
buffer = ReplayBuffer(capacity=10000)

trainer = Trainer(env, red_agent, blue_agent, buffer, batch_size = 64)
trainer.train_dqn(episodes = 200)


Episode 0, Total Reward: -442.8500009244308, Epsilon: 0.90, Time: 6.613807678222656, Winner : False
Episode 1, Total Reward: -459.4300009747967, Epsilon: 0.89, Time: 4.546820163726807, Winner : False
Episode 2, Total Reward: -461.89000153075904, Epsilon: 0.89, Time: 5.387091398239136, Winner : False
Episode 3, Total Reward: -322.16000114101917, Epsilon: 0.88, Time: 13.209748268127441, Winner : False
Episode 4, Total Reward: -404.33000082708895, Epsilon: 0.88, Time: 8.30617094039917, Winner : False
Episode 5, Total Reward: -452.6550003858283, Epsilon: 0.87, Time: 4.1059730052948, Winner : False
Episode 6, Total Reward: -410.6699999794364, Epsilon: 0.87, Time: 6.9331769943237305, Winner : False
Episode 7, Total Reward: -441.5150007158518, Epsilon: 0.86, Time: 4.839049816131592, Winner : False
Episode 8, Total Reward: -396.22000042535365, Epsilon: 0.86, Time: 7.333963632583618, Winner : False
Episode 9, Total Reward: -410.7400005105883, Epsilon: 0.86, Time: 8.232966184616089, Winner : Fal

In [16]:
trainer.save_model("my_model5.pt")

Model saved to my_model5.pt


In [23]:
blue_agent = PretrainedAgent(n_observation = observation_shape, n_actions = action_shape)

trainer = Trainer(env, red_agent, blue_agent, buffer, batch_size = 64)
trainer.train_dqn(episodes = 100)


Episode 0, Total Reward: -421.09499831404537, Epsilon: 0.60
Episode 1, Total Reward: 452.3800273099914, Epsilon: 0.60
Episode 2, Total Reward: 1920.02507760562, Epsilon: 0.60
Episode 3, Total Reward: -404.4649995714426, Epsilon: 0.59
Episode 4, Total Reward: -408.74000167287886, Epsilon: 0.59
Episode 5, Total Reward: -326.11000062618405, Epsilon: 0.59
Episode 6, Total Reward: -334.4300012467429, Epsilon: 0.58
Episode 7, Total Reward: -421.1899992218241, Epsilon: 0.58
Episode 8, Total Reward: -389.87000301200897, Epsilon: 0.58
Episode 9, Total Reward: -337.24499892350286, Epsilon: 0.58
Episode 10, Total Reward: -374.6600020201877, Epsilon: 0.57
Episode 11, Total Reward: -313.0649981154129, Epsilon: 0.57
Episode 12, Total Reward: -351.7350029973313, Epsilon: 0.57
Episode 13, Total Reward: -313.44000261928886, Epsilon: 0.56
Episode 14, Total Reward: -447.4950031125918, Epsilon: 0.56
Episode 15, Total Reward: -415.5199996698648, Epsilon: 0.56
Episode 16, Total Reward: -393.9050019849092, E

In [24]:
trainer.save_model("my_model2.pt")

Model saved to my_model2.pt


In [21]:
blue_agent = FinalAgent(n_observation = observation_shape, n_actions = action_shape)

trainer = Trainer(env, red_agent, blue_agent, buffer, batch_size = 64)
trainer.train_dqn(episodes = 100)


Episode 0, Total Reward: -835.3250285536051, Epsilon: 0.33, Time: 45.99792528152466, Winner : False
Episode 1, Total Reward: -733.5650315135717, Epsilon: 0.33, Time: 47.74588489532471, Winner : False
Episode 2, Total Reward: -956.6750405393541, Epsilon: 0.32, Time: 62.174410343170166, Winner : False
Episode 3, Total Reward: -163.13499972503632, Epsilon: 0.32, Time: 21.861814737319946, Winner : False
Episode 4, Total Reward: -37.12999481894076, Epsilon: 0.32, Time: 23.885837078094482, Winner : False
Episode 5, Total Reward: -251.18500122893602, Epsilon: 0.32, Time: 16.812191247940063, Winner : False
Episode 6, Total Reward: 49.85500467475504, Epsilon: 0.32, Time: 42.093571186065674, Winner : False
Episode 7, Total Reward: -238.65000720508397, Epsilon: 0.32, Time: 21.857694387435913, Winner : False
Episode 8, Total Reward: -526.9850208768621, Epsilon: 0.31, Time: 36.23681354522705, Winner : False
Episode 9, Total Reward: -496.96502331178635, Epsilon: 0.31, Time: 47.10835075378418, Winner

KeyboardInterrupt: 

In [26]:
trainer.save_model("my_model3.pt")

Model saved to my_model3.pt


# Q-Mix 

In [None]:
from collections import deque
import random

In [None]:
class AgentQNetwork(nn.Module):
    def __init__(self, obs_shape, action_size):
        super(AgentQNetwork, self).__init__()
        self.fc1 = nn.Linear(np.prod(obs_shape), 64)
        self.gru = nn.GRU(64, 64, batch_first=True)
        self.fc2 = nn.Linear(64, action_size)

    def forward(self, obs, hidden_state):
        obs = obs.flatten(start_dim=1) 
        x = torch.relu(self.fc1(obs))
        x, hidden_state = self.gru(x.unsqueeze(1), hidden_state)
        q_values = self.fc2(x.squeeze(1))
        return q_values, hidden_state

In [None]:
class MixingNetwork(nn.Module):
    def __init__(self, num_agents, state_shape):
        super(MixingNetwork, self).__init__()
        self.state_fc1 = nn.Linear(np.prod(state_shape), 64)
        self.qmix_fc = nn.Linear(64, num_agents)

    def forward(self, q_values, state):
        state = state.flatten(start_dim=1) 
        state_features = torch.relu(self.state_fc1(state))
        weights = self.qmix_fc(state_features).unsqueeze(-1)  
        q_total = torch.sum(q_values * weights, dim=1)
        return q_total


In [None]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, transition):
        state, observation, action, reward, next_state, done = transition
        action = action if action is not None else -1  # Replace None with a sentinel value
        self.buffer.append((state, observation, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, observations, actions, rewards, next_states, dones = zip(*batch)
        return (torch.tensor(states, dtype=torch.float32),
                torch.tensor(observations, dtype=torch.float32),
                torch.tensor(actions, dtype=torch.long),
                torch.tensor(rewards, dtype=torch.float32),
                torch.tensor(next_states, dtype=torch.float32),
                torch.tensor(dones, dtype=torch.float32))
    def __len__(self):
        return len(self.buffer)


In [None]:
env = battle_v4.env(map_size=45, render_mode="human")
env.reset()

# Hyperparameters
num_episodes = 1000
learning_rate = 0.001
gamma = 0.99

epsilon = 1.0
epsilon_decay = 0.995
epsilon_min = 0.05

batch_size = 64
buffer_capacity = 10000
update_target_every = 100

# Agent and mixing network setup
obs_shape = (13, 13, 5)
state_shape = (45, 45, 5)
action_size = 21
num_agents = 81

# Initialize shared Q-network, target network, and mixing network
shared_q_network = AgentQNetwork(obs_shape, action_size)
target_q_network = AgentQNetwork(obs_shape, action_size)
target_q_network.load_state_dict(shared_q_network.state_dict())
mixing_network = MixingNetwork(num_agents, state_shape)
mixing_target_network = MixingNetwork(num_agents, state_shape)
mixing_target_network.load_state_dict(mixing_network.state_dict())

# Optimizers
optimizer = optim.Adam(shared_q_network.parameters(), lr=learning_rate)
mix_optimizer = optim.Adam(mixing_network.parameters(), lr=learning_rate)

# Replay buffer
replay_buffer = ReplayBuffer(buffer_capacity)

In [None]:
for episode in range(num_episodes):
    env.reset()
    hidden_states = {agent: None for agent in env.agents}  # Initialize hidden states for GRU
    done = False
    epsilon = max(epsilon * epsilon_decay, epsilon_min)


    agent_observations = {}
    agent_rewards = {}
    agent_dones = {}
    count = 0 
    red_count = 0 
    blue_count = 0 
    
    for agent in env.agent_iter():
        observation, reward, termination, truncation, _ = env.last()
        
        if hidden_states[agent] is None:
            hidden_states[agent] = torch.zeros((1, 1, 64))

        if termination or truncation:
            if agent.split("_")[0] == 'red' : red_count += 1
            else: blue_count += 1 
            action = None
            print(f"Blue: {blue_count}, Red: {red_count}")
        else:
            if np.random.rand() < epsilon:
                action = np.random.randint(action_size)
            else:
                obs_tensor = torch.tensor(observation, dtype=torch.float32).unsqueeze(0)
                q_values, hidden_states[agent] = shared_q_network(obs_tensor, hidden_states[agent])
                action = torch.argmax(q_values).item()
        env.step(action)
        

        agent_observations[agent] = observation
        agent_rewards[agent] = reward
        agent_dones[agent] = termination or truncation

        replay_buffer.push((env.state(), observation, action, reward, env.state(), termination or truncation))
        if red_count == 81 or blue_count == 81 : break 

    # Sample from replay buffer and update networks
    if len(replay_buffer) >= batch_size:
        states, observations, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)

        # Q-value prediction
        q_values, _ = shared_q_network(observations, None)
        q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

        # Target Q-value prediction
        with torch.no_grad():
            next_q_values, _ = target_q_network(observations, None)
            next_q_values = next_q_values.max(1)[0]
            target_q_values = rewards + gamma * (1 - dones) * next_q_values

        loss = torch.mean((q_values - target_q_values) ** 2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Update mixing network
    q_total = mixing_network(q_values, states)
    with torch.no_grad():
        next_q_total = mixing_target_network(target_q_values.unsqueeze(1), next_states)
    mix_loss = torch.mean((q_total - next_q_total) ** 2)

    mix_optimizer.zero_grad()
    mix_loss.backward()
    mix_optimizer.step()

    # Update target networks
    if episode % update_target_every == 0:
        target_q_network.load_state_dict(shared_q_network.state_dict())
        mixing_target_network.load_state_dict(mixing_network.state_dict())

    print(f"Episode {episode + 1}/{num_episodes} completed.")
