In [9]:
import numpy as np
import collections
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import gym
from gym import spaces
import copy
import math
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

class MultiAgentActionSpace(list):
    def __init__(self, agents_action_space):
        for x in agents_action_space:
            assert isinstance(x, gym.spaces.Space)
        super().__init__(agents_action_space)
        self._agents_action_space = agents_action_space

    def sample(self):
        return [agent_action_space.sample() for agent_action_space in self._agents_action_space]

class MultiAgentObservationSpace(list):
    def __init__(self, agents_observation_space):
        for x in agents_observation_space:
            assert isinstance(x, gym.spaces.Space)
        super().__init__(agents_observation_space)
        self._agents_observation_space = agents_observation_space

    def sample(self):
        return [agent_observation_space.sample() for agent_observation_space in self._agents_observation_space]

class ReplayBuffer:
    def __init__(self, buffer_limit):
        self.buffer = collections.deque(maxlen=buffer_limit)

    def put(self, transition):
        self.buffer.append(transition)

    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], []

        for transition in mini_batch:
            s, a, r, s_prime, done = transition
            s_lst.append(s)
            a_lst.append(a)
            r_lst.append(r)
            s_prime_lst.append(s_prime)
            done_mask_lst.append((np.ones(len(done)) - done).tolist())

        return (torch.tensor(np.array(s_lst), dtype=torch.float),
                torch.tensor(np.array(a_lst), dtype=torch.long),
                torch.tensor(np.array(r_lst), dtype=torch.float),
                torch.tensor(np.array(s_prime_lst), dtype=torch.float),
                torch.tensor(np.array(done_mask_lst), dtype=torch.float))

    def size(self):
        return len(self.buffer)

class QNet(nn.Module):
    def __init__(self, observation_space, action_space):
        super(QNet, self).__init__()
        self.num_agents = len(observation_space)
        for agent_i in range(self.num_agents):
            n_obs = observation_space[agent_i].shape[0]
            setattr(self, f'agent_{agent_i}', nn.Sequential(
                nn.Linear(n_obs, 64),
                nn.ReLU(),
                nn.Linear(64, 32),
                nn.ReLU(),
                nn.Linear(32, action_space[agent_i].n)
            ))

    def forward(self, obs):
        q_values = [torch.empty(obs.shape[0], )] * self.num_agents
        for agent_i in range(self.num_agents):
            q_values[agent_i] = getattr(self, f'agent_{agent_i}')(obs[:, agent_i, :]).unsqueeze(1)
        return torch.cat(q_values, dim=1)

    def sample_action(self, obs, epsilon, roles):
        out = self.forward(obs)
        actions = torch.empty((out.shape[0], out.shape[1],), dtype=torch.long)
        for agent_i in range(out.shape[1]):
            if random.random() < epsilon:
                if roles[agent_i] == 'relay':
                    actions[:, agent_i] = 8  # STAY
                else:
                    actions[:, agent_i] = torch.randint(0, out.shape[2], (1,)).long()
            else:
                actions[:, agent_i] = out[:, agent_i].argmax().long()
        return actions

def train(q, q_target, memory, optimizer, gamma, batch_size, update_iter=10):
    for _ in range(update_iter):
        s, a, r, s_prime, done_mask = memory.sample(batch_size)
        q_out = q(s)
        a = a.unsqueeze(-1)
        q_a = q_out.gather(2, a).squeeze(-1)
        max_q_prime = q_target(s_prime).max(dim=2)[0]
        target = r + gamma * max_q_prime * done_mask
        loss = F.smooth_l1_loss(q_a, target.detach())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

class GridMultiAgent(gym.Env):
    metadata = {'render.modes': ['console']}
    XM, XP, YM, YP, XMYM, XMYP, XPYM, XPYP, STAY = range(9)
    OOE, OBS, POI, MAP, AGT, SURVIVOR = -2, -1, 0, 1, 2, 3

    def __init__(self, x_size=2, y_size=2, fov_x=3, fov_y=3, n_agents=2, n_survivors=1):
        super().__init__()
        self.x_size = x_size
        self.y_size = y_size
        self.n_agents = n_agents
        self.idx_agents = list(range(n_agents))
        self.n_survivors = n_survivors
        self.fov_x = fov_x
        self.fov_y = fov_y

        self.confidence_map = np.zeros((x_size, y_size))
        self.pheromone_map = np.zeros((x_size, y_size))
        self.confidence_decay = 0.95
        self.pheromone_decay = 0.9
        self.agent_roles = ['scout'] * n_agents
        self.stuck_counts = [0] * n_agents
        self.survivor_pos = []
        self.agent_paths = [[] for _ in range(n_agents)]

        n_actions = 9
        self.action_space = MultiAgentActionSpace([spaces.Discrete(n_actions) for _ in range(n_agents)])
        obs_size = (fov_x * fov_y - 1) + 1 + 1 + 1 + 3
        self.obs_low = np.concatenate([np.ones(fov_x * fov_y - 1) * self.OOE, [0, 0, 0, 0, 0, 0]], dtype=np.float32)
        self.obs_high = np.concatenate([np.ones(fov_x * fov_y - 1) * self.SURVIVOR, [1, 1, 1, 1, 1, 1]], dtype=np.float32)
        self.observation_space = MultiAgentObservationSpace([
            spaces.Box(self.obs_low, self.obs_high, dtype=np.float32) for _ in range(n_agents)
        ])

        self.init_grid()
        self.init_agent()
        self.init_survivors()

    def init_grid(self):
        self.grid_status = np.zeros((self.x_size, self.y_size))
        self.grid_counts = np.tile(self.grid_status, (self.n_agents, 1, 1)).reshape(self.n_agents, self.x_size, self.y_size)
        self.n_poi = self.x_size * self.y_size - np.count_nonzero(self.grid_status)
        self.grid_agents_status = copy.deepcopy(self.grid_status)

    def init_agent(self):
        self.agent_pos = []
        self.agent_paths = [[] for _ in range(self.n_agents)]
        for i in range(self.n_agents):
            while True:
                x, y = random.randrange(0, self.x_size), random.randrange(0, self.y_size)
                if self.grid_agents_status[x, y] == self.POI:
                    self.agent_pos.append([x, y])
                    self.grid_agents_status[x, y] = self.AGT
                    self.agent_paths[i].append([x, y])
                    break
        self.stuck_counts = [0] * n_agents

    def init_survivors(self):
        self.survivor_pos = []
        for _ in range(self.n_survivors):
            while True:
                x, y = random.randrange(0, self.x_size), random.randrange(0, self.y_size)
                if self.grid_status[x, y] == self.POI and [x, y] not in self.agent_pos:
                    self.survivor_pos.append([x, y])
                    self.grid_status[x, y] = self.SURVIVOR
                    self.confidence_map[x, y] = 1.0
                    break

    def grid_overlay(self):
        self.grid_agents_status = copy.deepcopy(self.grid_status)
        for i in range(self.n_agents):
            x, y = self.agent_pos[i]
            self.grid_agents_status[x, y] = self.AGT

    def get_rf_signal(self, agent_pos, survivor_pos):
        distance = np.sqrt((agent_pos[0] - survivor_pos[0])**2 + (agent_pos[1] - survivor_pos[1])**2)
        return min(1.0, 1.0 / (distance + 1)) if distance > 0 else 1.0

    def get_agent_obs(self):
        self.agent_obs = []
        for agent in range(self.n_agents):
            single_obs = np.ones((self.fov_x, self.fov_y)) * self.OOE
            x, y = self.agent_pos[agent]
            for i in range(self.fov_x):
                for j in range(self.fov_y):
                    obs_x = x + (i - 1)
                    obs_y = y + (j - 1)
                    if 0 <= obs_x < self.x_size and 0 <= obs_y < self.y_size:
                        single_obs[i][j] = self.grid_agents_status[obs_x][obs_y]
            single_obs_flat = single_obs.flatten()
            single_obs_flat = np.array([v for k, v in enumerate(single_obs_flat) if k != math.floor(self.fov_x * self.fov_y / 2)])

            rf_signal = max([self.get_rf_signal(self.agent_pos[agent], s) for s in self.survivor_pos], default=0.0)
            confidence_level = self.confidence_map[x, y]
            pheromone_level = self.pheromone_map[x, y]
            role_encoding = {'scout': [1, 0, 0], 'verification': [0, 1, 0], 'relay': [0, 0, 1]}
            role_vec = role_encoding[self.agent_roles[agent]]

            obs = np.concatenate([single_obs_flat, [confidence_level, pheromone_level, rf_signal], role_vec])
            self.agent_obs.append(obs.astype(np.float32))
        return self.agent_obs

    def update_roles(self):
        for i in range(self.n_agents):
            x, y = self.agent_pos[i]
            confidence = self.confidence_map[x, y]
            if confidence > 0.7 and self.agent_roles[i] != 'verification':
                self.agent_roles[i] = 'verification'
            elif confidence > 0.5 and self.agent_roles[i] != 'relay':
                self.agent_roles[i] = 'relay'
            elif self.agent_roles[i] != 'scout':
                self.agent_roles[i] = 'scout'

    def update_confidence_and_pheromones(self):
        self.confidence_map *= self.confidence_decay
        self.pheromone_map *= self.pheromone_decay
        for i in range(self.n_agents):
            x, y = self.agent_pos[i]
            rf_signal = max([self.get_rf_signal(self.agent_pos[i], s) for s in self.survivor_pos], default=0.0)
            confidence = rf_signal
            self.confidence_map[x, y] = min(1.0, self.confidence_map[x, y] + confidence)
            if confidence > 0.5:
                self.pheromone_map[x, y] = min(1.0, self.pheromone_map[x, y] + 0.5)

    def get_coverage(self):
        mapped_poi = np.count_nonzero(self.grid_status == self.MAP)
        return mapped_poi / self.n_poi

    def get_survivor_detection_rate(self):
        detected = sum(1 for s in self.survivor_pos if self.confidence_map[s[0], s[1]] > 0.8)
        return detected / self.n_survivors if self.n_survivors > 0 else 0.0

    def step(self, action, i):
        org_x, org_y = self.agent_pos[i]
        reward = 0
        action = int(action)
        if action == self.XM:
            self.agent_pos[i][0] -= 1
        elif action == self.XP:
            self.agent_pos[i][0] += 1
        elif action == self.YM:
            self.agent_pos[i][1] -= 1
        elif action == self.YP:
            self.agent_pos[i][1] += 1
        elif action == self.XMYM:
            self.agent_pos[i][0] -= 1
            self.agent_pos[i][1] -= 1
        elif action == self.XMYP:
            self.agent_pos[i][0] -= 1
            self.agent_pos[i][1] += 1
        elif action == self.XPYM:
            self.agent_pos[i][0] += 1
            self.agent_pos[i][1] -= 1
        elif action == self.XPYP:
            self.agent_pos[i][0] += 1
            self.agent_pos[i][1] += 1
        elif action == self.STAY:
            reward = 5 if self.agent_roles[i] == 'relay' and self.confidence_map[org_x, org_y] > 0.5 else -1
        else:
            raise ValueError(f"Invalid action={action}")

        if (self.agent_pos[i][0] >= self.x_size or self.agent_pos[i][0] < 0 or
            self.agent_pos[i][1] >= self.y_size or self.agent_pos[i][1] < 0 or
            self.grid_status[self.agent_pos[i][0], self.agent_pos[i][1]] in [self.OBS, self.AGT]):
            self.agent_pos[i] = [org_x, org_y]
            self.grid_counts[i][org_x, org_y] += 1
            reward = -2
        else:
            prev_status = self.grid_status[self.agent_pos[i][0], self.agent_pos[i][1]]
            if prev_status == self.POI:
                self.grid_status[self.agent_pos[i][0], self.agent_pos[i][1]] = self.MAP
                self.grid_counts[i][self.agent_pos[i][0], self.agent_pos[i][1]] += 1
                reward = 5
            elif prev_status == self.SURVIVOR:
                self.grid_counts[i][self.agent_pos[i][0], self.agent_pos[i][1]] += 1
                reward = 50
            elif prev_status == self.MAP:
                self.grid_counts[i][self.agent_pos[i][0], self.agent_pos[i][1]] += 1
                reward = -1

        self.agent_paths[i].append(self.agent_pos[i][:])

        nearby_agents = sum(1 for j in range(self.n_agents) if j != i and
                           np.sqrt((self.agent_pos[i][0] - self.agent_pos[j][0])**2 +
                                   (self.agent_pos[i][1] - self.agent_pos[j][1])**2) < 2)
        if nearby_agents > 0:
            reward -= 2 * nearby_agents

        if org_x == self.agent_pos[i][0] and org_y == self.agent_pos[i][1]:
            self.stuck_counts[i] += 1
        else:
            self.stuck_counts[i] = 0

        self.update_confidence_and_pheromones()
        self.update_roles()
        self.grid_overlay()

        mapped_poi = np.count_nonzero(self.grid_status == self.MAP)
        done = mapped_poi == self.n_poi or self.get_survivor_detection_rate() == 1.0
        info = {'survivor_detection_rate': self.get_survivor_detection_rate()}

        return self.get_agent_obs(), reward, done, info

    def reset(self):
        self.init_grid()
        self.init_agent()
        self.init_survivors()
        self.confidence_map = np.zeros((self.x_size, self.y_size))
        self.pheromone_map = np.zeros((self.x_size, self.y_size))
        self.agent_roles = ['scout'] * self.n_agents
        while True:
            obs = self.get_agent_obs()
            obs_tf = [obs[i][0] != 0 and obs[i][1] != 0 and obs[i][2] != 0 and obs[i][3] != 0
                      for i in range(self.n_agents)]
            if any(obs_tf):
                self.init_grid()
                self.init_agent()
                self.init_survivors()
            else:
                break
        return self.get_agent_obs()

    def close(self):
        pass

# Training
size = 2
fov = 3
n_agents = 2
n_survivors = 1
train_episodes = 1000
max_steps = 8
batch_size = 32
gamma = 0.99
buffer_limit = 50000
log_interval = 20
max_epsilon = 0.9
min_epsilon = 0.1
warm_up_steps = 1000
update_iter = 10
coverage_threshold = 0.95
lr = 0.001

env = GridMultiAgent(x_size=size, y_size=size, fov_x=fov, fov_y=fov, n_agents=n_agents, n_survivors=n_survivors)
memory = ReplayBuffer(buffer_limit)
q = QNet(env.observation_space, env.action_space)
q_target = QNet(env.observation_space, env.action_space)
q_target.load_state_dict(q.state_dict())
optimizer = optim.Adam(q.parameters(), lr=lr)

time_steps, epsilons, coverage, survivor_detection, total_reward, all_agent_paths = [], [], [], [], [], []

for episode in tqdm(range(train_episodes), desc="Training Episodes"):
    score = np.zeros(n_agents)
    state = env.reset()
    done = False
    episode_step = 0
    epsilon = max(min_epsilon, max_epsilon - (max_epsilon - min_epsilon) * (episode / (0.4 * train_episodes)))

    while not done:
        state_np = np.array(state, dtype=np.float32)
        action = q.sample_action(torch.tensor(state_np).unsqueeze(0), epsilon, env.agent_roles)[0].data.cpu().numpy()
        rewards = np.zeros(n_agents)
        next_state = state
        for agent_i in env.idx_agents:
            next_state, reward, done, info = env.step(action[agent_i], agent_i)
            rewards[agent_i] = reward
            if done:
                break
        memory.put((state, action, rewards, next_state, [done] * n_agents))
        score += rewards
        state = next_state
        episode_step += 1
        if episode_step >= max_steps or done:
            break

    if memory.size() > warm_up_steps:
        train(q, q_target, memory, optimizer, gamma, batch_size, update_iter)

    if episode % log_interval == 0:
        q_target.load_state_dict(q.state_dict())

    all_agent_paths.append([path[:] for path in env.agent_paths])
    time_steps.append(episode_step)
    epsilons.append(epsilon)
    coverage.append(env.get_coverage())
    survivor_detection.append(info['survivor_detection_rate'])
    total_reward.append(score.sum())

    if episode % log_interval == 0:
        print(f'//Episode {episode+1}// Epsilon: {epsilon:.3f}, Steps: {episode_step}, '
              f'Coverage (%): {coverage[-1]:.3f}, Survivor Detection (%): {survivor_detection[-1]:.3f}, '
              f'Total Reward: {total_reward[-1]}')

# Metrics Visualization
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(coverage, label='Coverage')
plt.title('Coverage Over Episodes')
plt.xlabel('Episode')
plt.ylabel('Coverage (%)')
plt.grid(True)
plt.legend()

plt.subplot(1, 3, 2)
plt.plot(survivor_detection, label='Survivor Detection', color='orange')
plt.title('Survivor Detection Over Episodes')
plt.xlabel('Episode')
plt.ylabel('Detection Rate (%)')
plt.grid(True)
plt.legend()

plt.subplot(1, 3, 3)
plt.plot(total_reward, label='Total Reward', color='green')
plt.title('Total Reward Over Episodes')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.savefig('metrics_original1000.png')
plt.close()

# Drone Path Animation (Last Episode)
def animate_drone_paths(episode_idx=-1):
    paths = all_agent_paths[episode_idx]
    final_grid = env.grid_status
    final_confidence = env.confidence_map
    final_survivors = env.survivor_pos

    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xlim(-0.5, env.x_size - 0.5)
    ax.set_ylim(-0.5, env.y_size - 0.5)
    ax.set_xticks(range(env.x_size))
    ax.set_yticks(range(env.y_size))
    ax.grid(True)
    ax.set_title(f'Drone Paths (Episode {episode_idx + 1})')

    obs_x, obs_y = np.where(final_grid == env.OBS)
    ax.scatter(obs_x, obs_y, c='black', marker='s', s=100, label='Obstacles')
    heatmap = ax.imshow(final_confidence.T, cmap='hot', alpha=0.5, origin='lower',
                        extent=(-0.5, env.x_size - 0.5, -0.5, env.y_size - 0.5))
    survivor_x, survivor_y = zip(*final_survivors) if final_survivors else ([], [])
    survivors = ax.scatter(survivor_x, survivor_y, c='blue', marker='*', s=200, label='Survivors')
    colors = ['red', 'green']
    agent_plots = [ax.plot([], [], c=colors[i], marker='o', linestyle='-', label=f'Agent {i}')[0]
                   for i in range(n_agents)]
    ax.legend()

    def update(frame):
        for i, plot in enumerate(agent_plots):
            if frame < len(paths[i]):
                x, y = zip(*paths[i][:frame + 1]) if paths[i] else ([], [])
                plot.set_data(x, y)
        return agent_plots + [heatmap, survivors]

    ani = FuncAnimation(fig, update, frames=min(len(p) for p in paths), interval=200, blit=True)
    plt.close(fig)
    return ani

