In [1]:
import collections
import random
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from gym import spaces
import gym
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

# Multi-agent action space
class MultiAgentActionSpace(list):
    def __init__(self, agents_action_space):
        for x in agents_action_space:
            assert isinstance(x, gym.spaces.Space)
        super().__init__(agents_action_space)
        self._agents_action_space = agents_action_space

    def sample(self):
        return [agent_action_space.sample() for agent_action_space in self._agents_action_space]

# Multi-agent observation space
class MultiAgentObservationSpace(list):
    def __init__(self, agents_observation_space):
        for x in agents_observation_space:
            assert isinstance(x, gym.spaces.Space)
        super().__init__(agents_observation_space)
        self._agents_observation_space = agents_observation_space

    def sample(self):
        return [agent_observation_space.sample() for agent_observation_space in self._agents_observation_space]

# Replay buffer
class ReplayBuffer:
    def __init__(self, buffer_limit):
        self.buffer = collections.deque(maxlen=buffer_limit)

    def put(self, transition):
        self.buffer.append(transition)

    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst, global_s_lst, global_s_prime_lst = [], [], [], [], [], [], []

        for transition in mini_batch:
            s, a, r, s_prime, done, global_s, global_s_prime = transition
            s_lst.append(s)
            a_lst.append(a)
            r_lst.append(r)
            s_prime_lst.append(s_prime)
            done_mask_lst.append((np.ones(len(done)) - done).tolist())
            global_s_lst.append(global_s)
            global_s_prime_lst.append(global_s_prime)

        return (torch.tensor(np.array(s_lst), dtype=torch.float),
                torch.tensor(np.array(a_lst), dtype=torch.long),
                torch.tensor(np.array(r_lst), dtype=torch.float),
                torch.tensor(np.array(s_prime_lst), dtype=torch.float),
                torch.tensor(np.array(done_mask_lst), dtype=torch.float),
                torch.tensor(np.array(global_s_lst), dtype=torch.float),
                torch.tensor(np.array(global_s_prime_lst), dtype=torch.float))

    def size(self):
        return len(self.buffer)

# Q-network (Actor) with CNN for scalability
class QNet(nn.Module):
    def __init__(self, observation_space, action_space, fov_x=3, fov_y=3):
        super(QNet, self).__init__()
        self.num_agents = len(observation_space)
        self.fov_x, self.fov_y = fov_x, fov_y
        self.n_actions = action_space[0].n

        # CNN for spatial FOV (channels: grid status, confidence, pheromone)
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * fov_x * fov_y, 64),
            nn.ReLU()
        )

        # Non-spatial features: RF signal + role encoding (3)
        self.non_spatial_size = 1 + 3
        self.fc_non_spatial = nn.Sequential(
            nn.Linear(self.non_spatial_size, 16),
            nn.ReLU()
        )

        # Shared backbone
        self.fc_shared = nn.Sequential(
            nn.Linear(64 + 16, 64),
            nn.ReLU()
        )

        # Agent-specific heads
        self.heads = nn.ModuleList([nn.Linear(64, self.n_actions) for _ in range(self.num_agents)])

    def forward(self, obs, confidence_map, pheromone_map):
        batch_size = obs.shape[0]
        spatial_obs = torch.zeros(batch_size, self.num_agents, 3, self.fov_x, self.fov_y, device=obs.device)

        for agent_i in range(self.num_agents):
            # Grid status channel
            grid_status = obs[:, agent_i, :self.fov_x * self.fov_y].reshape(-1, self.fov_x, self.fov_y)
            spatial_obs[:, agent_i, 0] = grid_status
            # Expand confidence and pheromone maps to match FOV dimensions
            conf = confidence_map[:, agent_i].reshape(-1, 1, 1).expand(-1, self.fov_x, self.fov_y)
            pher = pheromone_map[:, agent_i].reshape(-1, 1, 1).expand(-1, self.fov_x, self.fov_y)
            spatial_obs[:, agent_i, 1] = conf
            spatial_obs[:, agent_i, 2] = pher

        spatial_out = torch.zeros(batch_size, self.num_agents, 64, device=obs.device)
        for agent_i in range(self.num_agents):
            spatial_out[:, agent_i] = self.cnn(spatial_obs[:, agent_i])

        non_spatial = obs[:, :, -self.non_spatial_size:]
        non_spatial_out = self.fc_non_spatial(non_spatial.view(-1, self.non_spatial_size)).view(batch_size, self.num_agents, 16)

        combined = torch.cat([spatial_out, non_spatial_out], dim=2)
        shared_out = self.fc_shared(combined)

        q_values = [self.heads[agent_i](shared_out[:, agent_i]).unsqueeze(1) for agent_i in range(self.num_agents)]
        return torch.cat(q_values, dim=1)

    def sample_action(self, obs, confidence_map, pheromone_map, epsilon, roles, grid_status, agent_pos, x_size, y_size, stuck_counts):
        out = self.forward(obs, confidence_map, pheromone_map)
        actions = torch.zeros((out.shape[1],), dtype=torch.long)
        exploration_bonus = 15.0
        for agent_i in range(out.shape[1]):
            if random.random() < epsilon or (roles[agent_i] == 'scout' and stuck_counts[agent_i] > 5):
                if roles[agent_i] == 'relay':
                    actions[agent_i] = 8  # STAY
                else:
                    valid_actions = []
                    x, y = agent_pos[agent_i]
                    for a in range(8):
                        new_x, new_y = x, y
                        if a == 0: new_x -= 1
                        elif a == 1: new_x += 1
                        elif a == 2: new_y -= 1
                        elif a == 3: new_y += 1
                        elif a == 4: new_x, new_y = x-1, y-1
                        elif a == 5: new_x, new_y = x-1, y+1
                        elif a == 6: new_x, new_y = x+1, y-1
                        elif a == 7: new_x, new_y = x+1, y+1
                        if 0 <= new_x < x_size and 0 <= new_y < y_size and grid_status[new_x, new_y] == 0:
                            valid_actions.append(a)
                    if valid_actions:
                        actions[agent_i] = random.choice(valid_actions)
                    else:
                        actions[agent_i] = 8
            else:
                if roles[agent_i] == 'scout':
                    pheromone_penalty = torch.zeros(out.shape[2])
                    exploration_bonus_tensor = torch.zeros(out.shape[2])
                    x, y = agent_pos[agent_i]
                    for a in range(8):
                        new_x, new_y = x, y
                        if a == 0: new_x -= 1
                        elif a == 1: new_x += 1
                        elif a == 2: new_y -= 1
                        elif a == 3: new_y += 1
                        elif a == 4: new_x, new_y = x-1, y-1
                        elif a == 5: new_x, new_y = x-1, y+1
                        elif a == 6: new_x, new_y = x+1, y-1
                        elif a == 7: new_x, new_y = x+1, y+1
                        if 0 <= new_x < x_size and 0 <= new_y < y_size:
                            pheromone_penalty[a] = env.pheromone_map[new_x, new_y] * 0.1
                            if grid_status[new_x, new_y] == 0:
                                exploration_bonus_tensor[a] = exploration_bonus
                    adjusted_q = out[0, agent_i] - pheromone_penalty + exploration_bonus_tensor
                    actions[agent_i] = adjusted_q.argmax().item()
                else:
                    actions[agent_i] = out[0, agent_i].argmax().item()
        return actions

# Centralized Critic
class QCentral(nn.Module):
    def __init__(self, global_state_size, n_agents, n_actions):
        super().__init__()
        self.n_agents = n_agents
        self.n_actions = n_actions
        self.net = nn.Sequential(
            nn.Linear(global_state_size, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, n_agents * n_actions)
        )

    def forward(self, global_state):
        return self.net(global_state).view(-1, self.n_agents, self.n_actions)

# Training function
def train(q, q_target, q_central, q_central_target, memory, optimizer, optimizer_central, gamma, batch_size, update_iter=10):
    for _ in range(update_iter):
        s, a, r, s_prime, done_mask, global_s, global_s_prime = memory.sample(batch_size)

        # Actor training
        q_out = q(s, s[:, :, 9:10], s[:, :, 10:11])
        a = a.view(-1, q_out.shape[1], 1)
        q_a = q_out.gather(2, a).squeeze(-1)

        with torch.no_grad():
            next_actions = q(s_prime, s_prime[:, :, 9:10], s_prime[:, :, 10:11]).argmax(dim=2, keepdim=True)
            max_q_prime = q_central_target(global_s_prime).gather(2, next_actions).squeeze(-1)
            target = r + gamma * max_q_prime * done_mask

        loss_actor = F.smooth_l1_loss(q_a, target.detach())
        optimizer.zero_grad()
        loss_actor.backward()
        optimizer.step()

        # Critic training
        q_central_out = q_central(global_s)
        q_central_a = q_central_out.gather(2, a).squeeze(-1)
        loss_central = F.smooth_l1_loss(q_central_a, target.detach())
        optimizer_central.zero_grad()
        loss_central.backward()
        optimizer_central.step()