# Save animation
ani = animate_drone_paths(-1)
with open('animation_original.html', 'w') as f:
    f.write(ani.to_jshtml())

# Save metrics for comparison
np.savez('metrics_original.npz', coverage=coverage, survivor_detection=survivor_detection, total_reward=total_reward)

Training Episodes:   2%|▏         | 19/1000 [00:00<00:05, 169.66it/s]

//Episode 1// Epsilon: 0.900, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: -4.0
//Episode 21// Epsilon: 0.860, Steps: 8, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 12.0


Training Episodes:   7%|▋         | 73/1000 [00:00<00:05, 167.19it/s]

//Episode 41// Epsilon: 0.820, Steps: 5, Coverage (%): 0.750, Survivor Detection (%): 1.000, Total Reward: 33.0
//Episode 61// Epsilon: 0.780, Steps: 5, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 23.0


Training Episodes:  11%|█         | 107/1000 [00:00<00:05, 156.72it/s]

//Episode 81// Epsilon: 0.740, Steps: 2, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 43.0
//Episode 101// Epsilon: 0.700, Steps: 2, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  12%|█▏        | 123/1000 [00:00<00:05, 151.04it/s]

//Episode 121// Epsilon: 0.660, Steps: 3, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 32.0


Training Episodes:  16%|█▌        | 158/1000 [00:01<00:06, 121.15it/s]

//Episode 141// Epsilon: 0.620, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 3.0
//Episode 161// Epsilon: 0.580, Steps: 5, Coverage (%): 0.750, Survivor Detection (%): 1.000, Total Reward: 33.0


Training Episodes:  20%|██        | 205/1000 [00:01<00:05, 133.50it/s]

//Episode 181// Epsilon: 0.540, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: -4.0
//Episode 201// Epsilon: 0.500, Steps: 6, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 18.0


Training Episodes:  24%|██▍       | 243/1000 [00:01<00:04, 154.08it/s]

//Episode 221// Epsilon: 0.460, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 44.0
//Episode 241// Epsilon: 0.420, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: -4.0


Training Episodes:  27%|██▋       | 273/1000 [00:01<00:03, 188.71it/s]

//Episode 261// Epsilon: 0.380, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: -4.0
//Episode 281// Epsilon: 0.340, Steps: 6, Coverage (%): 0.750, Survivor Detection (%): 1.000, Total Reward: 27.0


Training Episodes:  31%|███▏      | 314/1000 [00:02<00:04, 170.38it/s]

//Episode 301// Epsilon: 0.300, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: -4.0
//Episode 321// Epsilon: 0.260, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  33%|███▎      | 332/1000 [00:02<00:05, 129.37it/s]

//Episode 341// Epsilon: 0.220, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 3.0


Training Episodes:  36%|███▌      | 358/1000 [00:05<00:34, 18.36it/s]

//Episode 361// Epsilon: 0.180, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  38%|███▊      | 385/1000 [00:08<00:42, 14.35it/s]

//Episode 381// Epsilon: 0.140, Steps: 3, Coverage (%): 0.750, Survivor Detection (%): 1.000, Total Reward: 51.0


Training Episodes:  40%|████      | 403/1000 [00:09<00:32, 18.53it/s]

//Episode 401// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 3.0


Training Episodes:  42%|████▏     | 424/1000 [00:10<00:27, 21.03it/s]

//Episode 421// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 44.0


Training Episodes:  44%|████▍     | 445/1000 [00:11<00:27, 20.28it/s]

//Episode 441// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 3.0


Training Episodes:  46%|████▋     | 463/1000 [00:12<00:25, 21.08it/s]

//Episode 461// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  48%|████▊     | 484/1000 [00:13<00:24, 20.70it/s]

//Episode 481// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 3.0


Training Episodes:  50%|█████     | 505/1000 [00:14<00:23, 21.08it/s]

//Episode 501// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 3.0


Training Episodes:  52%|█████▏    | 523/1000 [00:15<00:23, 20.08it/s]

//Episode 521// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 3.0


Training Episodes:  54%|█████▍    | 544/1000 [00:15<00:21, 21.56it/s]

//Episode 541// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 3.0


Training Episodes:  56%|█████▋    | 565/1000 [00:17<00:21, 20.38it/s]

//Episode 561// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  58%|█████▊    | 582/1000 [00:18<00:31, 13.44it/s]

//Episode 581// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  60%|██████    | 602/1000 [00:19<00:27, 14.43it/s]

//Episode 601// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 3.0


Training Episodes:  62%|██████▏   | 622/1000 [00:21<00:28, 13.21it/s]

//Episode 621// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  64%|██████▍   | 644/1000 [00:22<00:16, 21.02it/s]

//Episode 641// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  66%|██████▋   | 665/1000 [00:23<00:15, 21.81it/s]

//Episode 661// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  68%|██████▊   | 683/1000 [00:23<00:14, 21.21it/s]

//Episode 681// Epsilon: 0.100, Steps: 3, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 44.0


Training Episodes:  70%|███████   | 704/1000 [00:24<00:14, 20.23it/s]

//Episode 701// Epsilon: 0.100, Steps: 2, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 51.0


Training Episodes:  72%|███████▎  | 725/1000 [00:25<00:12, 21.67it/s]

//Episode 721// Epsilon: 0.100, Steps: 3, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 44.0


Training Episodes:  74%|███████▍  | 744/1000 [00:26<00:12, 19.76it/s]

//Episode 741// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  76%|███████▋  | 764/1000 [00:27<00:11, 20.87it/s]

//Episode 761// Epsilon: 0.100, Steps: 2, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 54.0


Training Episodes:  78%|███████▊  | 785/1000 [00:28<00:10, 20.73it/s]

//Episode 781// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  80%|████████  | 803/1000 [00:29<00:09, 20.12it/s]

//Episode 801// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  82%|████████▏ | 824/1000 [00:30<00:09, 19.22it/s]

//Episode 821// Epsilon: 0.100, Steps: 2, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 54.0


Training Episodes:  84%|████████▍ | 843/1000 [00:32<00:10, 14.60it/s]

//Episode 841// Epsilon: 0.100, Steps: 2, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 47.0


Training Episodes:  86%|████████▋ | 863/1000 [00:33<00:09, 13.73it/s]

//Episode 861// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  88%|████████▊ | 883/1000 [00:35<00:08, 13.63it/s]

//Episode 881// Epsilon: 0.100, Steps: 2, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 54.0


Training Episodes:  90%|█████████ | 904/1000 [00:36<00:05, 18.88it/s]

//Episode 901// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 3.0


Training Episodes:  92%|█████████▏| 924/1000 [00:37<00:04, 18.92it/s]

//Episode 921// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  94%|█████████▍| 943/1000 [00:38<00:02, 19.94it/s]

//Episode 941// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  96%|█████████▋| 963/1000 [00:39<00:01, 20.87it/s]

//Episode 961// Epsilon: 0.100, Steps: 2, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 47.0


Training Episodes:  98%|█████████▊| 984/1000 [00:40<00:00, 20.35it/s]

//Episode 981// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 3.0


Training Episodes: 100%|██████████| 1000/1000 [00:40<00:00, 24.45it/s]


In [10]:
import numpy as np
import collections
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import gym
from gym import spaces
import copy
import math
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import ipywidgets as widgets

class MultiAgentActionSpace(list):
    def __init__(self, agents_action_space):
        for x in agents_action_space:
            assert isinstance(x, gym.spaces.Space)
        super().__init__(agents_action_space)
        self._agents_action_space = agents_action_space

    def sample(self):
        return [agent_action_space.sample() for agent_action_space in self._agents_action_space]

class MultiAgentObservationSpace(list):
    def __init__(self, agents_observation_space):
        for x in agents_observation_space:
            assert isinstance(x, gym.spaces.Space)
        super().__init__(agents_observation_space)
        self._agents_observation_space = agents_observation_space

    def sample(self):
        return [agent_observation_space.sample() for agent_observation_space in self._agents_observation_space]

class ReplayBuffer:
    def __init__(self, buffer_limit):
        self.buffer = collections.deque(maxlen=buffer_limit)

    def put(self, transition):
        self.buffer.append(transition)

    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], []

        for transition in mini_batch:
            s, a, r, s_prime, done = transition
            s_lst.append(s)
            a_lst.append(a)
            r_lst.append(r)
            s_prime_lst.append(s_prime)
            done_mask_lst.append((np.ones(len(done)) - done).tolist())

        return (torch.tensor(np.array(s_lst), dtype=torch.float),
                torch.tensor(np.array(a_lst), dtype=torch.long),
                torch.tensor(np.array(r_lst), dtype=torch.float),
                torch.tensor(np.array(s_prime_lst), dtype=torch.float),
                torch.tensor(np.array(done_mask_lst), dtype=torch.float))

    def size(self):
        return len(self.buffer)

class QNet(nn.Module):
    def __init__(self, observation_space, action_space):
        super(QNet, self).__init__()
        self.num_agents = len(observation_space)
        for agent_i in range(self.num_agents):
            n_obs = observation_space[agent_i].shape[0]
            setattr(self, f'agent_{agent_i}', nn.Sequential(
                nn.Linear(n_obs, 64),
                nn.ReLU(),
                nn.Linear(64, 32),
                nn.ReLU(),
                nn.Linear(32, action_space[agent_i].n)
            ))

    def forward(self, obs):
        q_values = [torch.empty(obs.shape[0], )] * self.num_agents
        for agent_i in range(self.num_agents):
            q_values[agent_i] = getattr(self, f'agent_{agent_i}')(obs[:, agent_i, :]).unsqueeze(1)
        return torch.cat(q_values, dim=1)

    def sample_action(self, obs, epsilon, roles):
        out = self.forward(obs)
        actions = torch.empty((out.shape[0], out.shape[1],), dtype=torch.long)
        for agent_i in range(out.shape[1]):
            if random.random() < epsilon:
                if roles[agent_i] == 'relay':
                    actions[:, agent_i] = 8  # STAY
                else:
                    actions[:, agent_i] = torch.randint(0, out.shape[2], (1,)).long()
            else:
                actions[:, agent_i] = out[:, agent_i].argmax().long()
        return actions

def train(q, q_target, memory, optimizer, gamma, batch_size, update_iter=10):
    for _ in range(update_iter):
        s, a, r, s_prime, done_mask = memory.sample(batch_size)
        q_out = q(s)
        a = a.unsqueeze(-1)
        q_a = q_out.gather(2, a).squeeze(-1)
        max_q_prime = q_target(s_prime).max(dim=2)[0]
        target = r + gamma * max_q_prime * done_mask
        loss = F.smooth_l1_loss(q_a, target.detach())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

class GridMultiAgent(gym.Env):
    metadata = {'render.modes': ['console']}
    XM, XP, YM, YP, XMYM, XMYP, XPYM, XPYP, STAY = range(9)
    OOE, OBS, POI, MAP, AGT, SURVIVOR = -2, -1, 0, 1, 2, 3

    def __init__(self, x_size=2, y_size=2, fov_x=3, fov_y=3, n_agents=2, n_survivors=1):
        super().__init__()
        self.x_size = x_size
        self.y_size = y_size
        self.n_agents = n_agents
        self.idx_agents = list(range(n_agents))
        self.n_survivors = n_survivors
        self.fov_x = fov_x
        self.fov_y = fov_y

        self.confidence_map = np.zeros((x_size, y_size))
        self.pheromone_map = np.zeros((x_size, y_size))
        self.confidence_decay = 0.95
        self.pheromone_decay = 0.9
        self.agent_roles = ['scout'] * n_agents
        self.stuck_counts = [0] * n_agents
        self.survivor_pos = []
        self.agent_paths = [[] for _ in range(n_agents)]
        self.agent_roles_history = [[] for _ in range(n_agents)]

        n_actions = 9
        self.action_space = MultiAgentActionSpace([spaces.Discrete(n_actions) for _ in range(n_agents)])
        obs_size = (fov_x * fov_y - 1) + 1 + 1 + 1 + 3
        self.obs_low = np.concatenate([np.ones(fov_x * fov_y - 1) * self.OOE, [0, 0, 0, 0, 0, 0]], dtype=np.float32)
        self.obs_high = np.concatenate([np.ones(fov_x * fov_y - 1) * self.SURVIVOR, [1, 1, 1, 1, 1, 1]], dtype=np.float32)
        self.observation_space = MultiAgentObservationSpace([
            spaces.Box(self.obs_low, self.obs_high, dtype=np.float32) for _ in range(n_agents)
        ])

        self.init_grid()
        self.init_agent()
        self.init_survivors()

    def init_grid(self):
        self.grid_status = np.zeros((self.x_size, self.y_size))
        self.grid_counts = np.tile(self.grid_status, (self.n_agents, 1, 1)).reshape(self.n_agents, self.x_size, self.y_size)
        self.n_poi = self.x_size * self.y_size - np.count_nonzero(self.grid_status)
        self.grid_agents_status = copy.deepcopy(self.grid_status)

    def init_agent(self):
        self.agent_pos = []
        self.agent_paths = [[] for _ in range(self.n_agents)]
        self.agent_roles_history = [[] for _ in range(self.n_agents)]
        for i in range(self.n_agents):
            while True:
                x, y = random.randrange(0, self.x_size), random.randrange(0, self.y_size)
                if self.grid_agents_status[x, y] == self.POI:
                    self.agent_pos.append([x, y])
                    self.grid_agents_status[x, y] = self.AGT
                    self.agent_paths[i].append([x, y])
                    self.agent_roles_history[i].append(self.agent_roles[i])
                    break
        self.stuck_counts = [0] * n_agents

    def init_survivors(self):
        self.survivor_pos = []
        for _ in range(self.n_survivors):
            while True:
                x, y = random.randrange(0, self.x_size), random.randrange(0, self.y_size)
                if self.grid_status[x, y] == self.POI and [x, y] not in self.agent_pos:
                    self.survivor_pos.append([x, y])
                    self.grid_status[x, y] = self.SURVIVOR
                    self.confidence_map[x, y] = 1.0
                    break

    def grid_overlay(self):
        self.grid_agents_status = copy.deepcopy(self.grid_status)
        for i in range(self.n_agents):
            x, y = self.agent_pos[i]
            self.grid_agents_status[x, y] = self.AGT

    def get_rf_signal(self, agent_pos, survivor_pos):
        distance = np.sqrt((agent_pos[0] - survivor_pos[0])**2 + (agent_pos[1] - survivor_pos[1])**2)
        return min(1.0, 1.0 / (distance + 1)) if distance > 0 else 1.0

    def get_agent_obs(self):
        self.agent_obs = []
        for agent in range(self.n_agents):
            single_obs = np.ones((self.fov_x, self.fov_y)) * self.OOE
            x, y = self.agent_pos[agent]
            for i in range(self.fov_x):
                for j in range(self.fov_y):
                    obs_x = x + (i - 1)
                    obs_y = y + (j - 1)
                    if 0 <= obs_x < self.x_size and 0 <= obs_y < self.y_size:
                        single_obs[i][j] = self.grid_agents_status[obs_x][obs_y]
            single_obs_flat = single_obs.flatten()
            single_obs_flat = np.array([v for k, v in enumerate(single_obs_flat) if k != math.floor(self.fov_x * self.fov_y / 2)])

            rf_signal = max([self.get_rf_signal(self.agent_pos[agent], s) for s in self.survivor_pos], default=0.0)
            confidence_level = self.confidence_map[x, y]
            pheromone_level = self.pheromone_map[x, y]
            role_encoding = {'scout': [1, 0, 0], 'verification': [0, 1, 0], 'relay': [0, 0, 1]}
            role_vec = role_encoding[self.agent_roles[agent]]

            obs = np.concatenate([single_obs_flat, [confidence_level, pheromone_level, rf_signal], role_vec])
            self.agent_obs.append(obs.astype(np.float32))
        return self.agent_obs

    def update_roles(self):
        for i in range(self.n_agents):
            x, y = self.agent_pos[i]
            confidence = self.confidence_map[x, y]
            if confidence > 0.7 and self.agent_roles[i] != 'verification':
                self.agent_roles[i] = 'verification'
            elif confidence > 0.5 and self.agent_roles[i] != 'relay':
                self.agent_roles[i] = 'relay'
            elif self.agent_roles[i] != 'scout':
                self.agent_roles[i] = 'scout'
            self.agent_roles_history[i].append(self.agent_roles[i])

    def update_confidence_and_pheromones(self):
        self.confidence_map *= self.confidence_decay
        self.pheromone_map *= self.pheromone_decay
        for i in range(self.n_agents):
            x, y = self.agent_pos[i]
            rf_signal = max([self.get_rf_signal(self.agent_pos[i], s) for s in self.survivor_pos], default=0.0)
            confidence = rf_signal
            self.confidence_map[x, y] = min(1.0, self.confidence_map[x, y] + confidence)
            if confidence > 0.5:
                self.pheromone_map[x, y] = min(1.0, self.pheromone_map[x, y] + 0.5)

    def get_coverage(self):
        mapped_poi = np.count_nonzero(self.grid_status == self.MAP)
        return mapped_poi / self.n_poi

    def get_survivor_detection_rate(self):
        detected = sum(1 for s in self.survivor_pos if self.confidence_map[s[0], s[1]] > 0.8)
        return detected / self.n_survivors if self.n_survivors > 0 else 0.0

    def step(self, action, i):
        org_x, org_y = self.agent_pos[i]
        reward = 0
        action = int(action)
        if action == self.XM:
            self.agent_pos[i][0] -= 1
        elif action == self.XP:
            self.agent_pos[i][0] += 1
        elif action == self.YM:
            self.agent_pos[i][1] -= 1
        elif action == self.YP:
            self.agent_pos[i][1] += 1
        elif action == self.XMYM:
            self.agent_pos[i][0] -= 1
            self.agent_pos[i][1] -= 1
        elif action == self.XMYP:
            self.agent_pos[i][0] -= 1
            self.agent_pos[i][1] += 1
        elif action == self.XPYM:
            self.agent_pos[i][0] += 1
            self.agent_pos[i][1] -= 1
        elif action == self.XPYP:
            self.agent_pos[i][0] += 1
            self.agent_pos[i][1] += 1
        elif action == self.STAY:
            reward = 5 if self.agent_roles[i] == 'relay' and self.confidence_map[org_x, org_y] > 0.5 else -1
        else:
            raise ValueError(f"Invalid action={action}")

        if (self.agent_pos[i][0] >= self.x_size or self.agent_pos[i][0] < 0 or
            self.agent_pos[i][1] >= self.y_size or self.agent_pos[i][1] < 0 or
            self.grid_status[self.agent_pos[i][0], self.agent_pos[i][1]] in [self.OBS, self.AGT]):
            self.agent_pos[i] = [org_x, org_y]
            self.grid_counts[i][org_x, org_y] += 1
            reward = -2
        else:
            prev_status = self.grid_status[self.agent_pos[i][0], self.agent_pos[i][1]]
            if prev_status == self.POI:
                self.grid_status[self.agent_pos[i][0], self.agent_pos[i][1]] = self.MAP
                self.grid_counts[i][self.agent_pos[i][0], self.agent_pos[i][1]] += 1
                reward = 5
            elif prev_status == self.SURVIVOR:
                self.grid_counts[i][self.agent_pos[i][0], self.agent_pos[i][1]] += 1
                reward = 50
            elif prev_status == self.MAP:
                self.grid_counts[i][self.agent_pos[i][0], self.agent_pos[i][1]] += 1
                reward = -1

        self.agent_paths[i].append(self.agent_pos[i][:])

        nearby_agents = sum(1 for j in range(self.n_agents) if j != i and
                           np.sqrt((self.agent_pos[i][0] - self.agent_pos[j][0])**2 +
                                   (self.agent_pos[i][1] - self.agent_pos[j][1])**2) < 2)
        if nearby_agents > 0:
            reward -= 2 * nearby_agents

        if org_x == self.agent_pos[i][0] and org_y == self.agent_pos[i][1]:
            self.stuck_counts[i] += 1
        else:
            self.stuck_counts[i] = 0

        self.update_confidence_and_pheromones()
        self.update_roles()
        self.grid_overlay()

        mapped_poi = np.count_nonzero(self.grid_status == self.MAP)
        done = mapped_poi == self.n_poi or self.get_survivor_detection_rate() == 1.0
        info = {'survivor_detection_rate': self.get_survivor_detection_rate()}

        return self.get_agent_obs(), reward, done, info

    def reset(self):
        self.init_grid()
        self.init_agent()
        self.init_survivors()
        self.confidence_map = np.zeros((self.x_size, self.y_size))
        self.pheromone_map = np.zeros((self.x_size, self.y_size))
        self.agent_roles = ['scout'] * self.n_agents
        while True:
            obs = self.get_agent_obs()
            obs_tf = [obs[i][0] != 0 and obs[i][1] != 0 and obs[i][2] != 0 and obs[i][3] != 0
                      for i in range(self.n_agents)]
            if any(obs_tf):
                self.init_grid()
                self.init_agent()
                self.init_survivors()
            else:
                break
        return self.get_agent_obs()

    def close(self):
        pass

# Training
size = 2
fov = 3
n_agents = 2
n_survivors = 1
train_episodes = 1000
max_steps = 8
batch_size = 32
gamma = 0.99
buffer_limit = 50000
log_interval = 20
max_epsilon = 0.9
min_epsilon = 0.1
warm_up_steps = 1000
update_iter = 10
coverage_threshold = 0.95
lr = 0.001

env = GridMultiAgent(x_size=size, y_size=size, fov_x=fov, fov_y=fov, n_agents=n_agents, n_survivors=n_survivors)
memory = ReplayBuffer(buffer_limit)
q = QNet(env.observation_space, env.action_space)
q_target = QNet(env.observation_space, env.action_space)
q_target.load_state_dict(q.state_dict())
optimizer = optim.Adam(q.parameters(), lr=lr)

time_steps, epsilons, coverage, survivor_detection, total_reward, all_agent_paths, all_agent_roles = [], [], [], [], [], [], []

for episode in tqdm(range(train_episodes), desc="Training Episodes"):
    score = np.zeros(n_agents)
    state = env.reset()
    done = False
    episode_step = 0
    epsilon = max(min_epsilon, max_epsilon - (max_epsilon - min_epsilon) * (episode / (0.4 * train_episodes)))

    while not done:
        state_np = np.array(state, dtype=np.float32)
        action = q.sample_action(torch.tensor(state_np).unsqueeze(0), epsilon, env.agent_roles)[0].data.cpu().numpy()
        rewards = np.zeros(n_agents)
        next_state = state
        for agent_i in env.idx_agents:
            next_state, reward, done, info = env.step(action[agent_i], agent_i)
            rewards[agent_i] = reward
            if done:
                break
        memory.put((state, action, rewards, next_state, [done] * n_agents))
        score += rewards
        state = next_state
        episode_step += 1
        if episode_step >= max_steps or done:
            break

    if memory.size() > warm_up_steps:
        train(q, q_target, memory, optimizer, gamma, batch_size, update_iter)

    if episode % log_interval == 0:
        q_target.load_state_dict(q.state_dict())

    all_agent_paths.append([path[:] for path in env.agent_paths])
    all_agent_roles.append([roles[:] for roles in env.agent_roles_history])
    time_steps.append(episode_step)
    epsilons.append(epsilon)
    coverage.append(env.get_coverage())
    survivor_detection.append(info['survivor_detection_rate'])
    total_reward.append(score.sum())

    if episode % log_interval == 0:
        print(f'//Episode {episode+1}// Epsilon: {epsilon:.3f}, Steps: {episode_step}, '
              f'Coverage (%): {coverage[-1]:.3f}, Survivor Detection (%): {survivor_detection[-1]:.3f}, '
              f'Total Reward: {total_reward[-1]}')

# Metrics Visualization
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(coverage, label='Coverage')
plt.title('Coverage Over Episodes')
plt.xlabel('Episode')
plt.ylabel('Coverage (%)')
plt.grid(True)
plt.legend()

plt.subplot(1, 3, 2)
plt.plot(survivor_detection, label='Survivor Detection', color='orange')
plt.title('Survivor Detection Over Episodes')
plt.xlabel('Episode')
plt.ylabel('Detection Rate (%)')
plt.grid(True)
plt.legend()

plt.subplot(1, 3, 3)
plt.plot(total_reward, label='Total Reward', color='green')
plt.title('Total Reward Over Episodes')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.savefig('metrics_upgraded1000.png')
plt.close()

# Drone Path Animation
def animate_drone_paths(episode_idx):
    if episode_idx >= len(all_agent_paths):
        print(f"Episode {episode_idx + 1} data not available.")
        return None

    paths = all_agent_paths[episode_idx]
    roles = all_agent_roles[episode_idx]
    final_grid = env.grid_status.copy()
    final_confidence = env.confidence_map.copy()
    final_survivors = env.survivor_pos[:]

    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xlim(-0.5, env.x_size - 0.5)
    ax.set_ylim(-0.5, env.y_size - 0.5)
    ax.set_xticks(range(env.x_size))
    ax.set_yticks(range(env.y_size))
    ax.grid(True)
    ax.set_title(f'Drone Paths (Episode {episode_idx + 1})')

    obs_x, obs_y = np.where(final_grid == env.OBS)
    ax.scatter(obs_x, obs_y, c='black', marker='s', s=100, label='Obstacles')
    heatmap = ax.imshow(final_confidence.T, cmap='hot', alpha=0.5, origin='lower',
                        extent=(-0.5, env.x_size - 0.5, -0.5, env.y_size - 0.5))
    survivor_x, survivor_y = zip(*final_survivors) if final_survivors else ([], [])
    survivors = ax.scatter(survivor_x, survivor_y, c='blue', marker='*', s=200, label='Survivors')
    colors = ['red', 'green']
    agent_plots = [ax.plot([], [], c=colors[i], marker='o', linestyle='-', label=f'Agent {i}')[0]
                   for i in range(n_agents)]
    role_texts = [ax.text(0, 0, '', fontsize=8, color=colors[i], ha='center', va='bottom')
                  for i in range(n_agents)]
    ax.legend()

    max_frames = max(len(p) for p in paths) if paths else 0

    def update(frame):
        artists = [heatmap, survivors]
        for i, plot in enumerate(agent_plots):
            if i < len(paths):
                if frame < len(paths[i]):
                    x, y = zip(*paths[i][:frame + 1]) if paths[i][:frame + 1] else ([], [])
                    plot.set_data(x, y)
                    if frame < len(roles[i]):
                        if x and y:
                            role_texts[i].set_position((x[-1], y[-1] + 0.2))
                            role_texts[i].set_text(roles[i][frame])
                        else:
                            role_texts[i].set_text('')
                    else:
                        role_texts[i].set_text('')
                    artists.append(plot)
                    artists.append(role_texts[i])
                else:
                    if paths[i]:
                        last_x, last_y = paths[i][-1]
                        plot.set_data(*zip(*paths[i]))
                        role_texts[i].set_position((last_x, last_y + 0.2))
                        if roles[i]:
                            role_texts[i].set_text(roles[i][-1])
                        else:
                            role_texts[i].set_text('')
                        artists.append(plot)
                        artists.append(role_texts[i])
        return artists

    ani = FuncAnimation(fig, update, frames=max_frames, interval=200, blit=True)
    plt.close(fig)
    return ani

# Save animation
ani = animate_drone_paths(train_episodes - 1)
if ani is not None:
    with open('animation_upgraded.html', 'w') as f:
        f.write(ani.to_jshtml())

# Save metrics for comparison
np.savez('metrics_upgraded.npz', coverage=coverage, survivor_detection=survivor_detection, total_reward=total_reward)

# Interactive widget
episode_selector = widgets.IntSlider(
    value=train_episodes - 1,
    min=0,
    max=train_episodes - 1,
    step=1,
    description='Episode:',
    continuous_update=False
)

output_widget = widgets.Output()

def on_episode_change(change):
    episode_idx = change['new']
    with output_widget:
        output_widget.clear_output(wait=True)
        ani = animate_drone_paths(episode_idx)
        if ani is not None:
            display(HTML(ani.to_jshtml()))

episode_selector.observe(on_episode_change, names='value')
display(episode_selector, output_widget)

with output_widget:
    ani = animate_drone_paths(episode_selector.value)
    if ani is not None:
        display(HTML(ani.to_jshtml()))

Training Episodes:   4%|▍         | 45/1000 [00:00<00:04, 226.94it/s]

//Episode 1// Epsilon: 0.900, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: -4.0
//Episode 21// Epsilon: 0.860, Steps: 7, Coverage (%): 0.750, Survivor Detection (%): 1.000, Total Reward: 23.0
//Episode 41// Epsilon: 0.820, Steps: 8, Coverage (%): 0.750, Survivor Detection (%): 0.000, Total Reward: -38.0


Training Episodes:  10%|█         | 105/1000 [00:00<00:03, 278.77it/s]

//Episode 61// Epsilon: 0.780, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0
//Episode 81// Epsilon: 0.740, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0
//Episode 101// Epsilon: 0.700, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 3.0
//Episode 121// Epsilon: 0.660, Steps: 8, Coverage (%): 0.500, Survivor Detection (%): 0.000, Total Reward: -45.0


Training Episodes:  19%|█▉        | 189/1000 [00:00<00:03, 266.33it/s]

//Episode 141// Epsilon: 0.620, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 3.0
//Episode 161// Epsilon: 0.580, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: -4.0
//Episode 181// Epsilon: 0.540, Steps: 4, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 34.0


Training Episodes:  25%|██▍       | 247/1000 [00:00<00:02, 267.11it/s]

//Episode 201// Epsilon: 0.500, Steps: 4, Coverage (%): 0.750, Survivor Detection (%): 1.000, Total Reward: 43.0
//Episode 221// Epsilon: 0.460, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 3.0
//Episode 241// Epsilon: 0.420, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  28%|██▊       | 275/1000 [00:01<00:02, 270.13it/s]

//Episode 261// Epsilon: 0.380, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 44.0
//Episode 281// Epsilon: 0.340, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: -4.0
//Episode 301// Epsilon: 0.300, Steps: 5, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 23.0


Training Episodes:  35%|███▍      | 347/1000 [00:01<00:03, 200.32it/s]

//Episode 321// Epsilon: 0.260, Steps: 2, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 47.0
//Episode 341// Epsilon: 0.220, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: -4.0
//Episode 361// Epsilon: 0.180, Steps: 2, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 47.0


Training Episodes:  38%|███▊      | 375/1000 [00:02<00:05, 106.77it/s]

//Episode 381// Epsilon: 0.140, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 44.0


Training Episodes:  40%|███▉      | 396/1000 [00:03<00:11, 51.75it/s] 

//Episode 401// Epsilon: 0.100, Steps: 2, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 51.0


Training Episodes:  42%|████▏     | 422/1000 [00:04<00:16, 34.47it/s]

//Episode 421// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 51.0


Training Episodes:  44%|████▍     | 438/1000 [00:05<00:19, 28.17it/s]

//Episode 441// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 3.0


Training Episodes:  46%|████▋     | 463/1000 [00:07<00:30, 17.40it/s]

//Episode 461// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: -4.0


Training Episodes:  48%|████▊     | 483/1000 [00:08<00:40, 12.65it/s]

//Episode 481// Epsilon: 0.100, Steps: 2, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 54.0


Training Episodes:  50%|█████     | 504/1000 [00:09<00:26, 18.72it/s]

//Episode 501// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  52%|█████▏    | 524/1000 [00:10<00:22, 20.84it/s]

//Episode 521// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  55%|█████▍    | 545/1000 [00:11<00:22, 20.66it/s]

//Episode 541// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 3.0


Training Episodes:  56%|█████▋    | 563/1000 [00:12<00:21, 20.38it/s]

//Episode 561// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  58%|█████▊    | 585/1000 [00:13<00:19, 21.44it/s]

//Episode 581// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 3.0


Training Episodes:  60%|██████    | 603/1000 [00:14<00:19, 20.57it/s]

//Episode 601// Epsilon: 0.100, Steps: 2, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 47.0


Training Episodes:  62%|██████▏   | 624/1000 [00:15<00:18, 19.94it/s]

//Episode 621// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 3.0


Training Episodes:  64%|██████▍   | 643/1000 [00:17<00:30, 11.81it/s]

//Episode 641// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  66%|██████▋   | 663/1000 [00:18<00:20, 16.20it/s]

//Episode 661// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  68%|██████▊   | 682/1000 [00:19<00:21, 14.73it/s]

//Episode 681// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 3.0


Training Episodes:  70%|███████   | 704/1000 [00:21<00:19, 15.04it/s]

//Episode 701// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  72%|███████▏  | 722/1000 [00:22<00:20, 13.61it/s]

//Episode 721// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  74%|███████▍  | 744/1000 [00:23<00:12, 21.19it/s]

//Episode 741// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  76%|███████▋  | 765/1000 [00:24<00:11, 20.37it/s]

//Episode 761// Epsilon: 0.100, Steps: 4, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 38.0


Training Episodes:  78%|███████▊  | 783/1000 [00:25<00:10, 21.00it/s]

//Episode 781// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  80%|████████  | 804/1000 [00:26<00:09, 20.54it/s]

//Episode 801// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  82%|████████▏ | 822/1000 [00:27<00:08, 21.21it/s]

//Episode 821// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 3.0


Training Episodes:  84%|████████▍ | 843/1000 [00:28<00:09, 17.44it/s]

//Episode 841// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  86%|████████▋ | 864/1000 [00:29<00:07, 19.28it/s]

//Episode 861// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 3.0


Training Episodes:  88%|████████▊ | 885/1000 [00:30<00:05, 19.84it/s]

//Episode 881// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  90%|█████████ | 904/1000 [00:31<00:04, 20.94it/s]

//Episode 901// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  92%|█████████▎| 925/1000 [00:32<00:03, 21.11it/s]

//Episode 921// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  94%|█████████▍| 942/1000 [00:33<00:03, 15.66it/s]

//Episode 941// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  96%|█████████▌| 962/1000 [00:35<00:02, 13.09it/s]

//Episode 961// Epsilon: 0.100, Steps: 2, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 47.0


Training Episodes:  98%|█████████▊| 984/1000 [00:36<00:00, 18.35it/s]

//Episode 981// Epsilon: 0.100, Steps: 3, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 42.0


Training Episodes: 100%|██████████| 1000/1000 [00:37<00:00, 26.69it/s]


IntSlider(value=999, continuous_update=False, description='Episode:', max=999)

Output()

In [11]:
import numpy as np
import collections
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import gym
from gym import spaces
import copy
import math
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

class MultiAgentActionSpace(list):
    def __init__(self, agents_action_space):
        for x in agents_action_space:
            assert isinstance(x, gym.spaces.Space)
        super().__init__(agents_action_space)
        self._agents_action_space = agents_action_space

    def sample(self):
        return [agent_action_space.sample() for agent_action_space in self._agents_action_space]

class MultiAgentObservationSpace(list):
    def __init__(self, agents_observation_space):
        for x in agents_observation_space:
            assert isinstance(x, gym.spaces.Space)
        super().__init__(agents_observation_space)
        self._agents_observation_space = agents_observation_space

    def sample(self):
        return [agent_observation_space.sample() for agent_observation_space in self._agents_observation_space]

class ReplayBuffer:
    def __init__(self, buffer_limit):
        self.buffer = collections.deque(maxlen=buffer_limit)

    def put(self, transition):
        self.buffer.append(transition)

    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], []

        for transition in mini_batch:
            s, a, r, s_prime, done = transition
            s_lst.append(s)
            a_lst.append(a)
            r_lst.append(r)
            s_prime_lst.append(s_prime)
            done_mask_lst.append((np.ones(len(done)) - done).tolist())

        return (torch.tensor(np.array(s_lst), dtype=torch.float),
                torch.tensor(np.array(a_lst), dtype=torch.long),
                torch.tensor(np.array(r_lst), dtype=torch.float),
                torch.tensor(np.array(s_prime_lst), dtype=torch.float),
                torch.tensor(np.array(done_mask_lst), dtype=torch.float))

    def size(self):
        return len(self.buffer)

class QNet(nn.Module):
    def __init__(self, observation_space, action_space):
        super(QNet, self).__init__()
        self.num_agents = len(observation_space)
        for agent_i in range(self.num_agents):
            n_obs = observation_space[agent_i].shape[0]
            setattr(self, f'agent_{agent_i}', nn.Sequential(
                nn.Linear(n_obs, 64),
                nn.ReLU(),
                nn.Linear(64, 32),
                nn.ReLU(),
                nn.Linear(32, action_space[agent_i].n)
            ))

    def forward(self, obs):
        q_values = [torch.empty(obs.shape[0], )] * self.num_agents
        for agent_i in range(self.num_agents):
            q_values[agent_i] = getattr(self, f'agent_{agent_i}')(obs[:, agent_i, :]).unsqueeze(1)
        return torch.cat(q_values, dim=1)

    def sample_action(self, obs, epsilon):
        out = self.forward(obs)
        actions = torch.empty((out.shape[0], out.shape[1],), dtype=torch.long)
        for agent_i in range(out.shape[1]):
            if random.random() < epsilon:
                actions[:, agent_i] = torch.randint(0, out.shape[2], (1,)).long()
            else:
                actions[:, agent_i] = out[:, agent_i].argmax().long()
        return actions