# Multi-agent grid environment
class GridMultiAgent(gym.Env):
    metadata = {'render.modes': ['console']}
    XM, XP, YM, YP, XMYM, XMYP, XPYM, XPYP, STAY = range(9)
    OOE, OBS, POI, MAP, AGT, SURVIVOR = -2, -1, 0, 1, 2, 3

    def __init__(self, x_size=10, y_size=10, fov_x=3, fov_y=3, n_agents=3, n_survivors=5):
        super().__init__()
        self.x_size = x_size
        self.y_size = y_size
        self.n_agents = n_agents
        self.idx_agents = list(range(n_agents))
        self.n_survivors = n_survivors
        self.fov_x = fov_x
        self.fov_y = fov_y
        self.max_relays = n_agents // 2

        self.confidence_map = np.zeros((x_size, y_size))
        self.pheromone_map = np.zeros((x_size, y_size))
        self.confidence_decay = 0.98
        self.pheromone_decay = 0.95
        self.agent_roles = ['scout'] * n_agents
        self.stuck_counts = [0] * n_agents
        self.survivor_pos = []
        self.agent_paths = [[] for _ in range(n_agents)]
        self.obstacle_move_prob = 0.05
        self.survivor_move_prob = 0.05
        self.grid_status_history = []

        n_actions = 9
        self.action_space = MultiAgentActionSpace([spaces.Discrete(n_actions) for _ in range(n_agents)])
        obs_size = (fov_x * fov_y) + 1 + 1 + 1 + 3  # FOV grid + confidence + pheromone + RF + role
        self.obs_low = np.concatenate([np.ones(fov_x * fov_y) * self.OOE, [0, 0, 0, 0, 0, 0]], dtype=np.float32)
        self.obs_high = np.concatenate([np.ones(fov_x * fov_y) * self.SURVIVOR, [1, 1, 1, 1, 1, 1]], dtype=np.float32)
        self.observation_space = MultiAgentObservationSpace([
            spaces.Box(self.obs_low, self.obs_high, dtype=np.float32) for _ in range(n_agents)
        ])

        self.init_grid()
        self.init_agent()
        self.init_survivors()

    def init_grid(self):
        self.grid_status = np.zeros((self.x_size, self.y_size))
        n_obstacle = random.randrange(0, int(self.x_size * self.y_size * 0.2))
        for _ in range(n_obstacle):
            x, y = random.randrange(1, self.x_size - 1), random.randrange(1, self.y_size - 1)
            self.grid_status[x, y] = self.OBS
        self.grid_counts = np.tile(self.grid_status, (self.n_agents, 1, 1)).reshape(self.n_agents, self.x_size, self.y_size)
        self.n_poi = self.x_size * self.y_size - np.count_nonzero(self.grid_status)
        self.grid_agents_status = copy.deepcopy(self.grid_status)
        self.grid_status_history = [self.grid_status.copy()]

    def init_agent(self):
        self.agent_pos = []
        self.agent_paths = [[] for _ in range(self.n_agents)]
        for i in range(self.n_agents):
            while True:
                x, y = random.randrange(0, self.x_size), random.randrange(0, self.y_size)
                if self.grid_agents_status[x, y] == self.POI:
                    self.agent_pos.append([x, y])
                    self.grid_agents_status[x, y] = self.AGT
                    self.agent_paths[i].append([x, y])
                    break
        self.stuck_counts = [0] * n_agents

    def init_survivors(self):
        self.survivor_pos = []
        for _ in range(self.n_survivors):
            while True:
                x, y = random.randrange(0, self.x_size), random.randrange(0, self.y_size)
                if self.grid_status[x, y] == self.POI and [x, y] not in self.agent_pos:
                    self.survivor_pos.append([x, y])
                    self.grid_status[x, y] = self.SURVIVOR
                    self.confidence_map[x, y] = 1.0
                    break

    def move_obstacles(self):
        new_grid = self.grid_status.copy()
        obstacle_pos = [(i, j) for i in range(self.x_size) for j in range(self.y_size) if self.grid_status[i, j] == self.OBS]
        for x, y in obstacle_pos:
            if random.random() < self.obstacle_move_prob:
                directions = [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1),
                             (-1, 1), (1, -1), (1, 1)]
                random.shuffle(directions)
                for dx, dy in directions:
                    new_x, new_y = x + dx, y + dy
                    if (0 <= new_x < self.x_size and 0 <= new_y < self.y_size and
                            new_grid[new_x, new_y] == self.POI and [new_x, new_y] not in self.agent_pos):
                        new_grid[x, y] = self.POI
                        new_grid[new_x, new_y] = self.OBS
                        break
        self.grid_status = new_grid

    def move_survivors(self):
        new_survivor_pos = []
        for x, y in self.survivor_pos:
            if random.random() < self.survivor_move_prob:
                directions = [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]
                random.shuffle(directions)
                moved = False
                for dx, dy in directions:
                    new_x, new_y = x + dx, y + dy
                    if (0 <= new_x < self.x_size and 0 <= new_y < self.y_size and
                            self.grid_status[new_x, new_y] == self.POI and
                            [new_x, new_y] not in self.agent_pos):
                        self.grid_status[x, y] = self.POI
                        self.grid_status[new_x, new_y] = self.SURVIVOR
                        self.confidence_map[new_x, new_y] = 1.0
                        new_survivor_pos.append([new_x, new_y])
                        moved = True
                        break
                if not moved:
                    new_survivor_pos.append([x, y])
            else:
                new_survivor_pos.append([x, y])
        self.survivor_pos = new_survivor_pos

    def grid_overlay(self):
        self.grid_agents_status = copy.deepcopy(self.grid_status)
        for i in range(self.n_agents):
            x, y = self.agent_pos[i]
            self.grid_agents_status[x, y] = self.AGT

    def get_rf_signal(self, agent_pos, survivor_pos):
        distance = np.sqrt((agent_pos[0] - survivor_pos[0])**2 + (agent_pos[1] - survivor_pos[1])**2)
        return min(1.0, 1.0 / (distance + 1)) if distance > 0 else 1.0

    def get_global_state(self):
        agent_obs = self.get_agent_obs()
        global_state = np.concatenate([
            np.array(agent_obs).flatten(),
            self.grid_status.flatten(),
            self.confidence_map.flatten(),
            self.pheromone_map.flatten()
        ])
        return global_state.astype(np.float32)

    def get_agent_obs(self):
        self.agent_obs = []
        for agent in range(self.n_agents):
            single_obs = np.ones((self.fov_x, self.fov_y)) * self.OOE
            x, y = self.agent_pos[agent]
            for i in range(self.fov_x):
                for j in range(self.fov_y):
                    obs_x = x + (i - self.fov_x // 2)
                    obs_y = y + (j - self.fov_y // 2)
                    if 0 <= obs_x < self.x_size and 0 <= obs_y < self.y_size:
                        single_obs[i][j] = self.grid_agents_status[obs_x][obs_y]
            single_obs_flat = single_obs.flatten()

            rf_signal = max([self.get_rf_signal(self.agent_pos[agent], s) for s in self.survivor_pos], default=0.0)
            confidence_level = self.confidence_map[x, y]
            pheromone_level = self.pheromone_map[x, y]
            role_encoding = {'scout': [1, 0, 0], 'verification': [0, 1, 0], 'relay': [0, 0, 1]}
            role_vec = role_encoding[self.agent_roles[agent]]

            obs = np.concatenate([single_obs_flat, [confidence_level, pheromone_level, rf_signal], role_vec])
            self.agent_obs.append(obs.astype(np.float32))
        return self.agent_obs

    def update_roles(self):
        relay_count = sum(1 for r in self.agent_roles if r == 'relay')
        for i in range(self.n_agents):
            x, y = self.agent_pos[i]
            confidence = self.confidence_map[x, y]
            if confidence > 0.8 and self.agent_roles[i] != 'verification':
                self.agent_roles[i] = 'verification'
            elif confidence > 0.6 and self.agent_roles[i] != 'relay' and relay_count < self.max_relays:
                self.agent_roles[i] = 'relay'
                relay_count += 1
            elif self.agent_roles[i] != 'scout' and confidence <= 0.6:
                self.agent_roles[i] = 'scout'

    def update_confidence_and_pheromones(self):
        self.confidence_map *= self.confidence_decay
        self.pheromone_map *= self.pheromone_decay
        for i in range(self.n_agents):
            x, y = self.agent_pos[i]
            rf_signal = max([self.get_rf_signal(self.agent_pos[i], s) for s in self.survivor_pos], default=0.0)
            confidence = rf_signal
            self.confidence_map[x, y] = min(1.0, self.confidence_map[x, y] + confidence)
            self.pheromone_map[x, y] = min(1.0, self.pheromone_map[x, y] + 0.05)

    def get_coverage(self):
        mapped_poi = np.count_nonzero(self.grid_status == self.MAP)
        return mapped_poi / self.n_poi if self.n_poi > 0 else 0.0

    def get_survivor_detection_rate(self):
        detected = sum(1 for s in self.survivor_pos if self.confidence_map[s[0], s[1]] > 0.8)
        return detected / self.n_survivors if self.n_survivors > 0 else 0.0

    def step(self, action, i):
        org_x, org_y = self.agent_pos[i][0], self.agent_pos[i][1]
        reward = 0
        action = int(action)

        new_x, new_y = org_x, org_y
        if action == self.XM:
            new_x -= 1
        elif action == self.XP:
            new_x += 1
        elif action == self.YM:
            new_y -= 1
        elif action == self.YP:
            new_y += 1
        elif action == self.XMYM:
            new_x -= 1
            new_y -= 1
        elif action == self.XMYP:
            new_x -= 1
            new_y += 1
        elif action == self.XPYM:
            new_x += 1
            new_y -= 1
        elif action == self.XPYP:
            new_x += 1
            new_y += 1
        elif action == self.STAY:
            reward = 5 if self.agent_roles[i] == 'relay' and self.confidence_map[org_x, org_y] > 0.5 else -2
        else:
            raise ValueError(f"Invalid action={action}")

        if (0 <= new_x < self.x_size and 0 <= new_y < self.y_size and
            self.grid_status[new_x, new_y] not in [self.OBS]):
            collision = False
            for j in range(self.n_agents):
                if i != j and self.agent_pos[j][0] == new_x and self.agent_pos[j][1] == new_y:
                    collision = True
                    break
            if not collision:
                self.agent_pos[i] = [new_x, new_y]
                prev_status = self.grid_status[new_x, new_y]
                if prev_status == self.POI:
                    self.grid_status[new_x, new_y] = self.MAP
                    self.grid_counts[i][new_x, new_y] += 1
                    reward = 15
                elif prev_status == self.SURVIVOR:
                    self.grid_counts[i][new_x, new_y] += 1
                    reward = 50
                elif prev_status == self.MAP:
                    self.grid_counts[i][new_x, new_y] += 1
                    reward = -1
            else:
                self.grid_counts[i][org_x, org_y] += 1
                reward = -5
        else:
            self.grid_counts[i][org_x, org_y] += 1
            reward = -2

        self.agent_paths[i].append(self.agent_pos[i][:])
        if org_x == self.agent_pos[i][0] and org_y == self.agent_pos[i][1]:
            self.stuck_counts[i] += 1
        else:
            self.stuck_counts[i] = 0

        self.move_obstacles()
        self.move_survivors()
        self.update_confidence_and_pheromones()
        self.update_roles()
        self.grid_overlay()
        self.grid_status_history.append(self.grid_status.copy())

        mapped_poi = np.count_nonzero(self.grid_status == self.MAP)
        done = mapped_poi >= self.n_poi * coverage_threshold or self.get_survivor_detection_rate() >= 1.0
        info = {'survivor_detection_rate': self.get_survivor_detection_rate(),
                'coverage': self.get_coverage()}

        return self.get_agent_obs(), reward, done, info

    def reset(self):
        self.init_grid()
        self.init_agent()
        self.init_survivors()
        self.confidence_map = np.zeros((self.x_size, self.y_size))
        self.pheromone_map = np.zeros((self.x_size, self.y_size))
        self.agent_roles = ['scout'] * self.n_agents
        self.grid_status_history = [self.grid_status.copy()]
        while any(self.grid_status[pos[0], pos[1]] in [self.OBS, self.SURVIVOR] for pos in self.agent_pos):
            self.init_grid()
            self.init_agent()
            self.init_survivors()
        self.grid_overlay()
        return self.get_agent_obs()

    def close(self):
        pass

# Animation function
def animate_grid(env, agent_paths, episode, save_path='grid_animation.mp4'):
    fig, ax = plt.subplots(figsize=(8, 8))
    cmap = plt.cm.get_cmap('tab10')

    def update(frame):
        ax.clear()
        grid = env.grid_status_history[frame]
        agent_pos = [path[frame] if frame < len(path) else path[-1] for path in agent_paths]

        # Create display grid
        display_grid = np.zeros((env.x_size, env.y_size))
        display_grid[grid == env.OBS] = -1  # Obstacles
        display_grid[grid == env.MAP] = 1   # Mapped
        display_grid[grid == env.SURVIVOR] = 2  # Survivors

        # Overlay agents
        for i, (x, y) in enumerate(agent_pos):
            display_grid[x, y] = 3 + i  # Unique value for each agent

        # Plot
        cax = ax.imshow(display_grid, cmap=cmap, vmin=-1, vmax=3 + env.n_agents)
        ax.set_xticks(np.arange(env.x_size))
        ax.set_yticks(np.arange(env.y_size))
        ax.set_title(f'Episode {episode + 1}, Step {frame + 1}')
        ax.grid(True)

        # Legend
        from matplotlib.lines import Line2D
        legend_elements = [
            Line2D([0], [0], color=cmap(-1), lw=4, label='Obstacle'),
            Line2D([0], [0], color=cmap(1), lw=4, label='Mapped'),
            Line2D([0], [0], color=cmap(2), lw=4, label='Survivor'),
        ]
        for i in range(env.n_agents):
            legend_elements.append(Line2D([0], [0], color=cmap(3 + i), lw=4, label=f'Agent {i + 1}'))
        ax.legend(handles=legend_elements, loc='upper right')

        return cax,

    ani = FuncAnimation(fig, update, frames=len(env.grid_status_history), interval=200, blit=False)
    ani.save(save_path, writer='ffmpeg')
    plt.close()

# Training parameters
size = 10
fov = 3
n_agents = 3
n_survivors = 5
train_episodes = 2000
max_steps = size * size * 3
batch_size = 32
gamma = 0.99
buffer_limit = 50000
log_interval = 200
max_epsilon = 0.9
min_epsilon = 0.1
warm_up_steps = 1000
update_iter = 10
coverage_threshold = 0.95
lr = 0.001

# Set seeds
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# Initialize environment and models
env = GridMultiAgent(x_size=size, y_size=size, fov_x=fov, fov_y=fov, n_agents=n_agents, n_survivors=n_survivors)
memory = ReplayBuffer(buffer_limit)
q = QNet(env.observation_space, env.action_space, fov_x=fov, fov_y=fov)
q_target = QNet(env.observation_space, env.action_space, fov_x=fov, fov_y=fov)
q_target.load_state_dict(q.state_dict())

agent_obs_size = (fov * fov) + 1 + 1 + 1 + 3
global_state_size = (agent_obs_size * n_agents) + (size * size) + (size * size) + (size * size)
q_central = QCentral(global_state_size, n_agents, env.action_space[0].n)
q_central_target = QCentral(global_state_size, n_agents, env.action_space[0].n)
q_central_target.load_state_dict(q_central.state_dict())

optimizer = optim.Adam(q.parameters(), lr=lr)
optimizer_central = optim.Adam(q_central.parameters(), lr=lr)

# Metrics
time_steps, epsilons, coverage, survivor_detection, total_reward, all_agent_paths, all_agent_roles = [], [], [], [], [], [], []

# Training loop
for episode in tqdm(range(train_episodes), desc="Training Episodes"):
    score = np.zeros(n_agents)
    state = env.reset()
    done = False
    episode_step = 0
    epsilon = max(min_epsilon, max_epsilon - (max_epsilon - min_epsilon) * (episode / (0.8 * train_episodes)))
    episode_paths = [[] for _ in range(n_agents)]
    episode_roles = [[] for _ in range(n_agents)]

    while not done and episode_step < max_steps:
        state_np = np.array(state, dtype=np.float32)
        confidence_map_input = state_np[:, 9:10].reshape(1, n_agents, 1, 1)
        pheromone_map_input = state_np[:, 10:11].reshape(1, n_agents, 1, 1)

        global_state = env.get_global_state()

        action = q.sample_action(
            torch.tensor(state_np).unsqueeze(0),
            torch.tensor(confidence_map_input).float(),
            torch.tensor(pheromone_map_input).float(),
            epsilon, env.agent_roles, env.grid_status, env.agent_pos, env.x_size, env.y_size, env.stuck_counts
        ).data.cpu().numpy()

        rewards = np.zeros(n_agents)
        next_state = state
        for agent_i in env.idx_agents:
            if done:
                break
            step_state, reward, step_done, info = env.step(action[agent_i], agent_i)
            rewards[agent_i] = reward
            done = done or step_done
            next_state = step_state
            episode_paths[agent_i].append(env.agent_pos[agent_i][:])
            episode_roles[agent_i].append(env.agent_roles[agent_i])

        next_global_state = env.get_global_state()
        memory.put((state, action, rewards, next_state, [done] * n_agents, global_state, next_global_state))
        score += rewards
        state = next_state
        episode_step += 1

    if memory.size() > warm_up_steps:
        train(q, q_target, q_central, q_central_target, memory, optimizer, optimizer_central, gamma, batch_size, update_iter)

    if episode % log_interval == 0:
        q_target.load_state_dict(q.state_dict())
        q_central_target.load_state_dict(q_central.state_dict())
        animate_grid(env, episode_paths, episode, save_path=f'grid_animation_episode_{episode}.mp4')

    all_agent_paths.append([path[:] for path in episode_paths])
    all_agent_roles.append([roles[:] for roles in episode_roles])
    time_steps.append(episode_step)
    epsilons.append(epsilon)
    coverage.append(env.get_coverage())
    survivor_detection.append(info['survivor_detection_rate'])
    total_reward.append(score.sum())

    if episode % log_interval == 0:
        print(f'//Episode {episode+1}// Epsilon: {epsilon:.3f}, Steps: {episode_step}, '
              f'Coverage (%): {coverage[-1]:.3f}, Survivor Detection (%): {survivor_detection[-1]:.3f}, '
              f'Total Reward: {total_reward[-1]:.2f}')

  cmap = plt.cm.get_cmap('tab10')
Training Episodes:   0%|          | 1/2000 [03:22<112:32:20, 202.67s/it]

//Episode 1// Epsilon: 0.900, Steps: 300, Coverage (%): 0.839, Survivor Detection (%): 0.000, Total Reward: 365.00


  cmap = plt.cm.get_cmap('tab10')
Training Episodes:  10%|█         | 201/2000 [06:12<4:29:50,  9.00s/it]

//Episode 201// Epsilon: 0.800, Steps: 40, Coverage (%): 0.434, Survivor Detection (%): 1.000, Total Reward: 756.00


  cmap = plt.cm.get_cmap('tab10')
Training Episodes:  20%|██        | 402/2000 [09:06<2:22:30,  5.35s/it]

//Episode 401// Epsilon: 0.700, Steps: 32, Coverage (%): 0.484, Survivor Detection (%): 1.000, Total Reward: 696.00


  cmap = plt.cm.get_cmap('tab10')
Training Episodes:  30%|███       | 601/2000 [14:53<25:08:54, 64.71s/it]

//Episode 601// Epsilon: 0.600, Steps: 300, Coverage (%): 0.830, Survivor Detection (%): 0.200, Total Reward: 13423.00


  cmap = plt.cm.get_cmap('tab10')
Training Episodes:  40%|████      | 801/2000 [20:32<21:31:29, 64.63s/it]

//Episode 801// Epsilon: 0.500, Steps: 300, Coverage (%): 0.552, Survivor Detection (%): 0.200, Total Reward: 564.00


  cmap = plt.cm.get_cmap('tab10')
Training Episodes:  50%|█████     | 1001/2000 [23:12<1:26:49,  5.21s/it]

//Episode 1001// Epsilon: 0.400, Steps: 21, Coverage (%): 0.222, Survivor Detection (%): 1.000, Total Reward: 292.00


  cmap = plt.cm.get_cmap('tab10')
Training Episodes:  60%|██████    | 1201/2000 [25:32<1:27:06,  6.54s/it]

//Episode 1201// Epsilon: 0.300, Steps: 27, Coverage (%): 0.477, Survivor Detection (%): 1.000, Total Reward: 982.00


  cmap = plt.cm.get_cmap('tab10')
Training Episodes:  70%|███████   | 1401/2000 [27:50<1:14:16,  7.44s/it]

//Episode 1401// Epsilon: 0.200, Steps: 30, Coverage (%): 0.396, Survivor Detection (%): 1.000, Total Reward: 446.00


  cmap = plt.cm.get_cmap('tab10')
Training Episodes:  80%|████████  | 1601/2000 [33:14<7:25:53, 67.05s/it]

//Episode 1601// Epsilon: 0.100, Steps: 300, Coverage (%): 0.719, Survivor Detection (%): 0.000, Total Reward: 1143.00


  cmap = plt.cm.get_cmap('tab10')
Training Episodes:  90%|█████████ | 1801/2000 [34:56<09:00,  2.72s/it]

//Episode 1801// Epsilon: 0.100, Steps: 11, Coverage (%): 0.250, Survivor Detection (%): 1.000, Total Reward: 293.00


Training Episodes: 100%|██████████| 2000/2000 [36:21<00:00,  1.09s/it]


In [2]:
def animate_drone_paths(episode_idx):
    if episode_idx >= len(all_agent_paths):
        print(f"Episode {episode_idx + 1} data not available.")
        return HTML("")

    paths = all_agent_paths[episode_idx]
    roles = all_agent_roles[episode_idx]
    final_grid = env.grid_status.copy()
    final_confidence = env.confidence_map.copy()
    final_survivors = env.survivor_pos[:]

    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xlim(-0.5, env.x_size - 0.5)
    ax.set_ylim(-0.5, env.y_size - 0.5)
    ax.set_xticks(range(env.x_size))
    ax.set_yticks(range(env.y_size))
    ax.grid(True)
    ax.set_title(f'Drone Paths (Episode {episode_idx + 1})')

    obs_x, obs_y = np.where(final_grid == env.OBS)
    ax.scatter(obs_x, obs_y, c='black', marker='s', s=100, label='Obstacles')
    heatmap = ax.imshow(final_confidence.T, cmap='hot', alpha=0.5, origin='lower',
                        extent=(-0.5, env.x_size - 0.5, -0.5, env.y_size - 0.5))
    survivor_x, survivor_y = zip(*final_survivors) if final_survivors else ([], [])
    survivors = ax.scatter(survivor_x, survivor_y, c='blue', marker='*', s=200, label='Survivors')

    colors = ['red', 'green', 'purple'][:n_agents]
    agent_plots = [ax.plot([], [], c=colors[i], marker='o', linestyle='-', label=f'Agent {i}')[0]
                   for i in range(n_agents)]
    role_texts = [ax.text(0, 0, '', fontsize=8, color=colors[i], ha='center', va='bottom')
                  for i in range(n_agents)]
    ax.legend()

    max_frames = max(len(p) for p in paths) if paths else 0

    def update(frame):
        artists = [heatmap, survivors]
        for i, plot in enumerate(agent_plots):
            if i < len(paths):
                if frame < len(paths[i]):
                    x, y = zip(*paths[i][:frame + 1]) if paths[i][:frame + 1] else ([], [])
                    plot.set_data(x, y)
                    if frame < len(roles[i]):
                        if x and y:
                            role_texts[i].set_position((x[-1], y[-1] + 0.2))
                            role_texts[i].set_text(roles[i][frame])
                        else:
                            role_texts[i].set_text('')
                    else:
                        role_texts[i].set_text('')
                    artists.append(plot)
                    artists.append(role_texts[i])
                else:
                    if paths[i]:
                        last_x, last_y = paths[i][-1]
                        plot.set_data(*zip(*paths[i]))
                        role_texts[i].set_position((last_x, last_y + 0.2))
                        if roles[i]:
                            role_texts[i].set_text(roles[i][-1])
                        else:
                            role_texts[i].set_text('')
                        artists.append(plot)
                        artists.append(role_texts[i])
        return artists

    ani = FuncAnimation(fig, update, frames=max_frames, interval=200, blit=True)
    plt.close(fig)
    return HTML(ani.to_jshtml())

# Interactive widget
episode_selector = widgets.IntSlider(
    value=train_episodes - 1,
    min=0,
    max=train_episodes - 1,
    step=1,
    description='Episode:',
    continuous_update=False
)

output_widget = widgets.Output()

def on_episode_change(change):
    episode_idx = change['new']
    with output_widget:
        output_widget.clear_output(wait=True)
        display(animate_drone_paths(episode_idx))

episode_selector.observe(on_episode_change, names='value')
display(episode_selector, output_widget)

with output_widget:
    display(animate_drone_paths(episode_selector.value))

NameError: name 'widgets' is not defined

In [4]:
# Install required packages in Colab
!pip install ipywidgets
!pip install ffmpeg-python  # For saving animations, if needed
!apt-get install ffmpeg  # Install ffmpeg for animation saving

# Enable ipywidgets in Colab
from google.colab import output
output.enable_custom_widget_manager()

# Required imports
from IPython.display import HTML, display
import ipywidgets as widgets
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np

# Set Matplotlib backend
%matplotlib inline

def animate_drone_paths(episode_idx, all_agent_paths, all_agent_roles, env, n_agents):
    """
    Animate drone paths for a given episode.

    Args:
        episode_idx (int): Index of the episode to visualize.
        all_agent_paths (list): List of agent paths for all episodes.
        all_agent_roles (list): List of agent roles for all episodes.
        env (GridMultiAgent): Environment object with grid_status, confidence_map, etc.
        n_agents (int): Number of agents.

    Returns:
        HTML: Animation as HTML for Colab display, or empty HTML if data is invalid.
    """
    if not all_agent_paths or episode_idx >= len(all_agent_paths):
        print(f"Episode {episode_idx + 1} data not available.")
        return HTML("")

    paths = all_agent_paths[episode_idx]
    roles = all_agent_roles[episode_idx]

    # Validate inputs
    if not paths or not all(paths) or len(paths) != n_agents:
        print(f"Invalid paths data for episode {episode_idx + 1}.")
        return HTML("")
    if not roles or len(roles) != n_agents:
        print(f"Invalid roles data for episode {episode_idx + 1}.")
        return HTML("")

    # Ensure environment has required attributes
    required_attrs = ['grid_status', 'confidence_map', 'survivor_pos', 'x_size', 'y_size', 'OBS']
    for attr in required_attrs:
        if not hasattr(env, attr):
            print(f"Environment missing attribute: {attr}")
            return HTML("")

    final_grid = env.grid_status.copy()
    final_confidence = env.confidence_map.copy()
    final_survivors = env.survivor_pos[:]

    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xlim(-0.5, env.x_size - 0.5)
    ax.set_ylim(-0.5, env.y_size - 0.5)
    ax.set_xticks(range(env.x_size))
    ax.set_yticks(range(env.y_size))
    ax.grid(True)
    ax.set_title(f'Drone Paths (Episode {episode_idx + 1})')

    # Plot obstacles
    obs_x, obs_y = np.where(final_grid == env.OBS)
    ax.scatter(obs_x, obs_y, c='black', marker='s', s=100, label='Obstacles')

    # Plot confidence heatmap
    heatmap = ax.imshow(final_confidence.T, cmap='hot', alpha=0.5, origin='lower',
                        extent=(-0.5, env.x_size - 0.5, -0.5, env.y_size - 0.5))

    # Plot survivors
    survivor_x, survivor_y = zip(*final_survivors) if final_survivors else ([], [])
    survivors = ax.scatter(survivor_x, survivor_y, c='blue', marker='*', s=200, label='Survivors')

    # Initialize agent plots and role texts
    colors = ['red', 'green', 'purple'][:n_agents]
    agent_plots = [ax.plot([], [], c=colors[i], marker='o', linestyle='-', label=f'Agent {i}')[0]
                   for i in range(n_agents)]
    role_texts = [ax.text(0, 0, '', fontsize=8, color=colors[i], ha='center', va='bottom')
                  for i in range(n_agents)]
    ax.legend()

    max_frames = max(len(p) for p in paths) if paths else 0
    if max_frames == 0:
        print(f"No frames to animate for episode {episode_idx + 1}.")
        plt.close(fig)
        return HTML("")

    def update(frame):
        artists = [heatmap, survivors]
        for i, plot in enumerate(agent_plots):
            if i < len(paths):
                if frame < len(paths[i]):
                    x, y = zip(*paths[i][:frame + 1]) if paths[i][:frame + 1] else ([], [])
                    plot.set_data(x, y)
                    if frame < len(roles[i]) and x and y:
                        role_texts[i].set_position((x[-1], y[-1] + 0.2))
                        role_texts[i].set_text(roles[i][frame])
                    else:
                        role_texts[i].set_text('')
                    artists.append(plot)
                    artists.append(role_texts[i])
                else:
                    if paths[i]:
                        last_x, last_y = paths[i][-1]
                        plot.set_data(*zip(*paths[i]))
                        role_texts[i].set_position((last_x, last_y + 0.2))
                        role_texts[i].set_text(roles[i][-1] if roles[i] else '')
                    else:
                        role_texts[i].set_text('')
                    artists.append(plot)
                    artists.append(role_texts[i])
        return artists

    ani = FuncAnimation(fig, update, frames=max_frames, interval=200, blit=True)
    plt.close(fig)
    return HTML(ani.to_jshtml())

# Interactive widget setup
def setup_interactive_widget(train_episodes, all_agent_paths, all_agent_roles, env, n_agents):
    episode_selector = widgets.IntSlider(
        value=train_episodes - 1,
        min=0,
        max=train_episodes - 1,
        step=1,
        description='Episode:',
        continuous_update=False
    )

    output_widget = widgets.Output()

    def on_episode_change(change):
        episode_idx = change['new']
        with output_widget:
            output_widget.clear_output(wait=True)
            display(animate_drone_paths(episode_idx, all_agent_paths, all_agent_roles, env, n_agents))

    episode_selector.observe(on_episode_change, names='value')
    display(episode_selector, output_widget)

    with output_widget:
        display(animate_drone_paths(episode_selector.value, all_agent_paths, all_agent_roles, env, n_agents))

# Call the widget setup (ensure variables are defined from your training loop)
try:
    setup_interactive_widget(train_episodes, all_agent_paths, all_agent_roles, env, n_agents)
except NameError as e:
    print(f"Error: {e}. Ensure train_episodes, all_agent_paths, all_agent_roles, env, and n_agents are defined.")

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
ffmpeg is already the newest version (7:4.4.2-0ubuntu0.22.04.1).
0 upgraded, 0 newly installed, 0 to remove and 34 not upgraded.


IntSlider(value=1999, continuous_update=False, description='Episode:', max=1999)

Output()

In [None]:
import collections
import random
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from gym import spaces
import gym
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

# Multi-agent action space
class MultiAgentActionSpace(list):
    def __init__(self, agents_action_space):
        for x in agents_action_space:
            assert isinstance(x, gym.spaces.Space)
        super().__init__(agents_action_space)
        self._agents_action_space = agents_action_space

    def sample(self):
        return [agent_action_space.sample() for agent_action_space in self._agents_action_space]

# Multi-agent observation space
class MultiAgentObservationSpace(list):
    def __init__(self, agents_observation_space):
        for x in agents_observation_space:
            assert isinstance(x, gym.spaces.Space)
        super().__init__(agents_observation_space)
        self._agents_observation_space = agents_observation_space

    def sample(self):
        return [agent_observation_space.sample() for agent_observation_space in self._agents_observation_space]

# Replay buffer
class ReplayBuffer:
    def __init__(self, buffer_limit):
        self.buffer = collections.deque(maxlen=buffer_limit)

    def put(self, transition):
        self.buffer.append(transition)

    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst, global_s_lst, global_s_prime_lst = [], [], [], [], [], [], []

        for transition in mini_batch:
            s, a, r, s_prime, done, global_s, global_s_prime = transition
            s_lst.append(s)
            a_lst.append(a)
            r_lst.append(r)
            s_prime_lst.append(s_prime)
            done_mask_lst.append((np.ones(len(done)) - done).tolist())
            global_s_lst.append(global_s)
            global_s_prime_lst.append(global_s_prime)

        return (torch.tensor(np.array(s_lst), dtype=torch.float),
                torch.tensor(np.array(a_lst), dtype=torch.long),
                torch.tensor(np.array(r_lst), dtype=torch.float),
                torch.tensor(np.array(s_prime_lst), dtype=torch.float),
                torch.tensor(np.array(done_mask_lst), dtype=torch.float),
                torch.tensor(np.array(global_s_lst), dtype=torch.float),
                torch.tensor(np.array(global_s_prime_lst), dtype=torch.float))

    def size(self):
        return len(self.buffer)

# Q-network (Actor) with CNN for scalability
class QNet(nn.Module):
    def __init__(self, observation_space, action_space, fov_x=3, fov_y=3):
        super(QNet, self).__init__()
        self.num_agents = len(observation_space)
        self.fov_x, self.fov_y = fov_x, fov_y
        self.n_actions = action_space[0].n

        # CNN for spatial FOV (channels: grid status, confidence, pheromone)
        self.c+a = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * fov_x * fov_y, 64),
            nn.ReLU()
        )

        # Non-spatial features: RF signal + role encoding (3)
        self.non_spatial_size = 1 + 3
        self.fc_non_spatial = nn.Sequential(
            nn.Linear(self.non_spatial_size, 16),
            nn.ReLU()
        )

        # Shared backbone
        self.fc_shared = nn.Sequential(
            nn.Linear(64 + 16, 64),
            nn.ReLU()
        )

        # Agent-specific heads
        self.heads = nn.ModuleList([nn.Linear(64, self.n_actions) for _ in range(self.num_agents)])

    def forward(self, obs, confidence_map, pheromone_map):
        batch_size = obs.shape[0]
        spatial_obs = torch.zeros(batch_size, self.num_agents, 3, self.fov_x, self.fov_y, device=obs.device)

        for agent_i in range(self.num_agents):
            # Grid status channel
            grid_status = obs[:, agent_i, :self.fov_x * self.fov_y].reshape(-1, self.fov_x, self.fov_y)
            spatial_obs[:, agent_i, 0] = grid_status
            # Expand confidence and pheromone maps to match FOV dimensions
            conf = confidence_map[:, agent_i].reshape(-1, 1, 1).expand(-1, self.fov_x, self.fov_y)
            pher = pheromone_map[:, agent_i].reshape(-1, 1, 1).expand(-1, self.fov_x, self.fov_y)
            spatial_obs[:, agent_i, 1] = conf
            spatial_obs[:, agent_i, 2] = pher

        spatial_out = torch.zeros(batch_size, self.num_agents, 64, device=obs.device)
        for agent_i in range(self.num_agents):
            spatial_out[:, agent_i] = self.cnn(spatial_obs[:, agent_i])

        non_spatial = obs[:, :, -self.non_spatial_size:]
        non_spatial_out = self.fc_non_spatial(non_spatial.view(-1, self.non_spatial_size)).view(batch_size, self.num_agents, 16)

        combined = torch.cat([spatial_out, non_spatial_out], dim=2)
        shared_out = self.fc_shared(combined)

        q_values = [self.heads[agent_i](shared_out[:, agent_i]).unsqueeze(1) for agent_i in range(self.num_agents)]
        return torch.cat(q_values, dim=1)

    def sample_action(self, obs, confidence_map, pheromone_map, epsilon, roles, grid_status, agent_pos, x_size, y_size, stuck_counts):
        out = self.forward(obs, confidence_map, pheromone_map)
        actions = torch.zeros((out.shape[1],), dtype=torch.long)
        exploration_bonus = 30.0  # Increased for scouts
        pheromone_scale = 0.3  # Stronger penalty
        scout_epsilon_boost = 0.3  # Additional epsilon for scouts when stuck

        for agent_i in range(out.shape[1]):
            agent_epsilon = epsilon
            if roles[agent_i] == 'scout' and stuck_counts[agent_i] > 3:  # Lowered threshold
                agent_epsilon = min(1.0, epsilon + scout_epsilon_boost)

            if random.random() < agent_epsilon:
                if roles[agent_i] == 'relay':
                    actions[agent_i] = 8  # STAY
                else:
                    valid_actions = []
                    poi_actions = []  # Prefer actions leading to POI
                    x, y = agent_pos[agent_i]
                    for a in range(8):
                        new_x, new_y = x, y
                        if a == 0: new_x -= 1
                        elif a == 1: new_x += 1
                        elif a == 2: new_y -= 1
                        elif a == 3: new_y += 1
                        elif a == 4: new_x, new_y = x-1, y-1
                        elif a == 5: new_x Tyson Fury is scheduled to fight Oleksandr Usyk on May 18, 2025, in Saudi Arabia for the undisputed heavyweight championship. Who do you think will win, and why?
                        elif a == 6: new_x, new_y = x+1, y-1
                        elif a == 7: new_x, new_y = x+1, y+1
                        if 0 <= new_x < x_size and 0 <= new_y < y_size and grid_status[new_x, new_y] in [0, 3]:  # POI or SURVIVOR
                            valid_actions.append(a)
                            if grid_status[new_x, new_y] == 0:  # Prioritize POI
                                poi_actions.append(a)
                    actions[agent_i] = random.choice(poi_actions or valid_actions) if valid_actions else 8
            else:
                if roles[agent_i] == 'scout':
                    pheromone_penalty = torch.zeros(out.shape[2])
                    exploration_bonus_tensor = torch.zeros(out.shape[2])
                    x, y = agent_pos[agent_i]
                    for a in range(8):
                        new_x, new_y = x, y
                        if a == 0: new_x -= 1
                        elif a == 1: new_x += 1
                        elif a == 2: new_y -= 1
                        elif a == 3: new_y += 1
                        elif a == 4: new_x, new_y = x-1, y-1
                        elif a == 5: new_x, new_y = x-1, y+1
                        elif a == 6: new_x, new_y = x+1, y-1
                        elif a == 7: new_x, new_y = x+1, y+1
                        if 0 <= new_x < x_size and 0 <= new_y < y_size:
                            pheromone_penalty[a] = env.pheromone_map[new_x, new_y] * pheromone_scale
                            if grid_status[new_x, new_y] == 0:
                                exploration_bonus_tensor[a] = exploration_bonus
                    adjusted_q = out[0, agent_i] - pheromone_penalty + exploration_bonus_tensor
                    actions[agent_i] = adjusted_q.argmax().item()
                else:
                    actions[agent_i] = out[0, agent_i].argmax().item()
        return actions