def train(q, q_target, memory, optimizer, gamma, batch_size, update_iter=10):
    for _ in range(update_iter):
        s, a, r, s_prime, done_mask = memory.sample(batch_size)
        q_out = q(s)
        a = a.unsqueeze(-1)
        q_a = q_out.gather(2, a).squeeze(-1)
        max_q_prime = q_target(s_prime).max(dim=2)[0]
        target = r + gamma * max_q_prime * done_mask
        loss = F.smooth_l1_loss(q_a, target.detach())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

class GridMultiAgent(gym.Env):
    metadata = {'render.modes': ['console']}
    UP, DOWN, LEFT, RIGHT, STAY = range(5)
    OOE, OBS, POI, MAP, AGT, SURVIVOR = -2, -1, 0, 1, 2, 3

    def __init__(self, x_size=2, y_size=2, fov_x=3, fov_y=3, n_agents=2, n_survivors=1):
        super().__init__()
        self.x_size = x_size
        self.y_size = y_size
        self.n_agents = n_agents
        self.idx_agents = list(range(n_agents))
        self.n_survivors = n_survivors
        self.fov_x = fov_x
        self.fov_y = fov_y

        self.confidence_map = np.zeros((x_size, y_size))
        self.agent_paths = [[] for _ in range(n_agents)]

        n_actions = 5
        self.action_space = MultiAgentActionSpace([spaces.Discrete(n_actions) for _ in range(n_agents)])
        obs_size = (fov_x * fov_y - 1) + 1
        self.obs_low = np.concatenate([np.ones(fov_x * fov_y - 1) * self.OOE, [0]], dtype=np.float32)
        self.obs_high = np.concatenate([np.ones(fov_x * fov_y - 1) * self.SURVIVOR, [1]], dtype=np.float32)
        self.observation_space = MultiAgentObservationSpace([
            spaces.Box(self.obs_low, self.obs_high, dtype=np.float32) for _ in range(n_agents)
        ])

        self.init_grid()
        self.init_agent()
        self.init_survivors()

    def init_grid(self):
        self.grid_status = np.zeros((self.x_size, self.y_size))
        self.grid_counts = np.tile(self.grid_status, (self.n_agents, 1, 1)).reshape(self.n_agents, self.x_size, self.y_size)
        self.n_poi = self.x_size * self.y_size - np.count_nonzero(self.grid_status)
        self.grid_agents_status = copy.deepcopy(self.grid_status)

    def init_agent(self):
        self.agent_pos = []
        self.agent_paths = [[] for _ in range(self.n_agents)]
        for i in range(self.n_agents):
            while True:
                x, y = random.randrange(0, self.x_size), random.randrange(0, self.y_size)
                if self.grid_agents_status[x, y] == self.POI:
                    self.agent_pos.append([x, y])
                    self.grid_agents_status[x, y] = self.AGT
                    self.agent_paths[i].append([x, y])
                    break

    def init_survivors(self):
        self.survivor_pos = []
        for _ in range(self.n_survivors):
            while True:
                x, y = random.randrange(0, self.x_size), random.randrange(0, self.y_size)
                if self.grid_status[x, y] == self.POI and [x, y] not in self.agent_pos:
                    self.survivor_pos.append([x, y])
                    self.grid_status[x, y] = self.SURVIVOR
                    self.confidence_map[x, y] = 1.0
                    break

    def grid_overlay(self):
        self.grid_agents_status = copy.deepcopy(self.grid_status)
        for i in range(self.n_agents):
            x, y = self.agent_pos[i]
            self.grid_agents_status[x, y] = self.AGT

    def get_rf_signal(self, agent_pos, survivor_pos):
        distance = np.sqrt((agent_pos[0] - survivor_pos[0])**2 + (agent_pos[1] - survivor_pos[1])**2)
        return min(1.0, 1.0 / (distance + 1)) if distance > 0 else 1.0

    def get_agent_obs(self):
        self.agent_obs = []
        for agent in range(self.n_agents):
            single_obs = np.ones((self.fov_x, self.fov_y)) * self.OOE
            x, y = self.agent_pos[agent]
            for i in range(self.fov_x):
                for j in range(self.fov_y):
                    obs_x = x + (i - 1)
                    obs_y = y + (j - 1)
                    if 0 <= obs_x < self.x_size and 0 <= obs_y < self.y_size:
                        single_obs[i][j] = self.grid_agents_status[obs_x][obs_y]
            single_obs_flat = single_obs.flatten()
            single_obs_flat = np.array([v for k, v in enumerate(single_obs_flat) if k != math.floor(self.fov_x * self.fov_y / 2)])
            rf_signal = max([self.get_rf_signal(self.agent_pos[agent], s) for s in self.survivor_pos], default=0.0)
            obs = np.concatenate([single_obs_flat, [rf_signal]])
            self.agent_obs.append(obs.astype(np.float32))
        return self.agent_obs

    def update_confidence(self):
        self.confidence_map *= 0.95
        for i in range(self.n_agents):
            x, y = self.agent_pos[i]
            rf_signal = max([self.get_rf_signal(self.agent_pos[i], s) for s in self.survivor_pos], default=0.0)
            self.confidence_map[x, y] = min(1.0, self.confidence_map[x, y] + rf_signal)

    def get_coverage(self):
        mapped_poi = np.count_nonzero(self.grid_status == self.MAP)
        return mapped_poi / self.n_poi

    def get_survivor_detection_rate(self):
        detected = sum(1 for s in self.survivor_pos if self.confidence_map[s[0], s[1]] > 0.8)
        return detected / self.n_survivors if self.n_survivors > 0 else 0.0

    def step(self, action, i):
        org_x, org_y = self.agent_pos[i]
        reward = 0
        action = int(action)
        if action == self.UP:
            self.agent_pos[i][1] += 1
        elif action == self.DOWN:
            self.agent_pos[i][1] -= 1
        elif action == self.LEFT:
            self.agent_pos[i][0] -= 1
        elif action == self.RIGHT:
            self.agent_pos[i][0] += 1
        elif action == self.STAY:
            reward = -1
        else:
            raise ValueError(f"Invalid action={action}")

        if (self.agent_pos[i][0] >= self.x_size or self.agent_pos[i][0] < 0 or
            self.agent_pos[i][1] >= self.y_size or self.agent_pos[i][1] < 0 or
            self.grid_status[self.agent_pos[i][0], self.agent_pos[i][1]] in [self.OBS, self.AGT]):
            self.agent_pos[i] = [org_x, org_y]
            self.grid_counts[i][org_x, org_y] += 1
            reward = -2
        else:
            prev_status = self.grid_status[self.agent_pos[i][0], self.agent_pos[i][1]]
            if prev_status == self.POI:
                self.grid_status[self.agent_pos[i][0], self.agent_pos[i][1]] = self.MAP
                self.grid_counts[i][self.agent_pos[i][0], self.agent_pos[i][1]] += 1
                reward = 5
            elif prev_status == self.SURVIVOR:
                self.grid_counts[i][self.agent_pos[i][0], self.agent_pos[i][1]] += 1
                reward = 50
            elif prev_status == self.MAP:
                self.grid_counts[i][self.agent_pos[i][0], self.agent_pos[i][1]] += 1
                reward = -1

        self.agent_paths[i].append(self.agent_pos[i][:])

        nearby_agents = sum(1 for j in range(self.n_agents) if j != i and
                           np.sqrt((self.agent_pos[i][0] - self.agent_pos[j][0])**2 +
                                   (self.agent_pos[i][1] - self.agent_pos[j][1])**2) < 2)
        if nearby_agents > 0:
            reward -= 2 * nearby_agents

        self.update_confidence()
        self.grid_overlay()

        mapped_poi = np.count_nonzero(self.grid_status == self.MAP)
        done = mapped_poi == self.n_poi or self.get_survivor_detection_rate() == 1.0
        info = {'survivor_detection_rate': self.get_survivor_detection_rate()}

        return self.get_agent_obs(), reward, done, info

    def reset(self):
        self.init_grid()
        self.init_agent()
        self.init_survivors()
        self.confidence_map = np.zeros((self.x_size, self.y_size))
        return self.get_agent_obs()

    def close(self):
        pass

# Training
size = 2
fov = 3
n_agents = 2
n_survivors = 1
train_episodes = 1000
max_steps = 8
batch_size = 32
gamma = 0.99
buffer_limit = 50000
log_interval = 20
max_epsilon = 0.9
min_epsilon = 0.1
warm_up_steps = 1000
update_iter = 10
lr = 0.001

env = GridMultiAgent(x_size=size, y_size=size, fov_x=fov, fov_y=fov, n_agents=n_agents, n_survivors=n_survivors)
memory = ReplayBuffer(buffer_limit)
q = QNet(env.observation_space, env.action_space)
q_target = QNet(env.observation_space, env.action_space)
q_target.load_state_dict(q.state_dict())
optimizer = optim.Adam(q.parameters(), lr=lr)

time_steps, epsilons, coverage, survivor_detection, total_reward, all_agent_paths = [], [], [], [], [], []

for episode in tqdm(range(train_episodes), desc="Training Episodes"):
    score = np.zeros(n_agents)
    state = env.reset()
    done = False
    episode_step = 0
    epsilon = max(min_epsilon, max_epsilon - (max_epsilon - min_epsilon) * (episode / (0.4 * train_episodes)))

    while not done:
        state_np = np.array(state, dtype=np.float32)
        action = q.sample_action(torch.tensor(state_np).unsqueeze(0), epsilon)[0].data.cpu().numpy()
        rewards = np.zeros(n_agents)
        next_state = state
        for agent_i in env.idx_agents:
            next_state, reward, done, info = env.step(action[agent_i], agent_i)
            rewards[agent_i] = reward
            if done:
                break
        memory.put((state, action, rewards, next_state, [done] * n_agents))
        score += rewards
        state = next_state
        episode_step += 1
        if episode_step >= max_steps or done:
            break

    if memory.size() > warm_up_steps:
        train(q, q_target, memory, optimizer, gamma, batch_size, update_iter)

    if episode % log_interval == 0:
        q_target.load_state_dict(q.state_dict())

    all_agent_paths.append([path[:] for path in env.agent_paths])
    time_steps.append(episode_step)
    epsilons.append(epsilon)
    coverage.append(env.get_coverage())
    survivor_detection.append(info['survivor_detection_rate'])
    total_reward.append(score.sum())

    if episode % log_interval == 0:
        print(f'//Episode {episode+1}// Epsilon: {epsilon:.3f}, Steps: {episode_step}, '
              f'Coverage (%): {coverage[-1]:.3f}, Survivor Detection (%): {survivor_detection[-1]:.3f}, '
              f'Total Reward: {total_reward[-1]}')

# Metrics Visualization
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(coverage, label='Coverage')
plt.title('Coverage Over Episodes')
plt.xlabel('Episode')
plt.ylabel('Coverage (%)')
plt.grid(True)
plt.legend()

plt.subplot(1, 3, 2)
plt.plot(survivor_detection, label='Survivor Detection', color='orange')
plt.title('Survivor Detection Over Episodes')
plt.xlabel('Episode')
plt.ylabel('Detection Rate (%)')
plt.grid(True)
plt.legend()

plt.subplot(1, 3, 3)
plt.plot(total_reward, label='Total Reward', color='green')
plt.title('Total Reward Over Episodes')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.savefig('metrics_simple1000.png')
plt.close()

# Drone Path Animation (Last Episode)
def animate_drone_paths(episode_idx=-1):
    paths = all_agent_paths[episode_idx]
    final_grid = env.grid_status
    final_confidence = env.confidence_map
    final_survivors = env.survivor_pos

    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xlim(-0.5, env.x_size - 0.5)
    ax.set_ylim(-0.5, env.y_size - 0.5)
    ax.set_xticks(range(env.x_size))
    ax.set_yticks(range(env.y_size))
    ax.grid(True)
    ax.set_title(f'Drone Paths (Episode {episode_idx + 1})')

    obs_x, obs_y = np.where(final_grid == env.OBS)
    ax.scatter(obs_x, obs_y, c='black', marker='s', s=100, label='Obstacles')
    heatmap = ax.imshow(final_confidence.T, cmap='hot', alpha=0.5, origin='lower',
                        extent=(-0.5, env.x_size - 0.5, -0.5, env.y_size - 0.5))
    survivor_x, survivor_y = zip(*final_survivors) if final_survivors else ([], [])
    survivors = ax.scatter(survivor_x, survivor_y, c='blue', marker='*', s=200, label='Survivors')
    colors = ['red', 'green']
    agent_plots = [ax.plot([], [], c=colors[i], marker='o', linestyle='-', label=f'Agent {i}')[0]
                   for i in range(n_agents)]
    ax.legend()

    def update(frame):
        for i, plot in enumerate(agent_plots):
            if frame < len(paths[i]):
                x, y = zip(*paths[i][:frame + 1]) if paths[i] else ([], [])
                plot.set_data(x, y)
        return agent_plots + [heatmap, survivors]

    ani = FuncAnimation(fig, update, frames=min(len(p) for p in paths), interval=200, blit=True)
    plt.close(fig)
    return ani

# Save animation
ani = animate_drone_paths(-1)
with open('animation_simple.html', 'w') as f:
    f.write(ani.to_jshtml())

# Save metrics for comparison
np.savez('metrics_simple.npz', coverage=coverage, survivor_detection=survivor_detection, total_reward=total_reward)

Training Episodes:   2%|▏         | 16/1000 [00:00<00:06, 154.67it/s]

//Episode 1// Epsilon: 0.900, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 51.0
//Episode 21// Epsilon: 0.860, Steps: 2, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 47.0


Training Episodes:   5%|▍         | 46/1000 [00:00<00:07, 122.29it/s]

//Episode 41// Epsilon: 0.820, Steps: 5, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 31.0


Training Episodes:   7%|▋         | 70/1000 [00:00<00:11, 81.97it/s] 

//Episode 61// Epsilon: 0.780, Steps: 5, Coverage (%): 0.750, Survivor Detection (%): 1.000, Total Reward: 36.0


Training Episodes:   9%|▊         | 87/1000 [00:01<00:14, 64.63it/s]

//Episode 81// Epsilon: 0.740, Steps: 5, Coverage (%): 0.750, Survivor Detection (%): 1.000, Total Reward: 37.0


Training Episodes:  10%|█         | 105/1000 [00:01<00:12, 69.79it/s]

//Episode 101// Epsilon: 0.700, Steps: 3, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 39.0


Training Episodes:  12%|█▎        | 125/1000 [00:01<00:11, 74.99it/s]

//Episode 121// Epsilon: 0.660, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  15%|█▍        | 148/1000 [00:02<00:15, 55.45it/s]

//Episode 141// Epsilon: 0.620, Steps: 4, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 39.0


Training Episodes:  16%|█▋        | 164/1000 [00:02<00:15, 54.47it/s]

//Episode 161// Epsilon: 0.580, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  19%|█▉        | 189/1000 [00:02<00:16, 49.43it/s]

//Episode 181// Epsilon: 0.540, Steps: 5, Coverage (%): 0.750, Survivor Detection (%): 1.000, Total Reward: 39.0


Training Episodes:  22%|██▏       | 224/1000 [00:03<00:09, 84.68it/s]

//Episode 201// Epsilon: 0.500, Steps: 3, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 39.0
//Episode 221// Epsilon: 0.460, Steps: 8, Coverage (%): 0.750, Survivor Detection (%): 0.000, Total Reward: -39.0


Training Episodes:  25%|██▌       | 252/1000 [00:03<00:06, 109.61it/s]

//Episode 241// Epsilon: 0.420, Steps: 3, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 47.0
//Episode 261// Epsilon: 0.380, Steps: 2, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 51.0


Training Episodes:  28%|██▊       | 280/1000 [00:05<00:38, 18.91it/s]

//Episode 281// Epsilon: 0.340, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  30%|███       | 302/1000 [00:07<00:45, 15.38it/s]

//Episode 301// Epsilon: 0.300, Steps: 8, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 18.0


Training Episodes:  32%|███▎      | 325/1000 [00:08<00:40, 16.49it/s]

//Episode 321// Epsilon: 0.260, Steps: 2, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 43.0


Training Episodes:  34%|███▍      | 344/1000 [00:09<00:30, 21.17it/s]

//Episode 341// Epsilon: 0.220, Steps: 2, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 47.0


Training Episodes:  36%|███▋      | 365/1000 [00:10<00:29, 21.17it/s]

//Episode 361// Epsilon: 0.180, Steps: 2, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 51.0


Training Episodes:  38%|███▊      | 383/1000 [00:11<00:30, 20.34it/s]

//Episode 381// Epsilon: 0.140, Steps: 3, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  40%|████      | 404/1000 [00:12<00:29, 20.48it/s]

//Episode 401// Epsilon: 0.100, Steps: 2, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 51.0


Training Episodes:  42%|████▎     | 425/1000 [00:13<00:27, 20.92it/s]

//Episode 421// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  44%|████▍     | 443/1000 [00:14<00:27, 20.21it/s]

//Episode 441// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  46%|████▋     | 464/1000 [00:15<00:26, 20.43it/s]

//Episode 461// Epsilon: 0.100, Steps: 3, Coverage (%): 0.750, Survivor Detection (%): 1.000, Total Reward: 53.0


Training Episodes:  48%|████▊     | 485/1000 [00:16<00:23, 21.72it/s]

//Episode 481// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 51.0


Training Episodes:  50%|█████     | 503/1000 [00:17<00:23, 21.36it/s]

//Episode 501// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  52%|█████▏    | 524/1000 [00:18<00:22, 20.96it/s]

//Episode 521// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 51.0


Training Episodes:  54%|█████▍    | 542/1000 [00:19<00:33, 13.81it/s]

//Episode 541// Epsilon: 0.100, Steps: 2, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 54.0


Training Episodes:  56%|█████▌    | 562/1000 [00:21<00:33, 12.90it/s]

//Episode 561// Epsilon: 0.100, Steps: 2, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 47.0


Training Episodes:  58%|█████▊    | 584/1000 [00:22<00:24, 16.95it/s]

//Episode 581// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 51.0


Training Episodes:  60%|██████    | 603/1000 [00:23<00:18, 21.46it/s]

//Episode 601// Epsilon: 0.100, Steps: 3, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 46.0


Training Episodes:  62%|██████▏   | 624/1000 [00:24<00:16, 22.70it/s]

//Episode 621// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 51.0


Training Episodes:  64%|██████▍   | 645/1000 [00:25<00:16, 21.67it/s]

//Episode 641// Epsilon: 0.100, Steps: 2, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 50.0


Training Episodes:  66%|██████▋   | 663/1000 [00:26<00:14, 22.61it/s]

//Episode 661// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 51.0


Training Episodes:  68%|██████▊   | 684/1000 [00:27<00:14, 21.75it/s]

//Episode 681// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 51.0


Training Episodes:  70%|███████   | 705/1000 [00:28<00:13, 21.58it/s]

//Episode 701// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 51.0


Training Episodes:  72%|███████▏  | 723/1000 [00:29<00:13, 21.09it/s]

//Episode 721// Epsilon: 0.100, Steps: 2, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 47.0


Training Episodes:  74%|███████▍  | 744/1000 [00:30<00:11, 21.39it/s]

//Episode 741// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 51.0


Training Episodes:  76%|███████▋  | 765/1000 [00:31<00:11, 20.87it/s]

//Episode 761// Epsilon: 0.100, Steps: 2, Coverage (%): 0.750, Survivor Detection (%): 1.000, Total Reward: 57.0


Training Episodes:  78%|███████▊  | 783/1000 [00:31<00:10, 21.06it/s]

//Episode 781// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 51.0


Training Episodes:  80%|████████  | 804/1000 [00:33<00:12, 15.29it/s]

//Episode 801// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  82%|████████▏ | 822/1000 [00:34<00:10, 16.23it/s]

//Episode 821// Epsilon: 0.100, Steps: 2, Coverage (%): 0.750, Survivor Detection (%): 1.000, Total Reward: 57.0


Training Episodes:  84%|████████▍ | 842/1000 [00:35<00:11, 13.72it/s]

//Episode 841// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  86%|████████▌ | 862/1000 [00:36<00:06, 21.42it/s]

//Episode 861// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  88%|████████▊ | 883/1000 [00:37<00:05, 21.52it/s]

//Episode 881// Epsilon: 0.100, Steps: 2, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  90%|█████████ | 904/1000 [00:38<00:04, 21.52it/s]

//Episode 901// Epsilon: 0.100, Steps: 3, Coverage (%): 0.750, Survivor Detection (%): 1.000, Total Reward: 54.0


Training Episodes:  92%|█████████▎| 925/1000 [00:39<00:03, 20.93it/s]

//Episode 921// Epsilon: 0.100, Steps: 1, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 51.0


Training Episodes:  94%|█████████▍| 943/1000 [00:40<00:02, 23.01it/s]

//Episode 941// Epsilon: 0.100, Steps: 1, Coverage (%): 0.000, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  96%|█████████▋| 964/1000 [00:41<00:01, 21.48it/s]

//Episode 961// Epsilon: 0.100, Steps: 3, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 48.0


Training Episodes:  98%|█████████▊| 985/1000 [00:42<00:00, 20.73it/s]

//Episode 981// Epsilon: 0.100, Steps: 2, Coverage (%): 0.500, Survivor Detection (%): 1.000, Total Reward: 50.0


Training Episodes: 100%|██████████| 1000/1000 [00:43<00:00, 23.11it/s]


In [7]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from IPython.display import display, HTML

# Load metrics from all implementations
metrics_original = np.load('metrics_original.npz')
metrics_upgraded = np.load('metrics_upgraded.npz')
metrics_simple = np.load('metrics_simple.npz')

# Extract metrics
coverage = {
    'Original': metrics_original['coverage'],
    'Upgraded': metrics_upgraded['coverage'],
    'Simple': metrics_simple['coverage']
}
survivor_detection = {
    'Original': metrics_original['survivor_detection'],
    'Upgraded': metrics_upgraded['survivor_detection'],
    'Simple': metrics_simple['survivor_detection']
}
total_reward = {
    'Original': metrics_original['total_reward'],
    'Upgraded': metrics_upgraded['total_reward'],
    'Simple': metrics_simple['total_reward']
}

# Create table
data = {
    'Implementation': ['Original', 'Upgraded', 'Simple'],
    'Avg. Coverage (%)': [np.mean(coverage['Original']), np.mean(coverage['Upgraded']), np.mean(coverage['Simple'])],
    'Avg. Survivor Detection (%)': [np.mean(survivor_detection['Original']), np.mean(survivor_detection['Upgraded']), np.mean(survivor_detection['Simple'])],
    'Avg. Total Reward': [np.mean(total_reward['Original']), np.mean(total_reward['Upgraded']), np.mean(total_reward['Simple'])]
}
df = pd.DataFrame(data)
df = df.round(2)
table_html = df.to_html(index=False, classes='table table-striped', justify='center')

# Plot metrics
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
for key in coverage:
    plt.plot(coverage[key], label=key)
plt.title('Coverage Over Episodes')
plt.xlabel('Episode')
plt.ylabel('Coverage (%)')
plt.grid(True)
plt.legend()

plt.subplot(1, 3, 2)
for key in survivor_detection:
    plt.plot(survivor_detection[key], label=key)
plt.title('Survivor Detection Over Episodes')
plt.xlabel('Episode')
plt.ylabel('Detection Rate (%)')
plt.grid(True)
plt.legend()

plt.subplot(1, 3, 3)
for key in total_reward:
    plt.plot(total_reward[key], label=key)
plt.title('Total Reward Over Episodes')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.savefig('comparison_metrics.png')
plt.close()

# Display table
display(HTML(table_html))

# Instructions to view animations
print("To view animations, open the following files in a browser:")
print("- Original: animation_original.html")
print("- Upgraded: animation_upgraded.html")
print("- Simple: animation_simple.html")

Implementation,Avg. Coverage (%),Avg. Survivor Detection (%),Avg. Total Reward
Original,0.16,0.65,-10.38
Upgraded,0.25,0.68,-4.97
Simple,0.34,0.6,5.14


To view animations, open the following files in a browser:
- Original: animation_original.html
- Upgraded: animation_upgraded.html
- Simple: animation_simple.html


In [13]:
!pip install torch-geometric
!pip install torch-scatter torch-sparse torch-cluster

Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m20.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.6.1
Collecting torch-scatter
  Downloading torch_scatter-2.1.2.tar.gz (108 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m108.0/108.0 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torch-sparse
  Downloading torch_sparse-0.6.18.tar.gz (209 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21

In [14]:
!pip cache purge
!pip install -q torch-geometric torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
!pip install -q numpy torch gym tqdm matplotlib ipython

Files removed: 19
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.1/5.1 MB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m989.8/989.8 kB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m79.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m26.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m32.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━

In [15]:
# Dependencies should be installed (torch-geometric, torch-scatter, torch-sparse, torch-cluster).
# If session reset, reinstall with:
# !pip cache purge
# !pip install -q torch-geometric torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
# !pip install -q numpy torch gym tqdm matplotlib ipython

import collections
import random
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from gym import spaces
import gym
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
from torch_geometric.nn import GATConv
from torch_geometric.data import Data, Batch
import os
import zipfile

# Verify environment
print(f"PyTorch Version: {torch.__version__}, CUDA Available: {torch.cuda.is_available()}")
print("torch-geometric imported successfully!")

class MultiAgentActionSpace(list):
    def __init__(self, agents_action_space):
        for x in agents_action_space:
            assert isinstance(x, gym.spaces.Space)
        super().__init__(agents_action_space)
        self._agents_action_space = agents_action_space

    def sample(self):
        return [agent_action_space.sample() for agent_action_space in self._agents_action_space]

class MultiAgentObservationSpace(list):
    def __init__(self, agents_observation_space):
        for x in agents_observation_space:
            assert isinstance(x, gym.spaces.Space)
        super().__init__(agents_observation_space)
        self._agents_observation_space = agents_observation_space

    def sample(self):
        return [agent_observation_space.sample() for agent_observation_space in self._agents_observation_space]

class ReplayBuffer:
    def __init__(self, buffer_limit):
        self.buffer = collections.deque(maxlen=buffer_limit)

    def put(self, transition):
        if len(transition) != 7:
            raise ValueError(f"Transition tuple must have exactly 7 elements, got {len(transition)}")
        self.buffer.append(transition)

    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst, graph_s_lst, graph_s_prime_lst = [], [], [], [], [], [], []

        for transition in mini_batch:
            s, a, r, s_prime, done, graph_s, graph_s_prime = transition
            s_lst.append(s)
            a_lst.append(a)
            r_lst.append(r)
            s_prime_lst.append(s_prime)
            done_mask_lst.append((np.ones(len(done)) - done).tolist())
            graph_s_lst.append(graph_s)
            graph_s_prime_lst.append(graph_s_prime)

        return (torch.tensor(np.array(s_lst), dtype=torch.float),
                torch.tensor(np.array(a_lst), dtype=torch.long),
                torch.tensor(np.array(r_lst), dtype=torch.float),
                torch.tensor(np.array(s_prime_lst), dtype=torch.float),
                torch.tensor(np.array(done_mask_lst), dtype=torch.float),
                graph_s_lst,
                graph_s_prime_lst)

    def size(self):
        return len(self.buffer)

class QNet(nn.Module):
    def __init__(self, observation_space, action_space, fov_x=3, fov_y=3):
        super(QNet, self).__init__()
        self.num_agents = len(observation_space)
        self.fov_x, self.fov_y = fov_x, fov_y
        self.n_actions = action_space[0].n
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * fov_x * fov_y, 64),
            nn.ReLU()
        )
        self.non_spatial_size = 1 + 3 + 3
        self.fc_non_spatial = nn.Sequential(
            nn.Linear(self.non_spatial_size, 16),
            nn.ReLU()
        )
        self.fc_shared = nn.Sequential(
            nn.Linear(64 + 16, 64),
            nn.ReLU()
        )
        self.heads = nn.ModuleList([nn.Linear(64, self.n_actions) for _ in range(self.num_agents)])

    def forward(self, obs):
        batch_size = obs.shape[0]
        spatial_obs = torch.zeros(batch_size, self.num_agents, 1, self.fov_x, self.fov_y, device=obs.device)
        for agent_i in range(self.num_agents):
            grid_status = obs[:, agent_i, :self.fov_x * fov_y].reshape(-1, self.fov_x, self.fov_y)
            spatial_obs[:, agent_i, 0] = grid_status
        spatial_out = torch.zeros(batch_size, self.num_agents, 64, device=obs.device)
        for agent_i in range(self.num_agents):
            spatial_out[:, agent_i] = self.cnn(spatial_obs[:, agent_i])
        non_spatial = obs[:, :, -self.non_spatial_size:]
        non_spatial_out = self.fc_non_spatial(non_spatial.view(-1, self.non_spatial_size)).view(batch_size, self.num_agents, 16)
        combined = torch.cat([spatial_out, non_spatial_out], dim=2)
        shared_out = self.fc_shared(combined)
        q_values = [self.heads[agent_i](shared_out[:, agent_i]).unsqueeze(1) for agent_i in range(self.num_agents)]
        return torch.cat(q_values, dim=1)

    def sample_action(self, obs, epsilon, roles, grid_status, agent_pos, x_size, y_size, stuck_counts, visited_counts):
        out = self.forward(obs)
        actions = torch.zeros((out.shape[1],), dtype=torch.long)
        exploration_bonus = 50.0
        scout_epsilon_boost = 0.4
        for agent_i in range(out.shape[1]):
            agent_epsilon = epsilon
            if roles[agent_i] == 'scout' and stuck_counts[agent_i] > 2:
                agent_epsilon = min(1.0, epsilon + scout_epsilon_boost)
            if random.random() < agent_epsilon:
                if roles[agent_i] == 'relay':
                    actions[agent_i] = 4  # STAY
                else:
                    valid_actions = []
                    poi_actions = []
                    low_visit_actions = []
                    x, y = agent_pos[agent_i]
                    for a in range(4):  # UP, DOWN, LEFT, RIGHT
                        new_x, new_y = x, y
                        if a == 0: new_y += 1
                        elif a == 1: new_y -= 1
                        elif a == 2: new_x -= 1
                        elif a == 3: new_x += 1
                        if 0 <= new_x < x_size and 0 <= new_y < y_size and grid_status[new_x, new_y] in [0, 3]:
                            valid_actions.append(a)
                            if grid_status[new_x, new_y] == 0:
                                poi_actions.append(a)
                            if visited_counts[new_x, new_y] < 2:
                                low_visit_actions.append(a)
                    actions[agent_i] = random.choice(low_visit_actions or poi_actions or valid_actions) if valid_actions else 4
            else:
                q_values = out[0, agent_i].clone()
                x, y = agent_pos[agent_i]
                max_neighbor_rf = obs[0, agent_i, -2].item()
                for a in range(4):  # UP, DOWN, LEFT, RIGHT
                    new_x, new_y = x, y
                    if a == 0: new_y += 1
                    elif a == 1: new_y -= 1
                    elif a == 2: new_x -= 1
                    elif a == 3: new_x += 1
                    if 0 <= new_x < x_size and 0 <= new_y < y_size:
                        if grid_status[new_x, new_y] == 0:
                            q_values[a] += exploration_bonus
                        elif grid_status[new_x, new_y] == 3:
                            q_values[a] += exploration_bonus * (2 + max_neighbor_rf)
                        if visited_counts[new_x, new_y] < 2:
                            q_values[a] += exploration_bonus * 0.5
                actions[agent_i] = q_values.argmax().item()
        return actions

class QCentralGNN(nn.Module):
    def __init__(self, node_feature_size, n_agents, n_actions):
        super().__init__()
        self.n_agents = n_agents
        self.n_actions = n_actions
        self.gat1 = GATConv(node_feature_size, 64, heads=4, concat=True)
        self.gat2 = GATConv(64 * 4, 64, heads=1, concat=True)
        self.fc = nn.Linear(64, n_actions)

    def forward(self, graph_data):
        x, edge_index = graph_data.x, graph_data.edge_index
        x = F.relu(self.gat1(x, edge_index))
        x = F.relu(self.gat2(x, edge_index))
        x = self.fc(x)
        return x.view(-1, self.n_agents, self.n_actions)