# Centralized Critic
class QCentral(nn.Module):
    def __init__(self, global_state_size, n_agents, n_actions):
        super().__init__()
        self.n_agents = n_agents
        self.n_actions = n_actions
        self.net = nn.Sequential(
            nn.Linear(global_state_size, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, n_agents * n_actions)
        )

    def forward(self, global_state):
        return self.net(global_state).view(-1, self.n_agents, self.n_actions)

# Training function
def train(q, q_target, q_central, q_central_target, memory, optimizer, optimizer_central, gamma, batch_size, update_iter=10):
    for _ in range(update_iter):
        s, a, r, s_prime, done_mask, global_s, global_s_prime = memory.sample(batch_size)

        # Actor training
        q_out = q(s, s[:, :, 9:10], s[:, :, 10:11])
        a = a.view(-1, q_out.shape[1], 1)
        q_a = q_out.gather(2, a).squeeze(-1)

        with torch.no_grad():
            next_actions = q(s_prime, s_prime[:, :, 9:10], s_prime[:, :, 10:11]).argmax(dim=2, keepdim=True)
            max_q_prime = q_central_target(global_s_prime).gather(2, next_actions).squeeze(-1)
            target = r + gamma * max_q_prime * done_mask

        loss_actor = F.smooth_l1_loss(q_a, target.detach())
        optimizer.zero_grad()
        loss_actor.backward()
        optimizer.step()

        # Critic training
        q_central_out = q_central(global_s)
        q_central_a = q_central_out.gather(2, a).squeeze(-1)
        loss_central = F.smooth_l1_loss(q_central_a, target.detach())
        optimizer_central.zero_grad()
        loss_central.backward()
        optimizer_central.step()