class GridMultiAgent(gym.Env):
    metadata = {'render.modes': ['console']}
    UP, DOWN, LEFT, RIGHT, STAY = range(5)
    OOE, OBS, POI, MAP, AGT, SURVIVOR = -2, -1, 0, 1, 2, 3

    def __init__(self, x_size=2, y_size=2, fov_x=3, fov_y=3, n_agents=2, n_survivors=1, comm_range=2.0):
        super().__init__()
        self.x_size = x_size
        self.y_size = y_size
        self.n_agents = n_agents
        self.idx_agents = list(range(n_agents))
        self.n_survivors = n_survivors
        self.fov_x = fov_x
        self.fov_y = fov_y
        self.comm_range = comm_range
        self.max_relays = n_agents // 2
        self.agent_roles = ['scout'] * n_agents
        self.stuck_counts = [0] * n_agents
        self.visited_counts = np.zeros((x_size, y_size))
        self.survivor_pos = []
        self.agent_paths = [[] for _ in range(n_agents)]
        self.grid_status_history = []
        self.adjacency_matrix = np.zeros((n_agents, n_agents))
        self.detection_map = np.zeros((x_size, y_size))
        n_actions = 5
        self.action_space = MultiAgentActionSpace([spaces.Discrete(n_actions) for _ in range(n_agents)])
        obs_size = (fov_x * fov_y) + 1 + 3 + 3
        self.obs_low = np.concatenate([np.ones(fov_x * fov_y) * self.OOE, [0, 0, 0, 0, 0, 0, 0]], dtype=np.float32)
        self.obs_high = np.concatenate([np.ones(fov_x * fov_y) * self.SURVIVOR, [1, 1, 1, 1, n_agents, 1, 3]], dtype=np.float32)
        self.observation_space = MultiAgentObservationSpace([
            spaces.Box(self.obs_low, self.obs_high, dtype=np.float32) for _ in range(n_agents)
        ])
        self.init_grid()
        self.init_agent()
        self.init_survivors()

    def init_grid(self):
        self.grid_status = np.zeros((self.x_size, self.y_size))
        self.grid_counts = np.tile(self.grid_status, (self.n_agents, 1, 1)).reshape(self.n_agents, self.x_size, self.y_size)
        self.n_poi = self.x_size * self.y_size - np.count_nonzero(self.grid_status)
        self.grid_agents_status = copy.deepcopy(self.grid_status)
        self.grid_status_history = [self.grid_status.copy()]
        self.visited_counts = np.zeros((self.x_size, self.y_size))

    def init_agent(self):
        self.agent_pos = []
        self.agent_paths = [[] for _ in range(self.n_agents)]
        for i in range(self.n_agents):
            while True:
                x, y = random.randrange(0, self.x_size), random.randrange(0, self.y_size)
                if self.grid_agents_status[x, y] == self.POI:
                    self.agent_pos.append([x, y])
                    self.grid_agents_status[x, y] = self.AGT
                    self.agent_paths[i].append([x, y])
                    self.visited_counts[x, y] += 1
                    break
        self.stuck_counts = [0] * n_agents

    def init_survivors(self):
        self.survivor_pos = []
        for _ in range(self.n_survivors):
            while True:
                x, y = random.randrange(0, self.x_size), random.randrange(0, self.y_size)
                if self.grid_status[x, y] == self.POI and [x, y] not in self.agent_pos:
                    self.survivor_pos.append([x, y])
                    self.grid_status[x, y] = self.SURVIVOR
                    break

    def compute_network(self):
        self.adjacency_matrix = np.zeros((self.n_agents, self.n_agents))
        for i in range(self.n_agents):
            for j in range(i + 1, self.n_agents):
                dist = np.sqrt((self.agent_pos[i][0] - self.agent_pos[j][0])**2 +
                               (self.agent_pos[i][1] - self.agent_pos[j][1])**2)
                if dist <= self.comm_range:
                    self.adjacency_matrix[i, j] = 1
                    self.adjacency_matrix[j, i] = 1

    def get_neighbor_messages(self, agent_idx):
        neighbors = np.where(self.adjacency_matrix[agent_idx] == 1)[0]
        n_neighbors = len(neighbors)
        max_rf = 0.0
        neighbor_role = 0
        for neighbor in neighbors:
            rf_signals = [self.get_rf_signal(self.agent_pos[neighbor], s) for s in self.survivor_pos]
            max_rf = max(max_rf, max(rf_signals, default=0.0))
            role_map = {'scout': 0, 'verification': 1, 'relay': 2}
            neighbor_role = max(neighbor_role, role_map[self.agent_roles[neighbor]])
        return n_neighbors, max_rf, neighbor_role

    def grid_overlay(self):
        self.grid_agents_status = copy.deepcopy(self.grid_status)
        for i in range(self.n_agents):
            x, y = self.agent_pos[i]
            self.grid_agents_status[x, y] = self.AGT

    def get_rf_signal(self, agent_pos, survivor_pos):
        distance = np.sqrt((agent_pos[0] - survivor_pos[0])**2 + (agent_pos[1] - survivor_pos[1])**2)
        return min(1.0, 1.0 / (distance + 0.5))

    def get_graph_state(self):
        agent_obs = self.get_agent_obs()
        node_features = np.array(agent_obs, dtype=np.float32)
        edge_index = []
        for i in range(self.n_agents):
            for j in range(self.n_agents):
                if self.adjacency_matrix[i, j] == 1:
                    edge_index.append([i, j])
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() if edge_index else torch.empty((2, 0), dtype=torch.long)
        graph_data = Data(
            x=torch.tensor(node_features, dtype=torch.float),
            edge_index=edge_index
        )
        return graph_data

    def get_agent_obs(self):
        self.compute_network()
        self.agent_obs = []
        for agent in range(self.n_agents):
            single_obs = np.ones((self.fov_x, self.fov_y)) * self.OOE
            x, y = self.agent_pos[agent]
            for s in self.survivor_pos:
                rf = self.get_rf_signal([x, y], s)
                if rf > 0.6:
                    sx, sy = s
                    self.detection_map[sx, sy] += 1
            for i in range(self.fov_x):
                for j in range(self.fov_y):
                    obs_x = x + (i - self.fov_x // 2)
                    obs_y = y + (j - self.fov_y // 2)
                    if 0 <= obs_x < self.x_size and 0 <= obs_y < self.y_size:
                        single_obs[i][j] = self.grid_agents_status[obs_x][obs_y]
            single_obs_flat = single_obs.flatten()
            rf_signals = [self.get_rf_signal(self.agent_pos[agent], s) for s in self.survivor_pos]
            rf_signal = max(rf_signals, default=0.0)
            role_encoding = {'scout': [1, 0, 0], 'verification': [0, 1, 0], 'relay': [0, 0, 1]}
            role_vec = role_encoding[self.agent_roles[agent]]
            n_neighbors, max_neighbor_rf, neighbor_role = self.get_neighbor_messages(agent)
            obs = np.concatenate([single_obs_flat, [rf_signal], role_vec, [n_neighbors, max_neighbor_rf, neighbor_role]])
            self.agent_obs.append(obs.astype(np.float32))
        return self.agent_obs

    def update_roles(self):
        relay_count = sum(1 for r in self.agent_roles if r == 'relay')
        detected_survivors = sum(1 for s in self.survivor_pos if
                                 max(self.get_rf_signal([ax, ay], s) for ax, ay in self.agent_pos) > 0.6)
        for i in range(self.n_agents):
            rf_signals = [self.get_rf_signal(self.agent_pos[i], s) for s in self.survivor_pos]
            max_rf = max(rf_signals, default=0.0)
            n_neighbors, max_neighbor_rf, _ = self.get_neighbor_messages(i)
            if (max_rf > 0.6 or max_neighbor_rf > 0.6) and self.agent_roles[i] != 'verification':
                self.agent_roles[i] = 'verification'
            elif detected_survivors < self.n_survivors and self.agent_roles[i] != 'relay' and relay_count < self.max_relays:
                self.agent_roles[i] = 'relay'
                relay_count += 1
            elif detected_survivors < self.n_survivors and self.agent_roles[i] != 'scout':
                self.agent_roles[i] = 'scout'

    def get_coverage(self):
        mapped_poi = np.count_nonzero(self.grid_status == self.MAP)
        return mapped_poi / self.n_poi if self.n_poi > 0 else 0.0

    def get_survivor_detection_rate(self):
        detected = sum(1 for s in self.survivor_pos if
                       max(self.get_rf_signal([ax, ay], s) for ax, ay in self.agent_pos) > 0.6)
        return detected / self.n_survivors if self.n_survivors > 0 else 0.0

    def step(self, action, i):
        org_x, org_y = self.agent_pos[i]
        reward = 0
        action = int(action)
        new_x, new_y = org_x, org_y
        if action == self.UP:
            new_y += 1
        elif action == self.DOWN:
            new_y -= 1
        elif action == self.LEFT:
            new_x -= 1
        elif action == self.RIGHT:
            new_x += 1
        elif action == self.STAY:
            rf_signals = [self.get_rf_signal(self.agent_pos[i], s) for s in self.survivor_pos]
            max_rf = max(rf_signals, default=0.0)
            reward = 5 if self.agent_roles[i] == 'relay' and max_rf > 0.5 else -2
        else:
            raise ValueError(f"Invalid action={action}")
        if (0 <= new_x < self.x_size and 0 <= new_y < self.y_size and
            self.grid_status[new_x, new_y] not in [self.OBS]):
            collision = False
            for j in range(self.n_agents):
                if i != j and self.agent_pos[j][0] == new_x and self.agent_pos[j][1] == new_y:
                    collision = True
                    break
            if not collision:
                self.agent_pos[i] = [new_x, new_y]
                self.visited_counts[new_x, new_y] += 1
                prev_status = self.grid_status[new_x, new_y]
                if prev_status == self.POI:
                    self.grid_status[new_x, new_y] = self.MAP
                    self.grid_counts[i][new_x, new_y] += 1
                    reward = 15
                elif prev_status == self.SURVIVOR:
                    self.grid_counts[i][new_x, new_y] += 1
                    reward = 100
                    detected = self.get_survivor_detection_rate() * self.n_survivors
                    reward += 50 * detected
                elif prev_status == self.MAP:
                    self.grid_counts[i][new_x, new_y] += 1
                    reward = -1
            else:
                self.grid_counts[i][org_x, org_y] += 1
                reward = -5
        else:
            self.grid_counts[i][org_x, org_y] += 1
            reward = -2
        self.agent_paths[i].append(self.agent_pos[i][:])
        if org_x == self.agent_pos[i][0] and org_y == self.agent_pos[i][1]:
            self.stuck_counts[i] += 1
        else:
            self.stuck_counts[i] = 0
        self.update_roles()
        self.grid_overlay()
        self.grid_status_history.append(self.grid_status.copy())
        mapped_poi = np.count_nonzero(self.grid_status == self.MAP)
        done = mapped_poi >= self.n_poi * 0.95 or self.get_survivor_detection_rate() >= 1.0
        info = {'survivor_detection_rate': self.get_survivor_detection_rate(),
                'coverage': self.get_coverage(),
                'adjacency_matrix': self.adjacency_matrix.copy()}
        return self.get_agent_obs(), reward, done, info

    def reset(self):
        self.init_grid()
        self.init_agent()
        self.init_survivors()
        self.agent_roles = ['scout'] * self.n_agents
        self.grid_status_history = [self.grid_status.copy()]
        self.adjacency_matrix = np.zeros((self.n_agents, self.n_agents))
        self.detection_map = np.zeros((self.x_size, self.y_size))
        while any(self.grid_status[pos[0], pos[1]] in [self.OBS, self.SURVIVOR] for pos in self.agent_pos):
            self.init_grid()
            self.init_agent()
            self.init_survivors()
        self.grid_overlay()
        return self.get_agent_obs()

    def close(self):
        pass

def plot_grid(env, agent_paths, agent_roles, episode, info_history):
    fig, ax = plt.subplots(figsize=(8, 8))
    grid = env.grid_status_history[-1]
    agent_pos = [path[-1] for path in agent_paths]
    agent_roles_final = [roles[-1] for roles in agent_roles]
    display_grid = np.zeros((env.x_size, env.y_size))
    display_grid[grid == env.OBS] = -1
    display_grid[grid == env.MAP] = 0.5
    display_grid[grid == env.SURVIVOR] = 2
    colors = plt.cm.RdYlBu(np.linspace(0, 1, 256))
    colors[int(-1 + 127)] = [0, 0, 0, 1]
    colors[int(0.5 + 127)] = [1, 0.8, 0.8, 1]
    colors[int(2 + 127)] = [1, 1, 1, 1]
    custom_cmap = plt.cm.colors.ListedColormap(colors)
    ax.imshow(display_grid, cmap=custom_cmap, vmin=-1, vmax=3, origin='lower')
    for sx, sy in env.survivor_pos:
        ax.plot(sy, sx, 'b*', markersize=15, label='Survivor' if 'Survivor' not in [l.get_label() for l in ax.get_legend_handles_labels()[1]] else '')
    colors = ['r', 'g']
    for i in range(env.n_agents):
        path = np.array(agent_paths[i])
        ax.plot(path[:, 1], path[:, 0], color=colors[i], linewidth=2, label=f'Agent {i}')
        ax.plot(path[-1, 1], path[-1, 0], 'o', color=colors[i], markersize=8)
        ax.text(path[-1, 1] + 0.1, path[-1, 0], agent_roles_final[i], fontsize=10, color=colors[i])
    ax.set_xticks(np.arange(env.x_size))
    ax.set_yticks(np.arange(env.y_size))
    ax.set_title(f'Drone Paths [Episode {episode + 1}]')
    ax.grid(True)
    legend_elements = [
        Line2D([0], [0], color='k', marker='s', linestyle='None', markersize=10, label='Obstacles'),
        plt.plot([], [], 'b*', markersize=10, label='Survivors')[0],
        Line2D([0], [0], color='r', lw=2, label='Agent 0'),
        Line2D([0], [0], color='g', lw=2, label='Agent 1'),
    ]
    ax.legend(handles=legend_elements, loc='upper right')
    plt.savefig(f'drone_paths_episode_{episode + 1}.png')
    plt.close()

def plot_detection_heatmap(env, episode):
    fig, ax = plt.subplots(figsize=(8, 8))
    heatmap = ax.imshow(env.detection_map, cmap='hot', interpolation='nearest', origin='lower')
    ax.set_title(f'Survivor Detection Heatmap [Episode {episode + 1}]')
    ax.set_xticks(np.arange(env.x_size))
    ax.set_yticks(np.arange(env.y_size))
    ax.grid(True)
    plt.colorbar(heatmap, ax=ax, label='Detection Count')
    plt.savefig(f'survivor_detection_heatmap_episode_{episode + 1}.png')
    plt.close()

def plot_gnn_graph(env, agent_paths, info_history, episode):
    fig, ax = plt.subplots(figsize=(8, 8))
    agent_pos = [path[-1] for path in agent_paths]
    adjacency_matrix = info_history[-1]['adjacency_matrix']
    colors = ['r', 'g']
    for i, (x, y) in enumerate(agent_pos):
        ax.plot(y, x, 'o', color=colors[i], markersize=10, label=f'Agent {i}')
    for i in range(env.n_agents):
        for j in range(i + 1, env.n_agents):
            if adjacency_matrix[i, j] == 1:
                ax.plot([agent_pos[i][1], agent_pos[j][1]],
                        [agent_pos[i][0], agent_pos[j][0]], 'k-', alpha=0.5)
    ax.set_xlim(-0.5, env.x_size - 0.5)
    ax.set_ylim(-0.5, env.y_size - 0.5)
    ax.set_xticks(np.arange(env.x_size))
    ax.set_yticks(np.arange(env.y_size))
    ax.set_title(f'GNN Communication Graph [Episode {episode + 1}]')
    ax.grid(True)
    ax.legend()
    plt.savefig(f'gnn_communication_graph_episode_{episode + 1}.png')
    plt.close()

def animate_drone_paths(env, agent_paths, agent_roles, episode):
    paths = agent_paths
    roles = agent_roles
    grid = env.grid_status_history[-1]
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xlim(-0.5, env.x_size - 0.5)
    ax.set_ylim(-0.5, env.y_size - 0.5)
    ax.set_xticks(range(env.x_size))
    ax.set_yticks(range(env.y_size))
    ax.grid(True)
    ax.set_title(f'Drone Paths Animation [Episode {episode + 1}]')
    display_grid = np.zeros((env.x_size, env.y_size))
    display_grid[grid == env.OBS] = -1
    display_grid[grid == env.MAP] = 0.5
    display_grid[grid == env.SURVIVOR] = 2
    colors = plt.cm.RdYlBu(np.linspace(0, 1, 256))
    colors[int(-1 + 127)] = [0, 0, 0, 1]
    colors[int(0.5 + 127)] = [1, 0.8, 0.8, 1]
    colors[int(2 + 127)] = [1, 1, 1, 1]
    custom_cmap = plt.cm.colors.ListedColormap(colors)
    ax.imshow(display_grid, cmap=custom_cmap, vmin=-1, vmax=3, origin='lower')
    survivor_x, survivor_y = zip(*env.survivor_pos) if env.survivor_pos else ([], [])
    survivors = ax.scatter(survivor_y, survivor_x, c='blue', marker='*', s=200, label='Survivors')
    colors = ['red', 'green']
    agent_plots = [ax.plot([], [], c=colors[i], marker='o', linestyle='-', label=f'Agent {i}')[0]
                   for i in range(env.n_agents)]
    role_texts = [ax.text(0, 0, '', fontsize=8, color=colors[i], ha='center', va='bottom')
                  for i in range(env.n_agents)]
    ax.legend()
    max_frames = max(len(p) for p in paths) if paths else 0

    def update(frame):
        artists = [survivors]
        for i, plot in enumerate(agent_plots):
            if frame < len(paths[i]):
                x, y = zip(*paths[i][:frame + 1]) if paths[i][:frame + 1] else ([], [])
                plot.set_data(y, x)
                if frame < len(roles[i]):
                    if x and y:
                        role_texts[i].set_position((y[-1], x[-1] + 0.1))
                        role_texts[i].set_text(roles[i][frame])
                    else:
                        role_texts[i].set_text('')
                artists.append(plot)
                artists.append(role_texts[i])
        return artists

    ani = FuncAnimation(fig, update, frames=max_frames, interval=200, blit=True)
    plt.close(fig)
    return ani

def download_visualizations(episode):
    files = [
        f'drone_paths_episode_{episode + 1}.png',
        f'survivor_detection_heatmap_episode_{episode + 1}.png',
        f'gnn_communication_graph_episode_{episode + 1}.png',
        f'animation_gnn_episode_{episode + 1}.html',
        'swarm_drone_gnn_performance.png'
    ]
    zip_filename = 'gnn_visualizations.zip'
    with zipfile.ZipFile(zip_filename, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for file in files:
            if os.path.exists(file):
                zipf.write(file)
    from google.colab import files
    files.download(zip_filename)

def train(q, q_target, q_central, q_central_target, memory, optimizer, optimizer_central, gamma, batch_size, update_iter=5):
    for _ in range(update_iter):
        s, a, r, s_prime, done_mask, graph_s, graph_s_prime = memory.sample(batch_size)
        q_out = q(s)
        a = a.view(-1, q_out.shape[1], 1)
        q_a = q_out.gather(2, a).squeeze(-1)
        for g in graph_s_prime:
            g.x = g.x.to(s.device)
            g.edge_index = g.edge_index.to(s.device)
        batch_s_prime = Batch.from_data_list(graph_s_prime).to(s.device)
        with torch.no_grad():
            next_actions = q(s_prime).argmax(dim=2, keepdim=True)
            q_values = q_central_target(batch_s_prime)
            q_selected = q_values.gather(2, next_actions).squeeze(-1)
            max_q_prime = q_selected
            target = r + gamma * max_q_prime * done_mask
        loss_actor = F.smooth_l1_loss(q_a, target.detach())
        optimizer.zero_grad()
        loss_actor.backward()
        optimizer.step()
        for g in graph_s:
            g.x = g.x.to(s.device)
            g.edge_index = g.edge_index.to(s.device)
        batch_s = Batch.from_data_list(graph_s).to(s.device)
        q_central_out = q_central(batch_s)
        q_central_a = q_central_out.gather(2, a).squeeze(-1)
        loss_central = F.smooth_l1_loss(q_central_a, target.detach())
        optimizer_central.zero_grad()
        loss_central.backward()
        optimizer_central.step()

size = 2
fov = 3
n_agents = 2
n_survivors = 1
train_episodes = 500  # Reduced for faster training
max_steps = 8
batch_size = 16  # Reduced to prevent memory issues
gamma = 0.99
buffer_limit = 10000
log_interval = 100  # Adjusted for fewer episodes
max_epsilon = 0.9
min_epsilon = 0.1
warm_up_steps = 500
update_iter = 5
lr = 0.001
comm_range = 2.0

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

env = GridMultiAgent(x_size=size, y_size=size, fov_x=fov, fov_y=fov, n_agents=n_agents,
                     n_survivors=n_survivors, comm_range=comm_range)
memory = ReplayBuffer(buffer_limit)
q = QNet(env.observation_space, env.action_space, fov_x=fov, fov_y=fov)
q_target = QNet(env.observation_space, env.action_space, fov_x=fov, fov_y=fov)
q_target.load_state_dict(q.state_dict())
agent_obs_size = (fov * fov) + 1 + 3 + 3
node_feature_size = agent_obs_size
q_central = QCentralGNN(node_feature_size, n_agents, env.action_space[0].n)
q_central_target = QCentralGNN(node_feature_size, n_agents, env.action_space[0].n)
q_central_target.load_state_dict(q_central.state_dict())
optimizer = optim.Adam(q.parameters(), lr=lr)
optimizer_central = optim.Adam(q_central.parameters(), lr=lr)
time_steps, epsilons, coverage, survivor_detection, total_reward, all_agent_paths, all_agent_roles, all_info = [], [], [], [], [], [], [], []

for episode in tqdm(range(train_episodes), desc="Training Episodes"):
    score = np.zeros(n_agents)
    state = env.reset()
    done = False
    episode_step = 0
    epsilon = max(min_epsilon, max_epsilon - (max_epsilon - min_epsilon) * (episode / (0.8 * train_episodes)))
    episode_paths = [[] for _ in range(n_agents)]
    episode_roles = [[] for _ in range(n_agents)]
    episode_info = []
    while not done and episode_step < max_steps:
        state_np = np.array(state, dtype=np.float32)
        graph_state = env.get_graph_state()
        action = q.sample_action(
            torch.tensor(state_np).unsqueeze(0),
            epsilon, env.agent_roles, env.grid_status, env.agent_pos, env.x_size, env.y_size,
            env.stuck_counts, env.visited_counts
        ).data.cpu().numpy()
        rewards = np.zeros(n_agents)
        next_state = state
        step_info = None
        for agent_i in env.idx_agents:
            if done:
                break
            step_state, reward, step_done, info = env.step(action[agent_i], agent_i)
            rewards[agent_i] = reward
            done = done or step_done
            next_state = step_state
            step_info = info
            episode_paths[agent_i].append(env.agent_pos[agent_i][:])
            episode_roles[agent_i].append(env.agent_roles[agent_i])
        if step_info is not None:
            episode_info.append(step_info)
        next_graph_state = env.get_graph_state()
        memory.put((state, action, rewards, next_state, [done] * n_agents, graph_state, next_graph_state))
        score += rewards
        state = next_state
        episode_step += 1
    if memory.size() > warm_up_steps:
        train(q, q_target, q_central, q_central_target, memory, optimizer, optimizer_central, gamma, batch_size, update_iter)
    if episode % log_interval == 0:
        q_target.load_state_dict(q.state_dict())
        q_central_target.load_state_dict(q_central.state_dict())
    all_agent_paths.append([path[:] for path in episode_paths])
    all_agent_roles.append([roles[:] for roles in episode_roles])
    all_info.append(episode_info)
    time_steps.append(episode_step)
    epsilons.append(epsilon)
    coverage.append(env.get_coverage())
    survivor_detection.append(step_info['survivor_detection_rate'] if step_info else 0.0)
    total_reward.append(score.sum())
    if episode % log_interval == 0:
        print(f'//Episode {episode+1}// Epsilon: {epsilon:.3f}, Steps: {episode_step}, '
              f'Coverage (%): {coverage[-1]:.3f}, Survivor Detection (%): {survivor_detection[-1]:.3f}, '
              f'Total Reward: {total_reward[-1]:.2f}')
    if episode == 499:  # Adjusted for 500 episodes
        plot_grid(env, episode_paths, episode_roles, episode, episode_info)
        plot_detection_heatmap(env, episode)
        plot_gnn_graph(env, episode_paths, episode_info, episode)
        ani = animate_drone_paths(env, episode_paths, episode_roles, episode)
        with open(f'animation_gnn_episode_{episode + 1}.html', 'w') as f:
            f.write(ani.to_jshtml())
        download_visualizations(episode)

plt.figure(figsize=(12, 8))
plt.subplot(3, 1, 1)
plt.plot(coverage)
plt.title('Coverage Over Episodes (Swarm Drones with GNN)')
plt.xlabel('Episode')
plt.ylabel('Coverage (%)')
plt.grid(True)
plt.subplot(3, 1, 2)
plt.plot(survivor_detection)
plt.title('Survivor Detection Rate Over Episodes (Swarm Drones with GNN)')
plt.xlabel('Episode')
plt.ylabel('Detection Rate (%)')
plt.grid(True)
plt.subplot(3, 1, 3)
plt.plot(total_reward)
plt.title('Total Reward Over Episodes (Swarm Drones with GNN)')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.grid(True)
plt.tight_layout()
plt.savefig('swarm_drone_gnn_performance.png')
plt.close()

np.savez('metrics_gnn.npz', coverage=coverage, survivor_detection=survivor_detection, total_reward=total_reward)



PyTorch Version: 2.6.0+cu124, CUDA Available: False
torch-geometric imported successfully!


Training Episodes:   0%|          | 0/500 [00:00<?, ?it/s]


NameError: name 'fov_y' is not defined

In [16]:
import torch
print(torch.cuda.is_available())  # Should print True

False


In [17]:
!pip cache purge
!pip install -q torch==2.6.0+cu124
!pip install -q torch-geometric torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.6.0+cu124.html
!pip install -q numpy gym tqdm matplotlib ipython

Files removed: 84


In [18]:
import torch
print(f"PyTorch Version: {torch.__version__}, CUDA Available: {torch.cuda.is_available()}")

PyTorch Version: 2.6.0+cu124, CUDA Available: False


In [21]:
import collections
import random
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from gym import spaces
import gym
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.lines import Line2D
from torch_geometric.nn import GATConv
from torch_geometric.data import Data, Batch  # Added Batch import

class MultiAgentActionSpace(list):
    def __init__(self, agents_action_space):
        for x in agents_action_space:
            assert isinstance(x, gym.spaces.Space)
        super().__init__(agents_action_space)
        self._agents_action_space = agents_action_space

    def sample(self):
        return [agent_action_space.sample() for agent_action_space in self._agents_action_space]

class MultiAgentObservationSpace(list):
    def __init__(self, agents_observation_space):
        for x in agents_observation_space:
            assert isinstance(x, gym.spaces.Space)
        super().__init__(agents_observation_space)
        self._agents_observation_space = agents_observation_space

    def sample(self):
        return [agent_observation_space.sample() for agent_observation_space in self._agents_observation_space]

class ReplayBuffer:
    def __init__(self, buffer_limit):
        self.buffer = collections.deque(maxlen=buffer_limit)

    def put(self, transition):
        if len(transition) != 7:
            print(f"Warning: Transition tuple has {len(transition)} elements, expected 7. Transition: {transition}")
            raise ValueError(f"Transition tuple must have exactly 7 elements, got {len(transition)}")
        self.buffer.append(transition)

    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst, graph_s_lst, graph_s_prime_lst = [], [], [], [], [], [], []

        for transition in mini_batch:
            s, a, r, s_prime, done, graph_s, graph_s_prime = transition
            s_lst.append(s)
            a_lst.append(a)
            r_lst.append(r)
            s_prime_lst.append(s_prime)
            done_mask_lst.append((np.ones(len(done)) - done).tolist())
            graph_s_lst.append(graph_s)
            graph_s_prime_lst.append(graph_s_prime)

        return (torch.tensor(np.array(s_lst), dtype=torch.float),
                torch.tensor(np.array(a_lst), dtype=torch.long),
                torch.tensor(np.array(r_lst), dtype=torch.float),
                torch.tensor(np.array(s_prime_lst), dtype=torch.float),
                torch.tensor(np.array(done_mask_lst), dtype=torch.float),
                graph_s_lst,
                graph_s_prime_lst)

    def size(self):
        return len(self.buffer)

class QNet(nn.Module):
    def __init__(self, observation_space, action_space, fov_x=3, fov_y=3):
        super(QNet, self).__init__()
        self.num_agents = len(observation_space)
        self.fov_x, self.fov_y = fov_x, fov_y
        self.n_actions = action_space[0].n
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * fov_x * fov_y, 64),
            nn.ReLU()
        )
        self.non_spatial_size = 1 + 3 + 3
        self.fc_non_spatial = nn.Sequential(
            nn.Linear(self.non_spatial_size, 16),
            nn.ReLU()
        )
        self.fc_shared = nn.Sequential(
            nn.Linear(64 + 16, 64),
            nn.ReLU()
        )
        self.heads = nn.ModuleList([nn.Linear(64, self.n_actions) for _ in range(self.num_agents)])

    def forward(self, obs):
        batch_size = obs.shape[0]
        spatial_obs = torch.zeros(batch_size, self.num_agents, 1, self.fov_x, self.fov_y, device=obs.device)
        for agent_i in range(self.num_agents):
            grid_status = obs[:, agent_i, :self.fov_x * self.fov_y].reshape(-1, self.fov_x, self.fov_y)
            spatial_obs[:, agent_i, 0] = grid_status
        spatial_out = torch.zeros(batch_size, self.num_agents, 64, device=obs.device)
        for agent_i in range(self.num_agents):
            spatial_out[:, agent_i] = self.cnn(spatial_obs[:, agent_i])
        non_spatial = obs[:, :, -self.non_spatial_size:]
        non_spatial_out = self.fc_non_spatial(non_spatial.view(-1, self.non_spatial_size)).view(batch_size, self.num_agents, 16)
        combined = torch.cat([spatial_out, non_spatial_out], dim=2)
        shared_out = self.fc_shared(combined)
        q_values = [self.heads[agent_i](shared_out[:, agent_i]).unsqueeze(1) for agent_i in range(self.num_agents)]
        return torch.cat(q_values, dim=1)

    def sample_action(self, obs, epsilon, roles, grid_status, agent_pos, x_size, y_size, stuck_counts, visited_counts):
        out = self.forward(obs)
        actions = torch.zeros((out.shape[1],), dtype=torch.long)
        exploration_bonus = 50.0
        scout_epsilon_boost = 0.4
        for agent_i in range(out.shape[1]):
            agent_epsilon = epsilon
            if roles[agent_i] == 'scout' and stuck_counts[agent_i] > 2:
                agent_epsilon = min(1.0, epsilon + scout_epsilon_boost)
            if random.random() < agent_epsilon:
                if roles[agent_i] == 'relay':
                    actions[agent_i] = 8
                else:
                    valid_actions = []
                    poi_actions = []
                    low_visit_actions = []
                    x, y = agent_pos[agent_i]
                    for a in range(8):
                        new_x, new_y = x, y
                        if a == 0: new_x -= 1
                        elif a == 1: new_x += 1
                        elif a == 2: new_y -= 1
                        elif a == 3: new_y += 1
                        elif a == 4: new_x, new_y = x-1, y-1
                        elif a == 5: new_x, new_y = x-1, y+1
                        elif a == 6: new_x, new_y = x+1, y-1
                        elif a == 7: new_x, new_y = x+1, y+1
                        if 0 <= new_x < x_size and 0 <= new_y < y_size and grid_status[new_x, new_y] in [0, 3]:
                            valid_actions.append(a)
                            if grid_status[new_x, new_y] == 0:
                                poi_actions.append(a)
                            if visited_counts[new_x, new_y] < 2:
                                low_visit_actions.append(a)
                    actions[agent_i] = random.choice(low_visit_actions or poi_actions or valid_actions) if valid_actions else 8
            else:
                q_values = out[0, agent_i].clone()
                x, y = agent_pos[agent_i]
                max_neighbor_rf = obs[0, agent_i, -2].item()
                for a in range(8):
                    new_x, new_y = x, y
                    if a == 0: new_x -= 1
                    elif a == 1: new_x += 1
                    elif a == 2: new_y -= 1
                    elif a == 3: new_y += 1
                    elif a == 4: new_x, new_y = x-1, y-1
                    elif a == 5: new_x, new_y = x-1, y+1
                    elif a == 6: new_x, new_y = x+1, y-1
                    elif a == 7: new_x, new_y = x+1, y+1
                    if 0 <= new_x < x_size and 0 <= new_y < y_size:
                        if grid_status[new_x, new_y] == 0:
                            q_values[a] += exploration_bonus
                        elif grid_status[new_x, new_y] == 3:
                            q_values[a] += exploration_bonus * (2 + max_neighbor_rf)
                        if visited_counts[new_x, new_y] < 2:
                            q_values[a] += exploration_bonus * 0.5
                actions[agent_i] = q_values.argmax().item()
        return actions

class QCentralGNN(nn.Module):
    def __init__(self, node_feature_size, n_agents, n_actions):
        super().__init__()
        self.n_agents = n_agents
        self.n_actions = n_actions
        self.gat1 = GATConv(node_feature_size, 64, heads=4, concat=True)
        self.gat2 = GATConv(64 * 4, 64, heads=1, concat=True)
        self.fc = nn.Linear(64, n_actions)

    def forward(self, graph_data):
        x, edge_index = graph_data.x, graph_data.edge_index
        x = F.relu(self.gat1(x, edge_index))
        x = F.relu(self.gat2(x, edge_index))
        x = self.fc(x)
        return x.view(-1, self.n_agents, self.n_actions)

class GridMultiAgent(gym.Env):
    metadata = {'render.modes': ['console']}
    XM, XP, YM, YP, XMYM, XMYP, XPYM, XPYP, STAY = range(9)
    OOE, OBS, POI, MAP, AGT, SURVIVOR = -2, -1, 0, 1, 2, 3

    def __init__(self, x_size=10, y_size=10, fov_x=3, fov_y=3, n_agents=3, n_survivors=5, comm_range=5.0):
        super().__init__()
        self.x_size = x_size
        self.y_size = y_size
        self.n_agents = n_agents
        self.idx_agents = list(range(n_agents))
        self.n_survivors = n_survivors
        self.fov_x = fov_x
        self.fov_y = fov_y
        self.comm_range = comm_range
        self.max_relays = n_agents // 2
        self.agent_roles = ['scout'] * n_agents
        self.stuck_counts = [0] * n_agents
        self.visited_counts = np.zeros((x_size, y_size))
        self.survivor_pos = []
        self.agent_paths = [[] for _ in range(n_agents)]
        self.obstacle_move_prob = 0.05
        self.survivor_move_prob = 0.05
        self.grid_status_history = []
        self.adjacency_matrix = np.zeros((n_agents, n_agents))
        n_actions = 9
        self.action_space = MultiAgentActionSpace([spaces.Discrete(n_actions) for _ in range(n_agents)])
        obs_size = (fov_x * fov_y) + 1 + 3 + 3
        self.obs_low = np.concatenate([np.ones(fov_x * fov_y) * self.OOE, [0, 0, 0, 0, 0, 0, 0]], dtype=np.float32)
        self.obs_high = np.concatenate([np.ones(fov_x * fov_y) * self.SURVIVOR, [1, 1, 1, 1, n_agents, 1, 3]], dtype=np.float32)
        self.observation_space = MultiAgentObservationSpace([
            spaces.Box(self.obs_low, self.obs_high, dtype=np.float32) for _ in range(n_agents)
        ])
        self.init_grid()
        self.init_agent()
        self.init_survivors()

    def init_grid(self):
        self.grid_status = np.zeros((self.x_size, self.y_size))
        n_obstacle = random.randrange(0, int(self.x_size * self.y_size * 0.2))
        for _ in range(n_obstacle):
            x, y = random.randrange(1, self.x_size - 1), random.randrange(1, self.y_size - 1)
            self.grid_status[x, y] = self.OBS
        self.grid_counts = np.tile(self.grid_status, (self.n_agents, 1, 1)).reshape(self.n_agents, self.x_size, self.y_size)
        self.n_poi = self.x_size * self.y_size - np.count_nonzero(self.grid_status)
        self.grid_agents_status = copy.deepcopy(self.grid_status)
        self.grid_status_history = [self.grid_status.copy()]
        self.visited_counts = np.zeros((self.x_size, self.y_size))

    def init_agent(self):
        self.agent_pos = []
        self.agent_paths = [[] for _ in range(self.n_agents)]
        for i in range(self.n_agents):
            while True:
                x, y = random.randrange(0, self.x_size), random.randrange(0, self.y_size)
                if self.grid_agents_status[x, y] == self.POI:
                    self.agent_pos.append([x, y])
                    self.grid_agents_status[x, y] = self.AGT
                    self.agent_paths[i].append([x, y])
                    self.visited_counts[x, y] += 1
                    break
        self.stuck_counts = [0] * n_agents

    def init_survivors(self):
        self.survivor_pos = []
        for _ in range(self.n_survivors):
            while True:
                x, y = random.randrange(0, self.x_size), random.randrange(0, self.y_size)
                if self.grid_status[x, y] == self.POI and [x, y] not in self.agent_pos:
                    self.survivor_pos.append([x, y])
                    self.grid_status[x, y] = self.SURVIVOR
                    break

    def compute_network(self):
        self.adjacency_matrix = np.zeros((self.n_agents, self.n_agents))
        for i in range(self.n_agents):
            for j in range(i + 1, self.n_agents):
                dist = np.sqrt((self.agent_pos[i][0] - self.agent_pos[j][0])**2 +
                              (self.agent_pos[i][1] - self.agent_pos[j][1])**2)
                if dist <= self.comm_range:
                    self.adjacency_matrix[i, j] = 1
                    self.adjacency_matrix[j, i] = 1

    def get_neighbor_messages(self, agent_idx):
        neighbors = np.where(self.adjacency_matrix[agent_idx] == 1)[0]
        n_neighbors = len(neighbors)
        max_rf = 0.0
        neighbor_role = 0
        for neighbor in neighbors:
            rf_signals = [self.get_rf_signal(self.agent_pos[neighbor], s) for s in self.survivor_pos]
            max_rf = max(max_rf, max(rf_signals, default=0.0))
            role_map = {'scout': 0, 'verification': 1, 'relay': 2}
            neighbor_role = max(neighbor_role, role_map[self.agent_roles[neighbor]])
        return n_neighbors, max_rf, neighbor_role

    def move_obstacles(self):
        new_grid = self.grid_status.copy()
        obstacle_pos = [(i, j) for i in range(self.x_size) for j in range(self.y_size) if self.grid_status[i, j] == self.OBS]
        for x, y in obstacle_pos:
            if random.random() < self.obstacle_move_prob:
                directions = [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]
                random.shuffle(directions)
                for dx, dy in directions:
                    new_x, new_y = x + dx, y + dy
                    if (0 <= new_x < self.x_size and 0 <= new_y < self.y_size and
                            new_grid[new_x, new_y] == self.POI and [new_x, new_y] not in self.agent_pos):
                        new_grid[x, y] = self.POI
                        new_grid[new_x, new_y] = self.OBS
                        break
        self.grid_status = new_grid

    def move_survivors(self):
        new_survivor_pos = []
        for x, y in self.survivor_pos:
            if random.random() < self.survivor_move_prob:
                directions = [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]
                random.shuffle(directions)
                moved = False
                for dx, dy in directions:
                    new_x, new_y = x + dx, y + dy
                    if (0 <= new_x < self.x_size and 0 <= new_y < self.y_size and
                            self.grid_status[new_x, new_y] == self.POI and
                            [new_x, new_y] not in self.agent_pos):
                        self.grid_status[x, y] = self.POI
                        self.grid_status[new_x, new_y] = self.SURVIVOR
                        new_survivor_pos.append([new_x, new_y])
                        moved = True
                        break
                if not moved:
                    new_survivor_pos.append([x, y])
            else:
                new_survivor_pos.append([x, y])
        self.survivor_pos = new_survivor_pos

    def grid_overlay(self):
        self.grid_agents_status = copy.deepcopy(self.grid_status)
        for i in range(self.n_agents):
            x, y = self.agent_pos[i]
            self.grid_agents_status[x, y] = self.AGT

    def get_rf_signal(self, agent_pos, survivor_pos):
        distance = np.sqrt((agent_pos[0] - survivor_pos[0])**2 + (agent_pos[1] - survivor_pos[1])**2)
        return min(1.0, 1.0 / (distance + 0.5))

    def get_graph_state(self):
        agent_obs = self.get_agent_obs()
        node_features = np.array(agent_obs, dtype=np.float32)
        edge_index = []
        for i in range(self.n_agents):
            for j in range(self.n_agents):
                if self.adjacency_matrix[i, j] == 1:
                    edge_index.append([i, j])
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() if edge_index else torch.empty((2, 0), dtype=torch.long)
        graph_data = Data(
            x=torch.tensor(node_features, dtype=torch.float),
            edge_index=edge_index
        )
        return graph_data

    def get_agent_obs(self):
        self.compute_network()
        self.agent_obs = []
        for agent in range(self.n_agents):
            single_obs = np.ones((self.fov_x, self.fov_y)) * self.OOE
            x, y = self.agent_pos[agent]
            for i in range(self.fov_x):
                for j in range(self.fov_y):
                    obs_x = x + (i - self.fov_x // 2)
                    obs_y = y + (j - self.fov_y // 2)
                    if 0 <= obs_x < self.x_size and 0 <= obs_y < self.y_size:
                        single_obs[i][j] = self.grid_agents_status[obs_x][obs_y]
            single_obs_flat = single_obs.flatten()
            rf_signals = [self.get_rf_signal(self.agent_pos[agent], s) for s in self.survivor_pos]
            rf_signal = max(rf_signals, default=0.0)
            role_encoding = {'scout': [1, 0, 0], 'verification': [0, 1, 0], 'relay': [0, 0, 1]}
            role_vec = role_encoding[self.agent_roles[agent]]
            n_neighbors, max_neighbor_rf, neighbor_role = self.get_neighbor_messages(agent)
            obs = np.concatenate([single_obs_flat, [rf_signal], role_vec, [n_neighbors, max_neighbor_rf, neighbor_role]])
            self.agent_obs.append(obs.astype(np.float32))
        return self.agent_obs

    def update_roles(self):
        relay_count = sum(1 for r in self.agent_roles if r == 'relay')
        detected_survivors = sum(1 for s in self.survivor_pos if
                                 max(self.get_rf_signal([ax, ay], s) for ax, ay in self.agent_pos) > 0.6)
        for i in range(self.n_agents):
            rf_signals = [self.get_rf_signal(self.agent_pos[i], s) for s in self.survivor_pos]
            max_rf = max(rf_signals, default=0.0)
            n_neighbors, max_neighbor_rf, _ = self.get_neighbor_messages(i)
            if (max_rf > 0.6 or max_neighbor_rf > 0.6) and self.agent_roles[i] != 'verification':
                self.agent_roles[i] = 'verification'
            elif detected_survivors < self.n_survivors and self.agent_roles[i] != 'relay' and relay_count < self.max_relays:
                self.agent_roles[i] = 'relay'
                relay_count += 1
            elif detected_survivors < self.n_survivors and self.agent_roles[i] != 'scout':
                self.agent_roles[i] = 'scout'

    def get_coverage(self):
        mapped_poi = np.count_nonzero(self.grid_status == self.MAP)
        return mapped_poi / self.n_poi if self.n_poi > 0 else 0.0

    def get_survivor_detection_rate(self):
        detected = 0
        for s in self.survivor_pos:
            for ax, ay in self.agent_pos:
                if self.get_rf_signal([ax, ay], s) > 0.6:
                    detected += 1
                    break
        return detected / self.n_survivors if self.n_survivors > 0 else 0.0

    def step(self, action, i):
        org_x, org_y = self.agent_pos[i][0], self.agent_pos[i][1]
        reward = 0
        action = int(action)
        new_x, new_y = org_x, org_y
        if action == self.XM:
            new_x -= 1
        elif action == self.XP:
            new_x += 1
        elif action == self.YM:
            new_y -= 1
        elif action == self.YP:
            new_y += 1
        elif action == self.XMYM:
            new_x -= 1
            new_y -= 1
        elif action == self.XMYP:
            new_x -= 1
            new_y += 1
        elif action == self.XPYM:
            new_x += 1
            new_y -= 1
        elif action == self.XPYP:
            new_x += 1
            new_y += 1
        elif action == self.STAY:
            rf_signals = [self.get_rf_signal(self.agent_pos[i], s) for s in self.survivor_pos]
            max_rf = max(rf_signals, default=0.0)
            reward = 5 if self.agent_roles[i] == 'relay' and max_rf > 0.5 else -2
        else:
            raise ValueError(f"Invalid action={action}")
        if (0 <= new_x < self.x_size and 0 <= new_y < self.y_size and
            self.grid_status[new_x, new_y] not in [self.OBS]):
            collision = False
            for j in range(self.n_agents):
                if i != j and self.agent_pos[j][0] == new_x and self.agent_pos[j][1] == new_y:
                    collision = True
                    break
            if not collision:
                self.agent_pos[i] = [new_x, new_y]
                self.visited_counts[new_x, new_y] += 1
                prev_status = self.grid_status[new_x, new_y]
                if prev_status == self.POI:
                    self.grid_status[new_x, new_y] = self.MAP
                    self.grid_counts[i][new_x, new_y] += 1
                    reward = 15
                elif prev_status == self.SURVIVOR:
                    self.grid_counts[i][new_x, new_y] += 1
                    reward = 100
                    detected = self.get_survivor_detection_rate() * self.n_survivors
                    reward += 50 * detected
                elif prev_status == self.MAP:
                    self.grid_counts[i][new_x, new_y] += 1
                    reward = -1
            else:
                self.grid_counts[i][org_x, org_y] += 1
                reward = -5
        else:
            self.grid_counts[i][org_x, org_y] += 1
            reward = -2
        self.agent_paths[i].append(self.agent_pos[i][:])
        if org_x == self.agent_pos[i][0] and org_y == self.agent_pos[i][1]:
            self.stuck_counts[i] += 1
        else:
            self.stuck_counts[i] = 0
        self.move_obstacles()
        self.move_survivors()
        self.update_roles()
        self.grid_overlay()
        self.grid_status_history.append(self.grid_status.copy())
        mapped_poi = np.count_nonzero(self.grid_status == self.MAP)
        done = mapped_poi >= self.n_poi * 0.95 or self.get_survivor_detection_rate() >= 1.0
        info = {'survivor_detection_rate': self.get_survivor_detection_rate(),
                'coverage': self.get_coverage(),
                'adjacency_matrix': self.adjacency_matrix.copy()}
        return self.get_agent_obs(), reward, done, info

    def reset(self):
        self.init_grid()
        self.init_agent()
        self.init_survivors()
        self.agent_roles = ['scout'] * self.n_agents
        self.grid_status_history = [self.grid_status.copy()]
        self.adjacency_matrix = np.zeros((self.n_agents, self.n_agents))
        while any(self.grid_status[pos[0], pos[1]] in [self.OBS, self.SURVIVOR] for pos in self.agent_pos):
            self.init_grid()
            self.init_agent()
            self.init_survivors()
        self.grid_overlay()
        return self.get_agent_obs()

    def close(self):
        pass

def train(q, q_target, q_central, q_central_target, memory, optimizer, optimizer_central, gamma, batch_size, update_iter=10):
    for _ in range(update_iter):
        s, a, r, s_prime, done_mask, graph_s, graph_s_prime = memory.sample(batch_size)
        q_out = q(s)
        a = a.view(-1, q_out.shape[1], 1)
        q_a = q_out.gather(2, a).squeeze(-1)

        # Process batched graph data for GNN (s_prime)
        for g in graph_s_prime:
            g.x = g.x.to(s.device)
            g.edge_index = g.edge_index.to(s.device)
        batch_s_prime = Batch.from_data_list(graph_s_prime).to(s.device)

        with torch.no_grad():
            next_actions = q(s_prime).argmax(dim=2, keepdim=True)
            q_values = q_central_target(batch_s_prime)
            q_selected = q_values.gather(2, next_actions).squeeze(-1)
            max_q_prime = q_selected
            target = r + gamma * max_q_prime * done_mask

        loss_actor = F.smooth_l1_loss(q_a, target.detach())
        optimizer.zero_grad()
        loss_actor.backward()
        optimizer.step()

        # Train GNN (s)
        for g in graph_s:
            g.x = g.x.to(s.device)
            g.edge_index = g.edge_index.to(s.device)
        batch_s = Batch.from_data_list(graph_s).to(s.device)

        q_central_out = q_central(batch_s)
        q_central_a = q_central_out.gather(2, a).squeeze(-1)
        loss_central = F.smooth_l1_loss(q_central_a, target.detach())
        optimizer_central.zero_grad()
        loss_central.backward()
        optimizer_central.step()

def animate_grid(env, agent_paths, episode, info_history, save_path='swarm_drone_gnn_animation.mp4'):
    fig, ax = plt.subplots(figsize=(8, 8))
    cmap = plt.colormaps['tab10']
    def update(frame):
        ax.clear()
        grid = env.grid_status_history[frame] if frame < len(env.grid_status_history) else env.grid_status_history[-1]
        agent_pos = [path[frame] if frame < len(path) else path[-1] for path in agent_paths]
        adjacency_matrix = info_history[frame]['adjacency_matrix'] if frame < len(info_history) else np.zeros((env.n_agents, env.n_agents))
        display_grid = np.zeros((env.x_size, env.y_size))
        display_grid[grid == env.OBS] = -1
        display_grid[grid == env.MAP] = 1
        display_grid[grid == env.SURVIVOR] = 2
        for i, (x, y) in enumerate(agent_pos):
            display_grid[x, y] = 3 + i
        cax = ax.imshow(display_grid, cmap=cmap, vmin=-1, vmax=3 + env.n_agents)
        for i in range(env.n_agents):
            for j in range(i + 1, env.n_agents):
                if adjacency_matrix[i, j] == 1:
                    ax.plot([agent_pos[i][1], agent_pos[j][1]],
                            [agent_pos[i][0], agent_pos[j][0]], 'k-', alpha=0.5)
        ax.set_xticks(np.arange(env.x_size))
        ax.set_yticks(np.arange(env.y_size))
        ax.set_title(f'Episode {episode + 1}, Step {frame + 1}')
        ax.grid(True)
        legend_elements = [
            Line2D([0], [0], color=cmap(-1), lw=4, label='Obstacle'),
            Line2D([0], [0], color=cmap(1), lw=4, label='Mapped'),
            Line2D([0], [0], color=cmap(2), lw=4, label='Survivor'),
        ]
        for i in range(env.n_agents):
            legend_elements.append(Line2D([0], [0], color=cmap(3 + i), lw=4, label=f'Agent {i + 1}'))
        legend_elements.append(Line2D([0], [0], color='k', lw=1, label='Comm Link'))
        ax.legend(handles=legend_elements, loc='upper right')
        return cax,
    frames = min(len(env.grid_status_history), max(len(path) for path in agent_paths))
    ani = FuncAnimation(fig, update, frames=frames, interval=200, blit=False)
    ani.save(save_path, writer='ffmpeg')
    plt.close()

# Main training loop
size = 10
fov = 3
n_agents = 3
n_survivors = 5
train_episodes = 2000
max_steps = size * size * 3
batch_size = 32
gamma = 0.99
buffer_limit = 50000
log_interval = 200
max_epsilon = 0.9
min_epsilon = 0.1
warm_up_steps = 1000
update_iter = 10
lr = 0.001
comm_range = 5.0

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

env = GridMultiAgent(x_size=size, y_size=size, fov_x=fov, fov_y=fov, n_agents=n_agents,
                     n_survivors=n_survivors, comm_range=comm_range)
memory = ReplayBuffer(buffer_limit)
q = QNet(env.observation_space, env.action_space, fov_x=fov, fov_y=fov)
q_target = QNet(env.observation_space, env.action_space, fov_x=fov, fov_y=fov)
q_target.load_state_dict(q.state_dict())
agent_obs_size = (fov * fov) + 1 + 3 + 3
node_feature_size = agent_obs_size
q_central = QCentralGNN(node_feature_size, n_agents, env.action_space[0].n)
q_central_target = QCentralGNN(node_feature_size, n_agents, env.action_space[0].n)
q_central_target.load_state_dict(q_central.state_dict())
optimizer = optim.Adam(q.parameters(), lr=lr)
optimizer_central = optim.Adam(q_central.parameters(), lr=lr)
time_steps, epsilons, coverage, survivor_detection, total_reward, all_agent_paths, all_agent_roles, all_info = [], [], [], [], [], [], [], []

for episode in tqdm(range(train_episodes), desc="Training Episodes"):
    score = np.zeros(n_agents)
    state = env.reset()
    done = False
    episode_step = 0
    epsilon = max(min_epsilon, max_epsilon - (max_epsilon - min_epsilon) * (episode / (0.8 * train_episodes)))
    episode_paths = [[] for _ in range(n_agents)]
    episode_roles = [[] for _ in range(n_agents)]
    episode_info = []
    while not done and episode_step < max_steps:
        state_np = np.array(state, dtype=np.float32)
        graph_state = env.get_graph_state()
        action = q.sample_action(
            torch.tensor(state_np).unsqueeze(0),
            epsilon, env.agent_roles, env.grid_status, env.agent_pos, env.x_size, env.y_size,
            env.stuck_counts, env.visited_counts
        ).data.cpu().numpy()
        rewards = np.zeros(n_agents)
        next_state = state
        step_info = None
        for agent_i in env.idx_agents:
            if done:
                break
            step_state, reward, step_done, info = env.step(action[agent_i], agent_i)
            rewards[agent_i] = reward
            done = done or step_done
            next_state = step_state
            step_info = info
            episode_paths[agent_i].append(env.agent_pos[agent_i][:])
            episode_roles[agent_i].append(env.agent_roles[agent_i])
        if step_info is not None:
            episode_info.append(step_info)
        next_graph_state = env.get_graph_state()
        memory.put((state, action, rewards, next_state, [done] * n_agents, graph_state, next_graph_state))
        score += rewards
        state = next_state
        episode_step += 1
    if memory.size() > warm_up_steps:
        train(q, q_target, q_central, q_central_target, memory, optimizer, optimizer_central, gamma, batch_size, update_iter)
    if episode % log_interval == 0:
        q_target.load_state_dict(q.state_dict())
        q_central_target.load_state_dict(q_central.state_dict())
        animate_grid(env, episode_paths, episode, episode_info, save_path=f'swarm_drone_gnn_animation_episode_{episode}.mp4')
    all_agent_paths.append([path[:] for path in episode_paths])
    all_agent_roles.append([roles[:] for roles in episode_roles])
    all_info.append(episode_info)
    time_steps.append(episode_step)
    epsilons.append(epsilon)
    coverage.append(env.get_coverage())
    survivor_detection.append(step_info['survivor_detection_rate'] if step_info else 0.0)
    total_reward.append(score.sum())
    if episode % log_interval == 0:
        print(f'//Episode {episode+1}// Epsilon: {epsilon:.3f}, Steps: {episode_step}, '
              f'Coverage (%): {coverage[-1]:.3f}, Survivor Detection (%): {survivor_detection[-1]:.3f}, '
              f'Total Reward: {total_reward[-1]:.2f}')

plt.figure(figsize=(12, 8))
plt.subplot(3, 1, 1)
plt.plot(coverage)
plt.title('Coverage Over Episodes (Swarm Drones with GNN)')
plt.xlabel('Episode')
plt.ylabel('Coverage (%)')
plt.grid(True)
plt.subplot(3, 1, 2)
plt.plot(survivor_detection)
plt.title('Survivor Detection Rate Over Episodes (Swarm Drones with GNN)')
plt.xlabel('Episode')
plt.ylabel('Detection Rate (%)')
plt.grid(True)
plt.subplot(3, 1, 3)
plt.plot(total_reward)
plt.title('Total Reward Over Episodes (Swarm Drones with GNN)')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.grid(True)
plt.tight_layout()
plt.savefig('swarm_drone_gnn_performance.png')
plt.close()

Training Episodes:   0%|          | 1/2000 [01:37<54:01:29, 97.29s/it]

//Episode 1// Epsilon: 0.900, Steps: 300, Coverage (%): 0.849, Survivor Detection (%): 0.800, Total Reward: 97095.00


Training Episodes:  10%|█         | 201/2000 [09:50<15:29:22, 31.00s/it]

//Episode 201// Epsilon: 0.800, Steps: 300, Coverage (%): 0.759, Survivor Detection (%): 0.600, Total Reward: 195142.00


Training Episodes:  20%|██        | 401/2000 [18:11<14:32:48, 32.75s/it]

//Episode 401// Epsilon: 0.700, Steps: 300, Coverage (%): 0.700, Survivor Detection (%): 0.600, Total Reward: 111790.00


Training Episodes:  30%|███       | 601/2000 [26:52<12:51:09, 33.07s/it]

//Episode 601// Epsilon: 0.600, Steps: 300, Coverage (%): 0.798, Survivor Detection (%): 0.400, Total Reward: 108372.00


Training Episodes:  40%|████      | 801/2000 [35:40<10:20:51, 31.07s/it]

//Episode 801// Epsilon: 0.500, Steps: 300, Coverage (%): 0.707, Survivor Detection (%): 0.400, Total Reward: 111343.00


Training Episodes:  50%|█████     | 1001/2000 [44:37<8:34:12, 30.88s/it]

//Episode 1001// Epsilon: 0.400, Steps: 300, Coverage (%): 0.604, Survivor Detection (%): 0.600, Total Reward: 204132.00


Training Episodes:  60%|██████    | 1201/2000 [53:42<6:55:04, 31.17s/it]

//Episode 1201// Epsilon: 0.300, Steps: 300, Coverage (%): 0.670, Survivor Detection (%): 0.600, Total Reward: 137831.00


Training Episodes:  70%|███████   | 1401/2000 [1:02:54<5:10:42, 31.12s/it]

//Episode 1401// Epsilon: 0.200, Steps: 300, Coverage (%): 0.619, Survivor Detection (%): 0.400, Total Reward: 99823.00


Training Episodes:  80%|████████  | 1601/2000 [1:11:51<3:25:51, 30.96s/it]

//Episode 1601// Epsilon: 0.100, Steps: 300, Coverage (%): 0.429, Survivor Detection (%): 0.200, Total Reward: 52748.00


Training Episodes:  90%|█████████ | 1801/2000 [1:20:55<1:40:15, 30.23s/it]

//Episode 1801// Epsilon: 0.100, Steps: 300, Coverage (%): 0.711, Survivor Detection (%): 0.200, Total Reward: 60761.00


Training Episodes: 100%|██████████| 2000/2000 [1:28:15<00:00,  2.65s/it]