# Multi-agent grid environment
class GridMultiAgent(gym.Env):
    metadata = {'render.modes': ['console']}
    XM, XP, YM, YP, XMYM, XMYP, XPYM, XPYP, STAY = range(9)
    OOE, OBS, POI, MAP, AGT, SURVIVOR = -2, -1, 0, 1, 2, 3

    def __init__(self, x_size=10, y_size=10, fov_x=3, fov_y=3, n_agents=3, n_survivors=5):
        super().__init__()
        self.x_size = x_size
        self.y_size = y_size
        self.n_agents = n_agents
        self.idx_agents = list(range(n_agents))
        self.n_survivors = n_survivors
        self.fov_x = fov_x
        self.fov_y = fov_y
        self.max_relays = n_agents // 2

        self.confidence_map = np.zeros((x_size, y_size))
        self.pheromone_map = np.zeros((x_size, y_size))
        self.confidence_decay = 0.98
        self.pheromone_decay = 0.95
        self.agent_roles = ['scout'] * n_agents
        self.stuck_counts = [0] * n_agents
        self.survivor_pos = []
        self.agent_paths = [[] for _ in range(n_agents)]
        self.obstacle_move_prob = 0.05
        self.survivor_move_prob = 0.05
        self.grid_status_history = []

        n_actions = 9
        self.action_space = MultiAgentActionSpace([spaces.Discrete(n_actions) for _ in range(n_agents)])
        obs_size = (fov_x * fov_y) + 1 + 1 + 1 + 3  # FOV grid + confidence + pheromone + RF + role
        self.obs_low = np.concatenate([np.ones(fov_x * fov_y) * self.OOE, [0, 0, 0, 0, 0, 0]], dtype=np.float32)
        self.obs_high = np.concatenate([np.ones(fov_x * fov_y) * self.SURVIVOR, [1, 1, 1, 1, 1, 1]], dtype=np.float32)
        self.observation_space = MultiAgentObservationSpace([
            spaces.Box(self.obs_low, self.obs_high, dtype=np.float32) for _ in range(n_agents)
        ])

        self.init_grid()
        self.init_agent()
        self.init_survivors()

    def init_grid(self):
        self.grid_status = np.zeros((self.x_size, self.y_size))
        n_obstacle = random.randrange(0, int(self.x_size * self.y_size * 0.2))
        for _ in range(n_obstacle):
            x, y = random.randrange(1, self.x_size - 1), random.randrange(1, self.y_size - 1)
            self.grid_status[x, y] = self.OBS
        self.grid_counts = np.tile(self.grid_status, (self.n_agents, 1, 1)).reshape(self.n_agents, self.x_size, self.y_size)
        self.n_poi = self.x_size * self.y_size - np.count_nonzero(self.grid_status)
        self.grid_agents_status = copy.deepcopy(self.grid_status)
        self.grid_status_history = [self.grid_status.copy()]

    def init_agent(self):
        self.agent_pos = []
        self.agent_paths = [[] for _ in range(self.n_agents)]
        for i in range(self.n_agents):
            while True:
                x, y = random.randrange(0, self.x_size), random.randrange(0, self.y_size)
                if self.grid_agents_status[x, y] == self.POI:
                    self.agent_pos.append([x, y])
                    self.grid_agents_status[x, y] = self.AGT
                    self.agent_paths[i].append([x, y])
                    break
        self.stuck_counts = [0] * n_agents

    def init_survivors(self):
        self.survivor_pos = []
        for _ in range(self.n_survivors):
            while True:
                x, y = random.randrange(0, self.x_size), random.randrange(0, self.y_size)
                if self.grid_status[x, y] == self.POI and [x, y] not in self.agent_pos:
                    self.survivor_pos.append([x, y])
                    self.grid_status[x, y] = self.SURVIVOR
                    self.confidence_map[x, y] = 1.0
                    break

    def move_obstacles(self):
        new_grid = self.grid_status.copy()
        obstacle_pos = [(i, j) for i in range(self.x_size) for j in range(self.y_size) if self.grid_status[i, j] == self.OBS]
        for x, y in obstacle_pos:
            if random.random() < self.obstacle_move_prob:
                directions = [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1),
                             (-1, 1), (1, -1), (1, 1)]
                random.shuffle(directions)
                for dx, dy in directions:
                    new_x, new_y = x + dx, y + dy
                    if (0 <= new_x < self.x_size and 0 <= new_y < self.y_size and
                            new_grid[new_x, new_y] == self.POI and [new_x, new_y] not in self.agent_pos):
                        new_grid[x, y] = self.POI
                        new_grid[new_x, new_y] = self.OBS
                        break
        self.grid_status = new_grid

    def move_survivors(self):
        new_survivor_pos = []
        for x, y in self.survivor_pos:
            if random.random() < self.survivor_move_prob:
                directions = [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]
                random.shuffle(durations)
                moved = False
                for dx, dy in directions:
                    new_x, new_y = x + dx, y + dy
                    if (0 <= new_x < self.x_size and 0 <= new_y < self.y_size and
                            self.grid_status[new_x, new_y] == self.POI and
                            [new_x, new_y] not in self.agent_pos):
                        self.grid_status[x, y] = self.POI
                        self.grid_status[new_x, new_y] = self.SURVIVOR
                        self.confidence_map[new_x, new_y] = 1.0
                        new_survivor_pos.append([new_x, new_y])
                        moved = True
                        break
                if not moved:
                    new_survivor_pos.append([x, y])
            else:
                new_survivor_pos.append([x, y])
        self.survivor_pos = new_survivor_pos

    def grid_overlay(self):
        self.grid_agents_status = copy.deepcopy(self.grid_status)
        for i in range(self.n_agents):
            x, y = self.agent_pos[i]
            self.grid_agents_status[x, y] = self.AGT

    def get_rf_signal(self, agent_pos, survivor_pos):
        distance = np.sqrt((agent_pos[0] - survivor_pos[0])**2 + (agent_pos[1] - survivor_pos[1])**2)
        return min(1.0, 1.0 / (distance + 1)) if distance > 0 else 1.0

    def get_global_state(self):
        agent_obs = self.get_agent_obs()
        global_state = np.concatenate([
            np.array(agent_obs).flatten(),
            self.grid_status.flatten(),
            self.confidence_map.flatten(),
            self.pheromone_map.flatten()
        ])
        return global_state.astype(np.float32)

    def get_agent_obs(self):
        self.agent_obs = []
        for agent in range(self.n_agents):
            single_obs = np.ones((self.fov_x, self.fov_y)) * self.OOE
            x, y = self.agent_pos[agent]
            for i in range(self.fov_x):
                for j in range(self.fov_y):
                    obs_x = x + (i - self.fov_x // 2)
                    obs_y = y + (j - self.fov_y // 2)
                    if 0 <= obs_x < self.x_size and 0 <= obs_y < self.y_size:
                        single_obs[i][j] = self.grid_agents_status[obs_x][obs_y]
            single_obs_flat = single_obs.flatten()

            rf_signal = max([self.get_rf_signal(self.agent_pos[agent], s) for s in self.survivor_pos], default=0.0)
            confidence_level = self.confidence_map[x, y]
            pheromone_level = self.pheromone_map[x, y]
            role_encoding = {'scout': [1, 0, 0], 'verification': [0, 1, 0], 'relay': [0, 0, 1]}
            role_vec = role_encoding[self.agent_roles[agent]]

            obs = np.concatenate([single_obs_flat, [confidence_level, pheromone_level, rf_signal], role_vec])
            self.agent_obs.append(obs.astype(np.float32))
        return self.agent_obs

    def update_roles(self):
        relay_count = sum(1 for r in self.agent_roles if r == 'relay')
        for i in range(self.n_agents):
            x, y = self.agent_pos[i]
            confidence = self.confidence_map[x, y]
            if confidence > 0.8 and self.agent_roles[i] != 'verification':
                self.agent_roles[i] = 'verification'
            elif confidence > 0.6 and self.agent_roles[i] != 'relay' and relay_count < self.max_relays:
                self.agent_roles[i] = 'relay'
                relay_count += 1
            elif self.agent_roles[i] != 'scout' and confidence <= 0.6:
                self.agent_roles[i] = 'scout'

    def update_confidence_and_pheromones(self):
        self.confidence_map *= self.confidence_decay
        self.pheromone_map *= self.pheromone_decay
        for i in range(self.n_agents):
            x, y = self.agent_pos[i]
            rf_signal = max([self.get_rf_signal(self.agent_pos[i], s) for s in self.survivor_pos], default=0.0)
            confidence = rf_signal
            self.confidence_map[x, y] = min(1.0, self.confidence_map[x, y] + confidence)
            self.pheromone_map[x, y] = min(1.0, self.pheromone_map[x, y] + 0.05)

    def get_coverage(self):
        mapped_poi = np.count_nonzero(self.grid_status == self.MAP)
        return mapped_poi / self.n_poi if self.n_poi > 0 else 0.0

    def get_survivor_detection_rate(self):
        detected = sum(1 for s in self.survivor_pos if self.confidence_map[s[0], s[1]] > 0.8)
        return detected / self.n_survivors if self.n_survivors > 0 else 0.0

    def step(self, action, i):
        org_x, org_y = self.agent_pos[i][0], self.agent_pos[i][1]
        reward = 0
        action = int(action)

        new_x, new_y = org_x, org_y
        if action == self.XM:
            new_x -= 1
        elif action == self.XP:
            new_x += 1
        elif action == self.YM:
            new_y -= 1
        elif action == self.YP:
            new_y += 1
        elif action == self.XMYM:
            new_x -= 1
            new_y -= 1
        elif action == self.XMYP:
            new_x -= 1
            new_y += 1
        elif action == self.XPYM:
            new_x += 1
            new_y -= 1
        elif action == self.XPYP:
            new_x += 1
            new_y += 1
        elif action == self.STAY:
            reward = 5 if self.agent_roles[i] == 'relay' and self.confidence_map[org_x, org_y] > 0.5 else -2
        else:
            raise ValueError(f"Invalid action={action}")

        if (0 <= new_x < self.x_size and 0 <= new_y < self.y_size and
            self.grid_status[new_x, new_y] not in [self.OBS]):
            collision = False
            for j in range(self.n_agents):
                if i != j and self.agent_pos[j][0] == new_x and self.agent_pos[j][1] == new_y:
                    collision = True
                    break
            if not collision:
                self.agent_pos[i] = [new_x, new_y]
                prev_status = self.grid_status[new_x, new_y]
                if prev_status == self.POI:
                    self.grid_status[new_x, new_y] = self.MAP
                    self.grid_counts[i][new_x, new_y] += 1
                    reward = 15
                elif prev_status == self.SURVIVOR:
                    self.grid_counts[i][new_x, new_y] += 1
                    reward = 50
                elif prev_status == self.MAP:
                    self.grid_counts[i][new_x, new_y] += 1
                    reward = -1
            else:
                self.grid_counts[i][org_x, org_y] += 1
                reward = -5
        else:
            self.grid_counts[i][org_x, org_y] += 1
            reward = -2

        self.agent_paths[i].append(self.agent_pos[i][:])
        if org_x == self.agent_pos[i][0] and org_y == self.agent_pos[i][1]:
            self.stuck_counts[i] += 1
        else:
            self.stuck_counts[i] = 0

        self.move_obstacles()
        self.move_survivors()
        self.update_confidence_and_pheromones()
        self.update_roles()
        self.grid_overlay()
        self.grid_status_history.append(self.grid_status.copy())

        mapped_poi = np.count_nonzero(self.grid_status == self.MAP)
        done = mapped_poi >= self.n_poi * coverage_threshold or self.get_survivor_detection_rate() >= 1.0
        info = {'survivor_detection_rate': self.get_survivor_detection_rate(),
                'coverage': self.get_coverage}

        return self.get_agent_obs(), reward, done, info

    def reset(self):
        self.init_grid()
        self.init_agent()
        self.init_survivors()
        self.confidence_map = np.zeros((self.x_size, self.y_size))
        self.pheromone_map = np.zeros((self.x_size, self.y_size))
        self.agent_roles = ['scout'] * self.n_agents
        self.grid_status_history = [self.grid_status.copy()]
        while any(self.grid_status[pos[0], pos[1]] in [self.OBS, self.SURVIVOR] for pos in self.agent_pos):
            self.init_grid()
            self.init_agent()
            self.init_survivors()
        self.grid_overlay()
        return self.get_agent_obs()

    def close(self):
        pass

# Animation function
def animate_grid(env, agent_paths, episode, save_path='grid_animation.mp4'):
    fig, ax = plt.subplots(figsize=(8, 8))
    cmap = plt.cm.get_cmap('tab10')

    def update(frame):
        ax.clear()
        grid = env.grid_status_history[frame]
        agent_pos = [path[frame] if frame < len(path) else path[-1] for path in agent_paths]

        # Create display grid
        display_grid = np.zeros((env.x_size, env.y_size))
        display_grid[grid == env.OBS] = -1  # Obstacles
        display_grid[grid == env.MAP] = 1   # Mapped
        display_grid[grid == env.SURVIVOR] = 2  # Survivors

        # Overlay agents
        for i, (x, y) in enumerate(agent_pos):
            display_grid[x, y] = 3 + i  # Unique value for each agent

        # Plot
        cax = ax.imshow(display_grid, cmap=cmap, vmin=-1, vmax=3 + env.n_agents)
        ax.set_xticks(np.arange(env.x_size))
        ax.set_yticks(np.arange(env.y_size))
        ax.set_title(f'Episode {episode + 1}, Step {frame + 1}')
        ax.grid(True)

        # Legend
        from matplotlib.lines import Line2D
        legend_elements = [
            Line2D([0], [0], color=cmap(-1), lw=4, label='Obstacle'),
            Line2D([0], [0], color=cmap(1), lw=4, label='Mapped'),
            Line2D([0], [0], color=cmap(2), lw=4, label='Survivor'),
        ]
        for i inNoun: range(env.n_agents):
            legend_elements.append(Line2D([0], [0], color=cmap(3 + i), lw=4, label=f'Agent {i + 1}'))
        ax.legend(handles=legend_elements, loc='upper right')

        return cax,

    ani = FuncAnimation(fig, update, frames=len(env.grid_status_history), interval=200, blit=False)
    ani.save(save_path, writer='ffmpeg')
    plt.close()

# Training parameters
size = 10
fov = 3
n_agents = 3
n_survivors = 5
train_episodes = 2000
max_steps = size * size * 3
batch_size = 32
gamma = 0.99
buffer_limit = 50000
log_interval = 200
max_epsilon = 0.9
min_epsilon = 0.1
warm_up_steps = 1000
update_iter = 10
coverage_threshold = 0.95
lr = 0.001

# Set seeds
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# Initialize environment and models
env = GridMultiAgent(x_size=size, y_size=size, fov_x=fov, fov_y=fov, n_agents=n_agents, n_survivors=n_survivors)
memory = ReplayBuffer(buffer_limit)
q = QNet(env.observation_space, env.action_space, fov_x=fov, fov_y=fov)
q_target = QNet(env.observation_space, env.action_space, fov_x=fov, fov_y=fov)
q_target.load_state_dict(q.state_dict())

agent_obs_size = (fov * fov) + 1 + 1 + 1 + 3
global_state_size = (agent_obs_size * n_agents) + (size * size) + (size * size) + (size * size)
q_central = QCentral(global_state_size, n_agents, env.action_space[0].n)
q_central_target = QCentral(global_state_size, n_agents, env.action_space[0].n)
q_central_target.load_state_dict(q_central.state_dict())

optimizer = optim.Adam(q.parameters(), lr=lr)
optimizer_central = optim.Adam(q_central.parameters(), lr=lr)

# Metrics
time_steps, epsilons, coverage, survivor_detection, total_reward, all_agent_paths, all_agent_roles = [], [], [], [], [], [], []

# Training loop
for episode in tqdm(range(train_episodes), desc="Training Episodes"):
    score = np.zeros(n_agents)
    state = env.reset()
    done = False
    episode_step = 0
    epsilon = max(min_epsilon, max_epsilon - (max_epsilon - min_epsilon) * (episode / (0.8 * train_episodes)))
    episode_paths = [[] for _ in range(n_agents)]
    episode_roles = [[] for _ in range(n_agents)]

    while not done and episode_step < max_steps:
        state_np = np.array(state, dtype=np.float32)
        confidence_map_input = state_np[:, 9:10].reshape(1, n_agents, 1, 1)
        pheromone_map_input = state_np[:, 10:11].reshape(1, n_agents, 1, 1)

        global_state = env.get_global_state()

        action = q.sample_action(
            torch.tensor(state_np).unsqueeze(0),
            torch.tensor(confidence_map_input).float(),
            torch.tensor(pheromone_map_input).float(),
            epsilon, env.agent_roles, env.grid_status, env.agent_pos, env.x_size, env.y_size, env.stuck_counts
        ).data.cpu().numpy()

        rewards = np.zeros(n_agents)
        next_state = state
        for agent_i in env.idx_agents:
            if done:
                break
            step_state, reward, step_done, info = env.step(action[agent_i], agent_i)
            rewards[agent_i] = reward
            done = done or step_done
            next_state = step_state
            episode_paths[agent_i].append(env.agent_pos[agent_i][:])
            episode_roles[agent_i].append(env.agent_roles[agent_i])

        next_global_state = env.get_global_state()
        memory.put((state, action, rewards, next_state, [done] * n_agents, global_state, next_global_state))
        score += rewards
        state = next_state
        episode_step += 1

    if memory.size() > warm_up_steps:
        train(q, q_target, q_central, q_central_target, memory, optimizer, optimizer_central, gamma, batch_size, update_iter)

    if episode % log_interval == 0:
        q_target.load_state_dict(q.state_dict())
        q_central_target.load_state_dict(q_central.state_dict())
        animate_grid(env, episode_paths, episode, save_path=f'grid_animation_episode_{episode}.mp4')

    all_agent_paths.append([path[:] for path in episode_paths])
    all_agent_roles.append([roles[:] for roles in episode_roles])
    time_steps.append(episode_step)
    epsilons.append(epsilon)
    coverage.append(env.get_coverage())
    survivor_detection.append(info['survivor_detection_rate'])
    total_reward.append(score.sum())

    if episode % log_interval == 0:
        print(f'//Episode {episode+1}// Epsilon: {epsilon:.3f}, Steps: {episode_step}, '
              f'Coverage (%): {coverage[-1]:.3f}, Survivor Detection (%): {survivor_detection[-1]:.3f}, '
              f'Total Reward: {total_reward[-1]:.2f}')

# Install required packages in Colab
try:
    import ipywidgets
    import ffmpeg
except ImportError:
    !pip install ipywidgets
    !pip install ffmpeg-python
    !apt-get install ffmpeg

# Enable ipywidgets in Colab
from google.colab import output
output.enable_custom_widget_manager()

# Required imports for visualization
from IPython.display import HTML, display
import ipywidgets as widgets
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np

# Set Matplotlib backend
%matplotlib inline

def animate_drone_paths(episode_idx, all_agent_paths, all_agent_roles, env, n_agents):
    """
    Animate drone paths for a given episode.

    Args:
        episode_idx (int): Index of the episode to visualize.
        all_agent_paths (list): List of agent paths for all episodes.
        all_agent_roles (list): List of agent roles for all episodes.
        env (GridMultiAgent): Environment object with grid_status, confidence_map, etc.
        n_agents (int): Number of agents.

    Returns \nReturns:
        HTML: Animation as HTML for Colab display, or empty HTML if data is invalid.
    """
    if not all_agent_paths or episode_idx >= len(all_agent_paths):
        print(f"Episode {episode_idx + 1} data not available.")
        return HTML("")

    paths = all_agent_paths[episode_idx]
    roles = all_agent_roles[episode_idx]

    # Validate inputs
    if not paths or not all(paths) or len(paths) != n_agents:
        print(f"Invalid paths data for episode {episode_idx + 1}.")
        return HTML("")
    if not roles or len(roles) != n_agents:
        print(f"Invalid roles data for episode {episode_idx + 1}.")
        return HTML("")

    # Ensure environment has required attributes
    required_attrs = ['grid_status', 'confidence_map', 'survivor_pos', 'x_size', 'y_size', 'OBS']
    for attr in required_attrs:
)...
```

### Key Modifications for Enhanced Scout Exploration

1. **Increased Exploration Bonus**:
   - In `QNet.sample_action`, the `exploration_bonus` for scouts was increased from 15.0 to 30.0. This makes unmapped (POI) cells more attractive during Q-value adjustment, encouraging scouts to prioritize new areas.

2. **Stronger Pheromone Penalty**:
   - The `pheromone_scale` was increased from 0.1 to 0.3, amplifying the penalty for actions leading to high-pheromone areas. This discourages scouts from revisiting heavily trafficked paths.

3. **Adaptive Epsilon for Scouts**:
   - Added `scout_epsilon_boost = 0.3`, which increases the exploration probability for scouts when they are stuck (`stuck_counts > 3`). This helps scouts escape local traps by taking random actions more frequently.

4. **Unvisited Area Preference**:
   - Modified the random action selection for scouts to prioritize actions leading to POI cells (`poi_actions`). If no POI actions are available, it falls back to any valid action. This ensures scouts move toward unmapped areas when exploring randomly.

5. **Lowered Stuck Threshold**:
   - Reduced the `stuck_counts` threshold for scouts from 5 to 3 in the epsilon-greedy condition. This triggers exploration sooner when scouts are not moving, helping them avoid repetitive behavior.

### Other Fixes and Improvements

- **Fixed Typo in Code**:
  - Corrected `self.c+a` to `self.cnn` in the `QNet.__init__` method, which was a syntax error in the original code.
  - Fixed `random.shuffle(durations)` to `random.shuffle(directions)` in `GridMultiAgent.move_survivors`, correcting a variable name error.

- **Improved Robustness**:
  - Added input validation in `animate_drone_paths` to handle cases where paths or roles data might be incomplete or invalid, preventing crashes during visualization.
  - Ensured the environment's attributes are checked before use in the animation function.

- **Preserved Original Structure**:
  - All other components, including the training loop, critic, replay buffer, and visualization, remain unchanged to maintain compatibility with the original code's functionality.

### Expected Outcomes

These changes should result in scout drones covering more of the grid, as they:
- Are more strongly incentivized to visit unmapped areas via the higher exploration bonus.
- Avoid revisiting paths due to the increased pheromone penalty.
- Explore randomly more often when stuck or in familiar areas, thanks to the adaptive epsilon and lower stuck threshold.
- Prioritize POI cells during random exploration, maximizing coverage.

You can run this code in a Colab environment with the required packages (`ipywidgets`, `ffmpeg-python`, and `ffmpeg`) installed, as specified in the code. The interactive widget will allow you to visualize the drone paths for each episode, where you should observe scouts exploring a larger portion of the grid compared to the original implementation.

If you need further tweaks (e.g., adjusting the bonus/penalty values, modifying the reward structure, or enhancing other agent roles), let me know!

In [None]:
import collections
import random
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from gym import spaces
import gym
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import uuid

class MultiAgentActionSpace(list):
    def __init__(self, agents_action_space):
        for x in agents_action_space:
            assert isinstance(x, gym.spaces.Space)
        super().__init__(agents_action_space)
        self._agents_action_space = agents_action_space

    def sample(self):
        return [agent_action_space.sample() for agent_action_space in self._agents_action_space]

class MultiAgentObservationSpace(list):
    def __init__(self, agents_observation_space):
        for x in agents_observation_space:
            assert isinstance(x, gym.spaces.Space)
        super().__init__(agents_observation_space)
        self._agents_observation_space = agents_observation_space

    def sample(self):
        return [agent_observation_space.sample() for agent_observation_space in self._agents_observation_space]

class ReplayBuffer:
    def __init__(self, buffer_limit):
        self.buffer = collections.deque(maxlen=buffer_limit)

    def put(self, transition):
        self.buffer.append(transition)

    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst, global_s_lst, global_s_prime_lst = [], [], [], [], [], [], []

        for transition in mini_batch:
            s, a, r, s_prime, done, global_s, global_s_prime = transition
            s_lst.append(s)
            a_lst.append(a)
            r_lst.append(r)
            s_prime_lst.append(s_prime)
            done_mask_lst.append((np.ones(len(done)) - done).tolist())
            global_s_lst.append(global_s)
            global_s_prime_lst.append(global_s_prime)

        return (torch.tensor(np.array(s_lst), dtype=torch.float),
                torch.tensor(np.array(a_lst), dtype=torch.long),
                torch.tensor(np.array(r_lst), dtype=torch.float),
                torch.tensor(np.array(s_prime_lst), dtype=torch.float),
                torch.tensor(np.array(done_mask_lst), dtype=torch.float),
                torch.tensor(np.array(global_s_lst), dtype=torch.float),
                torch.tensor(np.array(global_s_prime_lst), dtype=torch.float))

    def size(self):
        return len(self.buffer)

class QNet(nn.Module):
    def __init__(self, observation_space, action_space, fov_x=3, fov_y=3):
        super(QNet, self).__init__()
        self.num_agents = len(observation_space)
        self.fov_x, self.fov_y = fov_x, fov_y
        self.n_actions = action_space[0].n
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * fov_x * fov_y, 64),
            nn.ReLU()
        )
        self.non_spatial_size = 1 + 3
        self.fc_non_spatial = nn.Sequential(
            nn.Linear(self.non_spatial_size, 16),
            nn.ReLU()
        )
        self.fc_shared = nn.Sequential(
            nn.Linear(64 + 16, 64),
            nn.ReLU()
        )
        self.heads = nn.ModuleList([nn.Linear(64, self.n_actions) for _ in range(self.num_agents)])

    def forward(self, obs):
        batch_size = obs.shape[0]
        spatial_obs = torch.zeros(batch_size, self.num_agents, 1, self.fov_x, self.fov_y, device=obs.device)
        for agent_i in range(self.num_agents):
            grid_status = obs[:, agent_i, :self.fov_x * self.fov_y].reshape(-1, self.fov_x, self.fov_y)
            spatial_obs[:, agent_i, 0] = grid_status
        spatial_out = torch.zeros(batch_size, self.num_agents, 64, device=obs.device)
        for agent_i in range(self.num_agents):
            spatial_out[:, agent_i] = self.cnn(spatial_obs[:, agent_i])
        non_spatial = obs[:, :, -self.non_spatial_size:]
        non_spatial_out = self.fc_non_spatial(non_spatial.view(-1, self.non_spatial_size)).view(batch_size, self.num_agents, 16)
        combined = torch.cat([spatial_out, non_spatial_out], dim=2)
        shared_out = self.fc_shared(combined)
        q_values = [self.heads[agent_i](shared_out[:, agent_i]).unsqueeze(1) for agent_i in range(self.num_agents)]
        return torch.cat(q_values, dim=1)

    def sample_action(self, obs, epsilon, roles, grid_status, agent_pos, x_size, y_size, stuck_counts, visited_counts):
        out = self.forward(obs)
        actions = torch.zeros((out.shape[1],), dtype=torch.long)
        exploration_bonus = 50.0
        scout_epsilon_boost = 0.4
        for agent_i in range(out.shape[1]):
            agent_epsilon = epsilon
            if roles[agent_i] == 'scout' and stuck_counts[agent_i] > 2:
                agent_epsilon = min(1.0, epsilon + scout_epsilon_boost)
            if random.random() < agent_epsilon:
                if roles[agent_i] == 'relay':
                    actions[agent_i] = 8
                else:
                    valid_actions = []
                    poi_actions = []
                    low_visit_actions = []
                    x, y = agent_pos[agent_i]
                    for a in range(8):
                        new_x, new_y = x, y
                        if a == 0: new_x -= 1
                        elif a == 1: new_x += 1
                        elif a == 2: new_y -= 1
                        elif a == 3: new_y += 1
                        elif a == 4: new_x, new_y = x-1, y-1
                        elif a == 5: new_x, new_y = x-1, y+1
                        elif a == 6: new_x, new_y = x+1, y-1
                        elif a == 7: new_x, new_y = x+1, y+1
                        if 0 <= new_x < x_size and 0 <= new_y < y_size and grid_status[new_x, new_y] in [0, 3]:
                            valid_actions.append(a)
                            if grid_status[new_x, new_y] == 0:
                                poi_actions.append(a)
                            if visited_counts[new_x, new_y] < 2:
                                low_visit_actions.append(a)
                    actions[agent_i] = random.choice(low_visit_actions or poi_actions or valid_actions) if valid_actions else 8
            else:
                q_values = out[0, agent_i].clone()
                x, y = agent_pos[agent_i]
                for a in range(8):
                    new_x, new_y = x, y
                    if a == 0: new_x -= 1
                    elif a == 1: new_x += 1
                    elif a == 2: new_y -= 1
                    elif a == 3: new_y += 1
                    elif a == 4: new_x, new_y = x-1, y-1
                    elif a == 5: new_x, new_y = x-1, y+1
                    elif a == 6: new_x, new_y = x+1, y-1
                    elif a == 7: new_x, new_y = x+1, y+1
                    if 0 <= new_x < x_size and 0 <= new_y < y_size:
                        if grid_status[new_x, new_y] == 0:
                            q_values[a] += exploration_bonus
                        elif grid_status[new_x, new_y] == 3:
                            q_values[a] += exploration_bonus * 2
                        if visited_counts[new_x, new_y] < 2:
                            q_values[a] += exploration_bonus * 0.5
                actions[agent_i] = q_values.argmax().item()
        return actions

class QCentral(nn.Module):
    def __init__(self, global_state_size, n_agents, n_actions):
        super().__init__()
        self.n_agents = n_agents
        self.n_actions = n_actions
        self.net = nn.Sequential(
            nn.Linear(global_state_size, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, n_agents * n_actions)
        )

    def forward(self, global_state):
        return self.net(global_state).view(-1, self.n_agents, self.n_actions)

def train(q, q_target, q_central, q_central_target, memory, optimizer, optimizer_central, gamma, batch_size, update_iter=10):
    for _ in range(update_iter):
        s, a, r, s_prime, done_mask, global_s, global_s_prime = memory.sample(batch_size)
        q_out = q(s)
        a = a.view(-1, q_out.shape[1], 1)
        q_a = q_out.gather(2, a).squeeze(-1)
        with torch.no_grad():
            next_actions = q(s_prime).argmax(dim=2, keepdim=True)
            max_q_prime = q_central_target(global_s_prime).gather(2, next_actions).squeeze(-1)
            target = r + gamma * max_q_prime * done_mask
        loss_actor = F.smooth_l1_loss(q_a, target.detach())
        optimizer.zero_grad()
        loss_actor.backward()
        optimizer.step()
        q_central_out = q_central(global_s)
        q_central_a = q_central_out.gather(2, a).squeeze(-1)
        loss_central = F.smooth_l1_loss(q_central_a, target.detach())
        optimizer_central.zero_grad()
        loss_central.backward()
        optimizer_central.step()

class GridMultiAgent(gym.Env):
    metadata = {'render.modes': ['console']}
    XM, XP, YM, YP, XMYM, XMYP, XPYM, XPYP, STAY = range(9)
    OOE, OBS, POI, MAP, AGT, SURVIVOR = -2, -1, 0, 1, 2, 3

    def __init__(self, x_size=10, y_size=10, fov_x=3, fov_y=3, n_agents=3, n_survivors=5):
        super().__init__()
        self.x_size = x_size
        self.y_size = y_size
        self.n_agents = n_agents
        self.idx_agents = list(range(n_agents))
        self.n_survivors = n_survivors
        self.fov_x = fov_x
        self.fov_y = fov_y
        self.max_relays = n_agents // 2
        self.agent_roles = ['scout'] * n_agents
        self.stuck_counts = [0] * n_agents
        self.visited_counts = np.zeros((x_size, y_size))
        self.survivor_pos = []
        self.agent_paths = [[] for _ in range(n_agents)]
        self.obstacle_move_prob = 0.05
        self.survivor_move_prob = 0.05
        self.grid_status_history = []
        n_actions = 9
        self.action_space = MultiAgentActionSpace([spaces.Discrete(n_actions) for _ in range(n_agents)])
        obs_size = (fov_x * fov_y) + 1 + 3
        self.obs_low = np.concatenate([np.ones(fov_x * fov_y) * self.OOE, [0, 0, 0, 0]], dtype=np.float32)
        self.obs_high = np.concatenate([np.ones(fov_x * fov_y) * self.SURVIVOR, [1, 1, 1, 1]], dtype=np.float32)
        self.observation_space = MultiAgentObservationSpace([
            spaces.Box(self.obs_low, self.obs_high, dtype=np.float32) for _ in range(n_agents)
        ])
        self.init_grid()
        self.init_agent()
        self.init_survivors()

    def init_grid(self):
        self.grid_status = np.zeros((self.x_size, self.y_size))
        n_obstacle = random.randrange(0, int(self.x_size * self.y_size * 0.2))
        for _ in range(n_obstacle):
            x, y = random.randrange(1, self.x_size - 1), random.randrange(1, self.y_size - 1)
            self.grid_status[x, y] = self.OBS
        self.grid_counts = np.tile(self.grid_status, (self.n_agents, 1, 1)).reshape(self.n_agents, self.x_size, self.y_size)
        self.n_poi = self.x_size * self.y_size - np.count_nonzero(self.grid_status)
        self.grid_agents_status = copy.deepcopy(self.grid_status)
        self.grid_status_history = [self.grid_status.copy()]
        self.visited_counts = np.zeros((self.x_size, self.y_size))

    def init_agent(self):
        self.agent_pos = []
        self.agent_paths = [[] for _ in range(self.n_agents)]
        for i in range(self.n_agents):
            while True:
                x, y = random.randrange(0, self.x_size), random.randrange(0, self.y_size)
                if self.grid_agents_status[x, y] == self.POI:
                    self.agent_pos.append([x, y])
                    self.grid_agents_status[x, y] = self.AGT
                    self.agent_paths[i].append([x, y])
                    self.visited_counts[x, y] += 1
                    break
        self.stuck_counts = [0] * n_agents

    def init_survivors(self):
        self.survivor_pos = []
        for _ in range(self.n_survivors):
            while True:
                x, y = random.randrange(0, self.x_size), random.randrange(0, self.y_size)
                if self.grid_status[x, y] == self.POI and [x, y] not in self.agent_pos:
                    self.survivor_pos.append([x, y])
                    self.grid_status[x, y] = self.SURVIVOR
                    break

    def move_obstacles(self):
        new_grid = self.grid_status.copy()
        obstacle_pos = [(i, j) for i in range(self.x_size) for j in range(self.y_size) if self.grid_status[i, j] == self.OBS]
        for x, y in obstacle_pos:
            if random.random() < self.obstacle_move_prob:
                directions = [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]
                random.shuffle(directions)
                for dx, dy in directions:
                    new_x, new_y = x + dx, y + dy
                    if (0 <= new_x < self.x_size and 0 <= new_y < self.y_size and
                            new_grid[new_x, new_y] == self.POI and [new_x, new_y] not in self.agent_pos):
                        new_grid[x, y] = self.POI
                        new_grid[new_x, new_y] = self.OBS
                        break
        self.grid_status = new_grid

    def move_survivors(self):
        new_survivor_pos = []
        for x, y in self.survivor_pos:
            if random.random() < self.survivor_move_prob:
                directions = [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]
                random.shuffle(directions)
                moved = False
                for dx, dy in directions:
                    new_x, new_y = x + dx, y + dy
                    if (0 <= new_x < self.x_size and 0 <= new_y < self.y_size and
                            self.grid_status[new_x, new_y] == self.POI and
                            [new_x, new_y] not in self.agent_pos):
                        self.grid_status[x, y] = self.POI
                        self.grid_status[new_x, new_y] = self.SURVIVOR
                        new_survivor_pos.append([new_x, new_y])
                        moved = True
                        break
                if not moved:
                    new_survivor_pos.append([x, y])
            else:
                new_survivor_pos.append([x, y])
        self.survivor_pos = new_survivor_pos

    def grid_overlay(self):
        self.grid_agents_status = copy.deepcopy(self.grid_status)
        for i in range(self.n_agents):
            x, y = self.agent_pos[i]
            self.grid_agents_status[x, y] = self.AGT

    def get_rf_signal(self, agent_pos, survivor_pos):
        distance = np.sqrt((agent_pos[0] - survivor_pos[0])**2 + (agent_pos[1] - survivor_pos[1])**2)
        return min(1.0, 1.0 / (distance + 0.5))  # Adjusted for better detection

    def get_global_state(self):
        agent_obs = self.get_agent_obs()
        global_state = np.concatenate([
            np.array(agent_obs).flatten(),
            self.grid_status.flatten()
        ])
        return global_state.astype(np.float32)

    def get_agent_obs(self):
        self.agent_obs = []
        for agent in range(self.n_agents):
            single_obs = np.ones((self.fov_x, self.fov_y)) * self.OOE
            x, y = self.agent_pos[agent]
            for i in range(self.fov_x):
                for j in range(self.fov_y):
                    obs_x = x + (i - self.fov_x // 2)
                    obs_y = y + (j - self.fov_y // 2)
                    if 0 <= obs_x < self.x_size and 0 <= obs_y < self.y_size:
                        single_obs[i][j] = self.grid_agents_status[obs_x][obs_y]
            single_obs_flat = single_obs.flatten()
            rf_signals = [self.get_rf_signal(self.agent_pos[agent], s) for s in self.survivor_pos]
            rf_signal = max(rf_signals, default=0.0)
            role_encoding = {'scout': [1, 0, 0], 'verification': [0, 1, 0], 'relay': [0, 0, 1]}
            role_vec = role_encoding[self.agent_roles[agent]]
            obs = np.concatenate([single_obs_flat, [rf_signal], role_vec])
            self.agent_obs.append(obs.astype(np.float32))
        return self.agent_obs

    def update_roles(self):
        relay_count = sum(1 for r in self.agent_roles if r == 'relay')
        detected_survivors = sum(1 for s in self.survivor_pos if
                                 max(self.get_rf_signal([ax, ay], s) for ax, ay in self.agent_pos) > 0.6)
        for i in range(self.n_agents):
            rf_signals = [self.get_rf_signal(self.agent_pos[i], s) for s in self.survivor_pos]
            max_rf = max(rf_signals, default=0.0)
            if max_rf > 0.6 and self.agent_roles[i] != 'verification':
                self.agent_roles[i] = 'verification'
            elif detected_survivors < self.n_survivors and self.agent_roles[i] != 'relay' and relay_count < self.max_relays:
                self.agent_roles[i] = 'relay'
                relay_count += 1
            elif detected_survivors < self.n_survivors and self.agent_roles[i] != 'scout':
                self.agent_roles[i] = 'scout'

    def get_coverage(self):
        mapped_poi = np.count_nonzero(self.grid_status == self.MAP)
        return mapped_poi / self.n_poi if self.n_poi > 0 else 0.0

    def get_survivor_detection_rate(self):
        detected = 0
        for s in self.survivor_pos:
            for ax, ay in self.agent_pos:
                if self.get_rf_signal([ax, ay], s) > 0.6:
                    detected += 1
                    break
        return detected / self.n_survivors if self.n_survivors > 0 else 0.0

    def step(self, action, i):
        org_x, org_y = self.agent_pos[i][0], self.agent_pos[i][1]
        reward = 0
        action = int(action)
        new_x, new_y = org_x, org_y
        if action == self.XM:
            new_x -= 1
        elif action == self.XP:
            new_x += 1
        elif action == self.YM:
            new_y -= 1
        elif action == self.YP:
            new_y += 1
        elif action == self.XMYM:
            new_x -= 1
            new_y -= 1
        elif action == self.XMYP:
            new_x -= 1
            new_y += 1
        elif action == self.XPYM:
            new_x += 1
            new_y -= 1
        elif action == self.XPYP:
            new_x += 1
            new_y += 1
        elif action == self.STAY:
            rf_signals = [self.get_rf_signal(self.agent_pos[i], s) for s in self.survivor_pos]
            max_rf = max(rf_signals, default=0.0)
            reward = 5 if self.agent_roles[i] == 'relay' and max_rf > 0.5 else -2
        else:
            raise ValueError(f"Invalid action={action}")
        if (0 <= new_x < self.x_size and 0 <= new_y < self.y_size and
            self.grid_status[new_x, new_y] not in [self.OBS]):
            collision = False
            for j in range(self.n_agents):
                if i != j and self.agent_pos[j][0] == new_x and self.agent_pos[j][1] == new_y:
                    collision = True
                    break
            if not collision:
                self.agent_pos[i] = [new_x, new_y]
                self.visited_counts[new_x, new_y] += 1
                prev_status = self.grid_status[new_x, new_y]
                if prev_status == self.POI:
                    self.grid_status[new_x, new_y] = self.MAP
                    self.grid_counts[i][new_x, new_y] += 1
                    reward = 15
                elif prev_status == self.SURVIVOR:
                    self.grid_counts[i][new_x, new_y] += 1
                    reward = 100
                    detected = self.get_survivor_detection_rate() * self.n_survivors
                    reward += 50 * detected
                elif prev_status == self.MAP:
                    self.grid_counts[i][new_x, new_y] += 1
                    reward = -1
            else:
                self.grid_counts[i][org_x, org_y] += 1
                reward = -5
        else:
            self.grid_counts[i][org_x, org_y] += 1
            reward = -2
        self.agent_paths[i].append(self.agent_pos[i][:])
        if org_x == self.agent_pos[i][0] and org_y == self.agent_pos[i][1]:
            self.stuck_counts[i] += 1
        else:
            self.stuck_counts[i] = 0
        self.move_obstacles()
        self.move_survivors()
        self.update_roles()
        self.grid_overlay()
        self.grid_status_history.append(self.grid_status.copy())
        mapped_poi = np.count_nonzero(self.grid_status == self.MAP)
        done = mapped_poi >= self.n_poi * coverage_threshold or self.get_survivor_detection_rate() >= 1.0
        info = {'survivor_detection_rate': self.get_survivor_detection_rate(),
                'coverage': self.get_coverage()}
        return self.get_agent_obs(), reward, done, info

    def reset(self):
        self.init_grid()
        self.init_agent()
        self.init_survivors()
        self.agent_roles = ['scout'] * self.n_agents
        self.grid_status_history = [self.grid_status.copy()]
        while any(self.grid_status[pos[0], pos[1]] in [self.OBS, self.SURVIVOR] for pos in self.agent_pos):
            self.init_grid()
            self.init_agent()
            self.init_survivors()
        self.grid_overlay()
        return self.get_agent_obs()

    def close(self):
        pass

def animate_grid(env, agent_paths, episode, save_path='grid_animation_rf_only.mp4'):
    fig, ax = plt.subplots(figsize=(8, 8))
    cmap = plt.colormaps['tab10']
    def update(frame):
        ax.clear()
        grid = env.grid_status_history[frame]
        agent_pos = [path[frame] if frame < len(path) else path[-1] for path in agent_paths]
        display_grid = np.zeros((env.x_size, env.y_size))
        display_grid[grid == env.OBS] = -1
        display_grid[grid == env.MAP] = 1
        display_grid[grid == env.SURVIVOR] = 2
        for i, (x, y) in enumerate(agent_pos):
            display_grid[x, y] = 3 + i
        cax = ax.imshow(display_grid, cmap=cmap, vmin=-1, vmax=3 + env.n_agents)
        ax.set_xticks(np.arange(env.x_size))
        ax.set_yticks(np.arange(env.y_size))
        ax.set_title(f'Episode {episode + 1}, Step {frame + 1}')
        ax.grid(True)
        from matplotlib.lines import Line2D
        legend_elements = [
            Line2D([0], [0], color=cmap(-1), lw=4, label='Obstacle'),
            Line2D([0], [0], color=cmap(1), lw=4, label='Mapped'),
            Line2D([0], [0], color=cmap(2), lw=4, label='Survivor'),
        ]
        for i in range(env.n_agents):
            legend_elements.append(Line2D([0], [0], color=cmap(3 + i), lw=4, label=f'Agent {i + 1}'))
        ax.legend(handles=legend_elements, loc='upper right')
        return cax,
    ani = FuncAnimation(fig, update, frames=len(env.grid_status_history), interval=200, blit=False)
    ani.save(save_path, writer='ffmpeg')
    plt.close()

size = 10
fov = 3
n_agents = 3
n_survivors = 5
train_episodes = 2000
max_steps = size * size * 3
batch_size = 32
gamma = 0.99
buffer_limit = 50000
log_interval = 200
max_epsilon = 0.9
min_epsilon = 0.1
warm_up_steps = 1000
update_iter = 10
coverage_threshold = 0.95
lr = 0.001

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

env = GridMultiAgent(x_size=size, y_size=size, fov_x=fov, fov_y=fov, n_agents=n_agents, n_survivors=n_survivors)
memory = ReplayBuffer(buffer_limit)
q = QNet(env.observation_space, env.action_space, fov_x=fov, fov_y=fov)
q_target = QNet(env.observation_space, env.action_space, fov_x=fov, fov_y=fov)
q_target.load_state_dict(q.state_dict())
agent_obs_size = (fov * fov) + 1 + 3
global_state_size = (agent_obs_size * n_agents) + (size * size)
q_central = QCentral(global_state_size, n_agents, env.action_space[0].n)
q_central_target = QCentral(global_state_size, n_agents, env.action_space[0].n)
q_central_target.load_state_dict(q_central.state_dict())
optimizer = optim.Adam(q.parameters(), lr=lr)
optimizer_central = optim.Adam(q_central.parameters(), lr=lr)
time_steps, epsilons, coverage, survivor_detection, total_reward, all_agent_paths, all_agent_roles = [], [], [], [], [], [], []

for episode in tqdm(range(train_episodes), desc="Training Episodes"):
    score = np.zeros(n_agents)
    state = env.reset()
    done = False
    episode_step = 0
    epsilon = max(min_epsilon, max_epsilon - (max_epsilon - min_epsilon) * (episode / (0.8 * train_episodes)))
    episode_paths = [[] for _ in range(n_agents)]
    episode_roles = [[] for _ in range(n_agents)]
    while not done and episode_step < max_steps:
        state_np = np.array(state, dtype=np.float32)
        global_state = env.get_global_state()
        action = q.sample_action(
            torch.tensor(state_np).unsqueeze(0),
            epsilon, env.agent_roles, env.grid_status, env.agent_pos, env.x_size, env.y_size, env.stuck_counts, env.visited_counts
        ).data.cpu().numpy()
        rewards = np.zeros(n_agents)
        next_state = state
        for agent_i in env.idx_agents:
            if done:
                break
            step_state, reward, step_done, info = env.step(action[agent_i], agent_i)
            rewards[agent_i] = reward
            done = done or step_done
            next_state = step_state
            episode_paths[agent_i].append(env.agent_pos[agent_i][:])
            episode_roles[agent_i].append(env.agent_roles[agent_i])
        next_global_state = env.get_global_state()
        memory.put((state, action, rewards, next_state, [done] * n_agents, global_state, next_global_state))
        score += rewards
        state = next_state
        episode_step += 1
    if memory.size() > warm_up_steps:
        train(q, q_target, q_central, q_central_target, memory, optimizer, optimizer_central, gamma, batch_size, update_iter)
    if episode % log_interval == 0:
        q_target.load_state_dict(q.state_dict())
        q_central_target.load_state_dict(q_central.state_dict())
        animate_grid(env, episode_paths, episode, save_path=f'grid_animation_rf_only_episode_{episode}.mp4')
    all_agent_paths.append([path[:] for path in episode_paths])
    all_agent_roles.append([roles[:] for roles in episode_roles])
    time_steps.append(episode_step)
    epsilons.append(epsilon)
    coverage.append(env.get_coverage())
    survivor_detection.append(info['survivor_detection_rate'])
    total_reward.append(score.sum())
    if episode % log_interval == 0:
        print(f'//Episode {episode+1}// Epsilon: {epsilon:.3f}, Steps: {episode_step}, '
              f'Coverage (%): {coverage[-1]:.3f}, Survivor Detection (%): {survivor_detection[-1]:.3f}, '
              f'Total Reward: {total_reward[-1]:.2f}')

plt.figure(figsize=(12, 8))
plt.subplot(3, 1, 1)
plt.plot(coverage)
plt.title('Coverage Over Episodes (RF Only)')
plt.xlabel('Episode')
plt.ylabel('Coverage (%)')
plt.grid(True)
plt.subplot(3, 1, 2)
plt.plot(survivor_detection)
plt.title('Survivor Detection Rate Over Episodes (RF Only)')
plt.xlabel('Episode')
plt.ylabel('Detection Rate (%)')
plt.grid(True)
plt.subplot(3, 1, 3)
plt.plot(total_reward)
plt.title('Total Reward Over Episodes (RF Only)')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.grid(True)
plt.tight_layout()
plt.savefig('rf281_performance.png')
plt.close()

Training Episodes:   0%|          | 1/2000 [02:51<95:19:47, 171.68s/it]

//Episode 1// Epsilon: 0.900, Steps: 300, Coverage (%): 0.785, Survivor Detection (%): 0.000, Total Reward: 6893.00


Training Episodes:  10%|█         | 201/2000 [08:27<24:39:11, 49.33s/it]

//Episode 201// Epsilon: 0.800, Steps: 300, Coverage (%): 0.742, Survivor Detection (%): 0.200, Total Reward: 42566.00


Training Episodes:  20%|██        | 401/2000 [14:03<21:52:07, 49.24s/it]

//Episode 401// Epsilon: 0.700, Steps: 300, Coverage (%): 0.750, Survivor Detection (%): 0.600, Total Reward: 189465.00


Training Episodes:  30%|███       | 601/2000 [19:39<19:07:50, 49.23s/it]

//Episode 601// Epsilon: 0.600, Steps: 300, Coverage (%): 0.611, Survivor Detection (%): 0.600, Total Reward: 205804.00


Training Episodes:  38%|███▊      | 755/2000 [21:56<21:31,  1.04s/it]