In [None]:
import numpy as np
import random
import math
import matplotlib.pyplot as plt
import json
import os
from datetime import datetime
from sklearn.cluster import KMeans
from IPython.display import Image, display

# ==========================
# H√ÄM T√çNH TO√ÅN T·ªêC ƒê·ªò & TH·ªúI GIAN
# ==========================

def compute_vs(p1, p2, v_f, v_AUV):
    """T√≠nh v·∫≠n t·ªëc t·ªïng h·ª£p v_s gi·ªØa 2 v·ªã tr√≠ p1, p2"""
    x1, y1, z1 = p1
    x2, y2, z2 = p2
    Lx, Ly, Lz = x2 - x1, y2 - y1, z2 - z1
    L_mag = math.sqrt(Lx**2 + Ly**2 + Lz**2)
    if L_mag == 0:
        return v_AUV
    cos_beta = Lz / L_mag
    cos_beta = np.clip(cos_beta, -1, 1)
    beta = math.acos(cos_beta)
    if abs(cos_beta) < 1e-6:
        cos_beta = 1e-6
    inner = np.clip((v_f * cos_beta) / v_AUV, -1, 1)
    angle = beta + math.acos(inner)
    v_s = abs(math.cos(angle) * v_AUV / cos_beta)
    return v_s


def travel_time(path, coords, v_f, v_AUV):
    """T√≠nh t·ªïng th·ªùi gian di chuy·ªÉn theo ƒë∆∞·ªùng path"""
    total_time = 0.0
    for i in range(len(path) - 1):
        p1, p2 = coords[path[i]], coords[path[i + 1]]
        d = np.linalg.norm(np.array(p2) - np.array(p1))
        v_s = compute_vs(tuple(p1), tuple(p2), v_f, v_AUV)
        total_time += d / max(v_s, 1e-9)
    # quay l·∫°i ƒëi·ªÉm ƒë·∫ßu
    p1, p2 = coords[path[-1]], coords[path[0]]
    d = np.linalg.norm(np.array(p2) - np.array(p1))
    v_s = compute_vs(tuple(p1), tuple(p2), v_f, v_AUV)
    total_time += d / max(v_s, 1e-9)
    return total_time


# ==========================
# L·ªöP PSO CHO TSP
# ==========================

class ClusterTSP_PSO:
    def __init__(self, clusters, pso_params=None):
        self.clusters = clusters
        sorted_keys = sorted(clusters.keys(), key=lambda x: int(x))
        self.index_to_ch = [None]
        self.cluster_centers = [(0.0, 0.0, 0.0)]
        
        for k in sorted_keys:
            c = clusters[k]['center']
            self.cluster_centers.append(tuple(c))
            self.index_to_ch.append(clusters[k].get('cluster_head', None))

        self.n = len(self.cluster_centers)

        defaults = {
            'n_particles': 40,
            'max_iter': 200,
            'w_start': 0.9,
            'w_end': 0.4,
            'c1': 1.5,
            'c2': 1.5,
            'v_f': 1.2,
            'v_AUV': 3.0,
            'verbose': False
        }
        if pso_params:
            defaults.update(pso_params)
        self.params = defaults

    def create_particle(self):
        """T·∫°o m·ªôt particle (ƒë∆∞·ªùng ƒëi ng·∫´u nhi√™n)"""
        seq = list(range(1, self.n))
        random.shuffle(seq)
        return [0] + seq

    def get_swap_sequence(self, A, B):
        """L·∫•y chu·ªói swap ƒë·ªÉ bi·∫øn ƒë·ªïi t·ª´ A sang B"""
        seq = []
        temp = A.copy()
        for i in range(1, len(A)):
            if temp[i] != B[i]:
                j = temp.index(B[i])
                seq.append((i, j))
                temp[i], temp[j] = temp[j], temp[i]
        return seq

    def apply_velocity(self, position, velocity):
        """√Åp d·ª•ng velocity (swap operations) l√™n position"""
        new_pos = position.copy()
        for (i, j) in velocity:
            if i == 0 or j == 0:
                continue
            new_pos[i], new_pos[j] = new_pos[j], new_pos[i]
        return new_pos

    def fitness(self, particle):
        """T√≠nh fitness (1/time)"""
        t = travel_time(particle, self.cluster_centers, 
                       self.params['v_f'], self.params['v_AUV'])
        return 1.0 / (t + 1e-9)

    def evolve(self):
        """Thu·∫≠t to√°n PSO ch√≠nh"""
        n_particles = self.params['n_particles']
        max_iter = self.params['max_iter']
        w_start = self.params['w_start']
        w_end = self.params['w_end']
        c1 = self.params['c1']
        c2 = self.params['c2']

        # Kh·ªüi t·∫°o swarm
        swarm = [self.create_particle() for _ in range(n_particles)]
        velocities = [[] for _ in range(n_particles)]

        # T√≠nh fitness ban ƒë·∫ßu
        costs = [travel_time(p, self.cluster_centers, 
                           self.params['v_f'], self.params['v_AUV']) 
                for p in swarm]
        
        # Personal best
        pbest = [p.copy() for p in swarm]
        pbest_cost = list(costs)

        # Global best
        gbest_idx = np.argmin(pbest_cost)
        gbest = pbest[gbest_idx].copy()
        gbest_cost = pbest_cost[gbest_idx]

        # V√≤ng l·∫∑p ch√≠nh
        for t in range(max_iter):
            # C·∫≠p nh·∫≠t h·ªá s·ªë qu√°n t√≠nh
            w = w_start - (w_start - w_end) * (t / max_iter)

            for i in range(n_particles):
                xi = swarm[i]
                vi = velocities[i]

                # T·∫°o velocity m·ªõi
                v_new = []
                
                # Ph·∫ßn qu√°n t√≠nh
                n_keep = int(w * len(vi))
                v_new.extend(vi[:n_keep])

                # Ph·∫ßn cognitive (h∆∞·ªõng v·ªÅ pbest)
                if random.random() < c1:
                    seq_pb = self.get_swap_sequence(xi, pbest[i])
                    if seq_pb:
                        n_select = max(1, int(c1 * len(seq_pb)))
                        v_new.extend(random.sample(seq_pb, k=min(len(seq_pb), n_select)))

                # Ph·∫ßn social (h∆∞·ªõng v·ªÅ gbest)
                if random.random() < c2:
                    seq_gb = self.get_swap_sequence(xi, gbest)
                    if seq_gb:
                        n_select = max(1, int(c2 * len(seq_gb)))
                        v_new.extend(random.sample(seq_gb, k=min(len(seq_gb), n_select)))

                # C·∫≠p nh·∫≠t velocity v√† position
                velocities[i] = v_new
                new_x = self.apply_velocity(xi, v_new)
                swarm[i] = new_x

                # T√≠nh cost m·ªõi
                new_cost = travel_time(new_x, self.cluster_centers, 
                                     self.params['v_f'], self.params['v_AUV'])

                # C·∫≠p nh·∫≠t pbest
                if new_cost < pbest_cost[i]:
                    pbest[i] = new_x.copy()
                    pbest_cost[i] = new_cost

                    # C·∫≠p nh·∫≠t gbest
                    if new_cost < gbest_cost:
                        gbest = new_x.copy()
                        gbest_cost = new_cost

            if self.params['verbose'] and t % 20 == 0:
                print(f"  Iteration {t}: Best time = {gbest_cost:.4f}s")

        # Map indices sang cluster_head IDs
        mapped_path = ['O' if idx == 0 else self.index_to_ch[idx] for idx in gbest]
        
        return gbest, mapped_path, gbest_cost


# ==========================
# H√ÄM NƒÇNG L∆Ø·ª¢NG
# ==========================

def compute_energy(best_time):
    G, L, n = 100, 1024, 4
    P_t, P_r, P_idle, DR, DR_i = 1.6e-3, 0.8e-3, 0.1e-3, 4000, 1e6
    E_tx_MN = G * P_t * L / DR
    E_idle_MN = (best_time - G * L / DR) * P_idle
    E_total_MN = E_tx_MN + E_idle_MN
    E_rx_TN = G * P_r * L * n / DR
    E_tx_TN = G * P_t * L * n / DR_i
    E_idle_TN = (best_time - (G * L * n / DR) - (G * L * n / DR_i)) * P_idle
    E_total_TN = E_rx_TN + E_tx_TN + E_idle_TN
    E_idle_solo = (best_time - (G * L * n / DR_i)) * P_idle
    E_total_SOLO = E_tx_TN + E_idle_solo
    return {
        "Member": {"E_tx": E_tx_MN, "E_idle": E_idle_MN, "E_total": E_total_MN},
        "Target": {"E_rx": E_rx_TN, "E_tx": E_tx_TN, "E_idle": E_idle_TN, "E_total": E_total_TN},
        "Solo": {"E_tx": E_tx_TN, "E_idle": E_idle_solo, "E_total": E_total_SOLO}
    }


# ==========================
# H√ÄM PH√ÇN C·ª§M
# ==========================

def calculate_objective_function(nodes, labels, centers):
    numerator = 0
    for i in range(2):
        cluster_nodes = nodes[labels == i]
        if len(cluster_nodes) > 0:
            distances = np.linalg.norm(cluster_nodes - centers[i], axis=1)
            numerator += np.sum(distances)
    denominator = np.linalg.norm(centers[0] - centers[1])
    if denominator == 0:
        return float('inf')
    return numerator / denominator


def check_subgroup_threshold(nodes, r_sen):
    if len(nodes) <= 1:
        return True
    max_distance = 0
    for i in range(len(nodes)):
        for j in range(i + 1, len(nodes)):
            dist = np.linalg.norm(nodes[i] - nodes[j])
            max_distance = max(max_distance, dist)
    return max_distance <= r_sen


def kmeans_with_best_T(nodes, N=30):
    best_T = float('inf')
    best_labels = None
    best_centers = None
    for _ in range(N):
        kmeans = KMeans(n_clusters=2, n_init=1)
        labels = kmeans.fit_predict(nodes)
        centers = kmeans.cluster_centers_
        T = calculate_objective_function(nodes, labels, centers)
        if T < best_T:
            best_T = T
            best_labels = labels.copy()
            best_centers = centers.copy()
    return best_labels, best_centers, best_T


def cluster_split(nodes, node_ids, node_data=None, r_sen=60, R=20, N=30, max_depth=10, depth=0):
    size_ok = len(nodes) <= R
    distance_ok = check_subgroup_threshold(nodes, r_sen)
    
    if (size_ok and distance_ok) or depth >= max_depth:
        center = np.mean(nodes, axis=0)
        return [{
            "node_ids": node_ids,
            "nodes": nodes,
            "center": center,
            "node_data": node_data if node_data else {}
        }]
    
    labels, centers, best_T = kmeans_with_best_T(nodes, N)
    clusters = []
    for i in range(2):
        sub_nodes = nodes[labels == i]
        sub_ids = [node_ids[j] for j in range(len(node_ids)) if labels[j] == i]
        sub_node_data = {}
        if node_data:
            for node_id in sub_ids:
                if node_id in node_data:
                    sub_node_data[node_id] = node_data[node_id]
        clusters += cluster_split(sub_nodes, sub_ids, sub_node_data, r_sen, R, N, max_depth, depth + 1)
    return clusters


def choose_cluster_head(cluster_info, node_data=None):
    nodes = cluster_info["nodes"]
    center = cluster_info["center"]
    node_ids = cluster_info["node_ids"]
    
    if node_data and len(node_data) > 0:
        max_energy = -1
        ch_id = node_ids[0]
        for nid in node_ids:
            if nid in node_data and "residual_energy" in node_data[nid]:
                if node_data[nid]["residual_energy"] > max_energy:
                    max_energy = node_data[nid]["residual_energy"]
                    ch_id = nid
        return ch_id
    else:
        distances = np.linalg.norm(nodes - center, axis=1)
        min_idx = np.argmin(distances)
        return node_ids[min_idx]


# ==========================
# H√ÄM QU·∫¢N L√ù NƒÇNG L∆Ø·ª¢NG
# ==========================

def update_energy(all_nodes, clusters, energy_report):
    for cid, cinfo in clusters.items():
        nodes = cinfo.get('nodes', [])
        ch = cinfo.get('cluster_head')
        
        if len(nodes) == 1:
            nid = nodes[0]
            if nid in all_nodes:
                all_nodes[nid]['residual_energy'] -= energy_report['Solo']['E_total']
                if all_nodes[nid]['residual_energy'] < 0:
                    all_nodes[nid]['residual_energy'] = 0.0
        else:
            for nid in nodes:
                if nid not in all_nodes:
                    continue
                if nid == ch:
                    all_nodes[nid]['residual_energy'] -= energy_report['Target']['E_total']
                else:
                    all_nodes[nid]['residual_energy'] -= energy_report['Member']['E_total']
                if all_nodes[nid]['residual_energy'] < 0:
                    all_nodes[nid]['residual_energy'] = 0.0


def remove_dead_nodes(all_nodes, clusters):
    dead = [nid for nid, info in list(all_nodes.items()) if info['residual_energy'] <= 0]
    for nid in dead:
        del all_nodes[nid]
    new_clusters = {}
    for cid, cinfo in clusters.items():
        alive_nodes = [nid for nid in cinfo.get('nodes', []) if nid in all_nodes]
        if alive_nodes:
            new_c = dict(cinfo)
            new_c['nodes'] = alive_nodes
            new_clusters[cid] = new_c
    return new_clusters, dead


def recluster(all_nodes, node_positions, r_sen=50, R=20):
    ids = sorted(list(all_nodes.keys()))
    if len(ids) == 0:
        return {}
    coords = np.array([node_positions[nid] for nid in ids])
    raw_clusters = cluster_split(coords, ids, all_nodes, r_sen=r_sen, R=R)
    clusters = {}
    for i, c in enumerate(raw_clusters):
        center = c['center'].tolist()
        node_ids = c['node_ids']
        ch = choose_cluster_head(c, all_nodes)
        clusters[i] = {'nodes': node_ids, 'center': center, 'cluster_head': ch}
    return clusters


# ==========================
# H√ÄM V·∫º BI·ªÇU ƒê·ªí REAL-TIME
# ==========================

def plot_realtime_analysis(outputs, total_nodes, INITIAL_ENERGY, filename, output_dir, cycle):
    """V·∫Ω bi·ªÉu ƒë·ªì theo th·ªùi gian th·ª±c sau m·ªói chu k·ª≥"""
    if len(outputs) == 0:
        return
    
    cycles = [o['cycle'] for o in outputs]
    alive_nodes = [o['alive_nodes'] for o in outputs]
    cumulative_dead = [total_nodes - a for a in alive_nodes]
    total_energy_remaining = [o['total_energy_remaining'] for o in outputs]
    travel_times = [o['best_time'] for o in outputs]
    
    first_death_cycle = None
    for o in outputs:
        if o['dead_nodes']:
            first_death_cycle = o['cycle']
            break
    
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle(f'Real-time Monitoring (PSO) - {filename} (Chu k·ª≥ {cycle})', 
                 fontsize=16, fontweight='bold')
    
    # Bi·ªÉu ƒë·ªì 1: S·ªë node s·ªëng
    axes[0, 0].plot(cycles, alive_nodes, 'b-o', linewidth=2, markersize=6)
    axes[0, 0].axhline(y=total_nodes * 0.9, color='r', linestyle='--', 
                       label=f'Ng∆∞·ª°ng 90% ({int(total_nodes * 0.9)} nodes)')
    axes[0, 0].set_xlabel('Chu k·ª≥', fontsize=12)
    axes[0, 0].set_ylabel('S·ªë node s·ªëng', fontsize=12)
    axes[0, 0].set_title('S·ªë node s·ªëng theo th·ªùi gian', fontsize=14, fontweight='bold')
    axes[0, 0].grid(True, alpha=0.3)
    axes[0, 0].legend()
    axes[0, 0].set_ylim(bottom=0, top=total_nodes * 1.1)
    
    # Bi·ªÉu ƒë·ªì 2: S·ªë node ch·∫øt t√≠ch l≈©y
    axes[0, 1].plot(cycles, cumulative_dead, 'r-s', linewidth=2, markersize=6)
    axes[0, 1].set_xlabel('Chu k·ª≥', fontsize=12)
    axes[0, 1].set_ylabel('S·ªë node ch·∫øt (t√≠ch l≈©y)', fontsize=12)
    axes[0, 1].set_title('S·ªë node ch·∫øt t√≠ch l≈©y theo th·ªùi gian', fontsize=14, fontweight='bold')
    axes[0, 1].grid(True, alpha=0.3)
    axes[0, 1].set_ylim(bottom=0)
    
    if first_death_cycle:
        axes[0, 1].axvline(x=first_death_cycle, color='orange', linestyle='--', 
                          label=f'Chu k·ª≥ ƒë·∫ßu ti√™n: {first_death_cycle}')
        axes[0, 1].legend()
    
    # Bi·ªÉu ƒë·ªì 3: T·ªïng nƒÉng l∆∞·ª£ng m·∫°ng
    axes[1, 0].plot(cycles, total_energy_remaining, 'g-^', linewidth=2, markersize=6)
    axes[1, 0].set_xlabel('Chu k·ª≥', fontsize=12)
    axes[1, 0].set_ylabel('T·ªïng nƒÉng l∆∞·ª£ng c√≤n l·∫°i (J)', fontsize=12)
    axes[1, 0].set_title('NƒÉng l∆∞·ª£ng t·ªïng th·ªÉ m·∫°ng', fontsize=14, fontweight='bold')
    axes[1, 0].grid(True, alpha=0.3)
    axes[1, 0].set_ylim(bottom=0)
    
    # Bi·ªÉu ƒë·ªì 4: Th·ªùi gian di chuy·ªÉn AUV
    axes[1, 1].plot(cycles, travel_times, 'm-d', linewidth=2, markersize=6)
    axes[1, 1].set_xlabel('Chu k·ª≥', fontsize=12)
    axes[1, 1].set_ylabel('Th·ªùi gian di chuy·ªÉn (s)', fontsize=12)
    axes[1, 1].set_title('Th·ªùi gian chu k·ª≥ AUV', fontsize=14, fontweight='bold')
    axes[1, 1].grid(True, alpha=0.3)
    
    if len(travel_times) > 0:
        avg_time = np.mean(travel_times)
        axes[1, 1].axhline(y=avg_time, color='r', linestyle='--', 
                          label=f'TB: {avg_time:.2f}s')
        axes[1, 1].legend()
    
    plt.tight_layout()
    
    latest_filename = os.path.join(output_dir, 
                                   f"realtime_pso_latest_{os.path.splitext(filename)[0]}.png")
    
    plt.savefig(latest_filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    return latest_filename


def save_realtime_results(outputs, meta_info, output_file):
    """L∆∞u k·∫øt qu·∫£ theo th·ªùi gian th·ª±c"""
    meta_info['outputs'] = outputs
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(meta_info, f, indent=4, ensure_ascii=False)


def plot_final_analysis(outputs, total_nodes, INITIAL_ENERGY, filename, output_dir):
    """V·∫Ω bi·ªÉu ƒë·ªì t·ªïng h·ª£p chi ti·∫øt cu·ªëi c√πng"""
    if len(outputs) == 0:
        return None
    
    cycles = [o['cycle'] for o in outputs]
    alive_nodes = [o['alive_nodes'] for o in outputs]
    cumulative_dead = [total_nodes - a for a in alive_nodes]
    total_energy_remaining = [o['total_energy_remaining'] for o in outputs]
    travel_times = [o['best_time'] for o in outputs]
    num_clusters = [o['num_clusters'] for o in outputs]
    
    first_death_cycle = None
    for o in outputs:
        if o['dead_nodes']:
            first_death_cycle = o['cycle']
            break
    
    fig = plt.figure(figsize=(20, 12))
    gs = fig.add_gridspec(3, 2, hspace=0.3, wspace=0.3)
    
    fig.suptitle(f'Ph√¢n t√≠ch t·ªïng h·ª£p (PSO) - {filename} ({len(outputs)} chu k·ª≥)', 
                 fontsize=18, fontweight='bold', y=0.995)
    
    # Bi·ªÉu ƒë·ªì 1: S·ªë node s·ªëng
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.plot(cycles, alive_nodes, 'b-o', linewidth=2, markersize=5)
    ax1.axhline(y=total_nodes * 0.9, color='r', linestyle='--', 
                label=f'Ng∆∞·ª°ng 90% ({int(total_nodes * 0.9)} nodes)')
    ax1.fill_between(cycles, 0, alive_nodes, alpha=0.3, color='blue')
    ax1.set_xlabel('Chu k·ª≥', fontsize=11)
    ax1.set_ylabel('S·ªë node s·ªëng', fontsize=11)
    ax1.set_title('S·ªë node s·ªëng theo th·ªùi gian', fontsize=13, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    ax1.legend()
    ax1.set_ylim(bottom=0, top=total_nodes * 1.1)
    
    # Bi·ªÉu ƒë·ªì 2: S·ªë node ch·∫øt t√≠ch l≈©y
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.plot(cycles, cumulative_dead, 'r-s', linewidth=2, markersize=5)
    ax2.fill_between(cycles, 0, cumulative_dead, alpha=0.3, color='red')
    ax2.set_xlabel('Chu k·ª≥', fontsize=11)
    ax2.set_ylabel('S·ªë node ch·∫øt (t√≠ch l≈©y)', fontsize=11)
    ax2.set_title('S·ªë node ch·∫øt t√≠ch l≈©y theo th·ªùi gian', fontsize=13, fontweight='bold')
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(bottom=0)
    
    if first_death_cycle:
        ax2.axvline(x=first_death_cycle, color='orange', linestyle='--', linewidth=2,
                   label=f'Chu k·ª≥ ƒë·∫ßu ti√™n: {first_death_cycle}')
        ax2.legend()
    
    # Bi·ªÉu ƒë·ªì 3: T·ªïng nƒÉng l∆∞·ª£ng m·∫°ng
    ax3 = fig.add_subplot(gs[1, 0])
    ax3.plot(cycles, total_energy_remaining, 'g-^', linewidth=2, markersize=5)
    ax3.fill_between(cycles, 0, total_energy_remaining, alpha=0.3, color='green')
    ax3.set_xlabel('Chu k·ª≥', fontsize=11)
    ax3.set_ylabel('T·ªïng nƒÉng l∆∞·ª£ng c√≤n l·∫°i (J)', fontsize=11)
    ax3.set_title('NƒÉng l∆∞·ª£ng t·ªïng th·ªÉ m·∫°ng', fontsize=13, fontweight='bold')
    ax3.grid(True, alpha=0.3)
    ax3.set_ylim(bottom=0)
    
    # Bi·ªÉu ƒë·ªì 4: Th·ªùi gian di chuy·ªÉn AUV
    ax4 = fig.add_subplot(gs[1, 1])
    ax4.plot(cycles, travel_times, 'm-d', linewidth=2, markersize=5)
    ax4.set_xlabel('Chu k·ª≥', fontsize=11)
    ax4.set_ylabel('Th·ªùi gian di chuy·ªÉn (s)', fontsize=11)
    ax4.set_title('Th·ªùi gian chu k·ª≥ AUV', fontsize=13, fontweight='bold')
    ax4.grid(True, alpha=0.3)
    
    if len(travel_times) > 0:
        avg_time = np.mean(travel_times)
        ax4.axhline(y=avg_time, color='r', linestyle='--', linewidth=2,
                   label=f'Trung b√¨nh: {avg_time:.2f}s')
        ax4.legend()
    
    # Bi·ªÉu ƒë·ªì 5: S·ªë c·ª•m theo chu k·ª≥
    ax5 = fig.add_subplot(gs[2, 0])
    ax5.plot(cycles, num_clusters, 'c-o', linewidth=2, markersize=5)
    ax5.fill_between(cycles, 0, num_clusters, alpha=0.3, color='cyan')
    ax5.set_xlabel('Chu k·ª≥', fontsize=11)
    ax5.set_ylabel('S·ªë c·ª•m', fontsize=11)
    ax5.set_title('S·ªë c·ª•m theo chu k·ª≥', fontsize=13, fontweight='bold')
    ax5.grid(True, alpha=0.3)
    ax5.set_ylim(bottom=0)
    
    # Bi·ªÉu ƒë·ªì 6: Th·ªëng k√™ t·ªïng h·ª£p
    ax6 = fig.add_subplot(gs[2, 1])
    ax6.axis('off')
    
    total_cycles = len(outputs)
    total_time = outputs[-1]['cumulative_time']
    avg_cycle_time = np.mean(travel_times)
    total_deaths = total_nodes - alive_nodes[-1]
    survival_rate = (alive_nodes[-1] / total_nodes) * 100
    
    stats_text = f"""
    TH·ªêNG K√ä T·ªîNG H·ª¢P (PSO)
    {'‚îÄ' * 40}
    
    T·ªïng s·ªë chu k·ª≥: {total_cycles}
    T·ªïng th·ªùi gian: {total_time:.2f}s ({total_time/3600:.2f} gi·ªù)
    Th·ªùi gian TB/chu k·ª≥: {avg_cycle_time:.2f}s
    
    Node ban ƒë·∫ßu: {total_nodes}
    Node c√≤n s·ªëng: {alive_nodes[-1]}
    Node ƒë√£ ch·∫øt: {total_deaths}
    T·ª∑ l·ªá s·ªëng s√≥t: {survival_rate:.2f}%
    
    Chu k·ª≥ ch·∫øt ƒë·∫ßu ti√™n: {first_death_cycle if first_death_cycle else 'N/A'}
    
    S·ªë c·ª•m trung b√¨nh: {np.mean(num_clusters):.1f}
    S·ªë c·ª•m min/max: {np.min(num_clusters)}/{np.max(num_clusters)}
    
    Th·ªùi gian chu k·ª≥ min: {np.min(travel_times):.2f}s
    Th·ªùi gian chu k·ª≥ max: {np.max(travel_times):.2f}s
    ƒê·ªô l·ªách chu·∫©n: {np.std(travel_times):.2f}s
    """
    
    ax6.text(0.1, 0.5, stats_text, fontsize=11, family='monospace',
             verticalalignment='center', bbox=dict(boxstyle='round', 
             facecolor='wheat', alpha=0.5))
    
    final_chart_filename = os.path.join(output_dir, 
                                        f"final_analysis_pso_{os.path.splitext(filename)[0]}.png")
    plt.savefig(final_chart_filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    return final_chart_filename


# ==========================
# MAIN v·ªõi Real-time Monitoring (PSO)
# ==========================

def main():
    input_dir = "/kaggle/input/input-cluster3/output_data_kmeans"
    output_dir = "/kaggle/working/output_pso_multicycle"
    os.makedirs(output_dir, exist_ok=True)

    files = [f for f in os.listdir(input_dir) if f.endswith('.json')]

    # Tham s·ªë PSO
    pso_params = {
        'n_particles': 40,
        'max_iter': 200,
        'w_start': 0.9,
        'w_end': 0.4,
        'c1': 1.5,
        'c2': 1.5,
        'v_f': 1.2,
        'v_AUV': 3.0,
        'verbose': False
    }

    INITIAL_ENERGY = 100.0

    for filename in files:
        input_path = os.path.join(input_dir, filename)
        print(f"\n{'='*70}")
        print(f"ƒêANG X·ª¨ L√ù FILE: {filename}")
        print(f"{'='*70}")
        
        with open(input_path, 'r') as f:
            clusters_in = json.load(f)

        node_positions = {}
        all_nodes = {}

        # Collect all node ids
        all_node_ids = set()
        for k, v in clusters_in.items():
            for nid in v.get('nodes', []):
                all_node_ids.add(nid)
            ch = v.get('cluster_head')
            if ch is not None:
                all_node_ids.add(ch)

        # Load node positions
        nodes_pos_file = "/kaggle/input/input-pos3/input_data_evenly_distributed/nodes_150.json"
        if os.path.exists(nodes_pos_file):
            try:
                with open(nodes_pos_file, 'r', encoding='utf-8') as f:
                    nodes_data = json.load(f)
                
                node_positions = {}
                for node in nodes_data:
                    node_id = node['id']
                    node_positions[node_id] = (node['x'], node['y'], node['z'])
                
                print(f"‚úì ƒê√£ load {len(node_positions)} node positions t·ª´ {os.path.basename(nodes_pos_file)}")
            except Exception as e:
                print(f"‚úó L·ªói khi ƒë·ªçc file v·ªã tr√≠: {e}")

        # Create approximate positions if needed
        if not node_positions:
            for k, v in clusters_in.items():
                center = tuple(v.get('center', (0.0, 0.0, 0.0)))
                for nid in v.get('nodes', []):
                    offset = np.random.normal(scale=5.0, size=3)
                    node_positions[nid] = tuple(np.array(center) + offset)
                ch = v.get('cluster_head')
                if ch is not None and ch not in node_positions:
                    node_positions[ch] = center
            print("‚ö† ƒê√£ t·∫°o v·ªã tr√≠ gi·∫£ l·∫≠p cho c√°c nodes")

        # Initialize energy
        for nid in list(all_node_ids):
            all_nodes[nid] = {
                'initial_energy': INITIAL_ENERGY, 
                'residual_energy': INITIAL_ENERGY
            }

        total_nodes = len(all_nodes)
        print(f"‚úì T·ªïng s·ªë node ban ƒë·∫ßu: {total_nodes}")

        # Initialize clusters
        clusters = {}
        for k, v in clusters_in.items():
            clusters[int(k)] = {
                'nodes': v.get('nodes', []), 
                'center': v.get('center', []), 
                'cluster_head': v.get('cluster_head')
            }

        # Prepare output file
        out_filename = os.path.join(output_dir, 
                                    f"multicycle_pso_result_{os.path.splitext(filename)[0]}.json")
        
        # Meta info
        meta_info = {
            'algorithm': 'PSO',
            'input_file': filename,
            'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
            'initial_total_nodes': total_nodes,
            'pso_params': pso_params,
            'cycles': 0,
            'total_operation_time': 0.0,
            'first_death_cycle': None,
            'first_death_time': None,
            'final_alive_nodes': 0,
            'survival_rate': 0.0
        }

        cycle = 0
        outputs = []
        cumulative_time = 0.0
        first_death_cycle = None
        first_death_time = None

        print(f"\n{'='*70}")
        print("B·∫ÆT ƒê·∫¶U M√î PH·ªéNG CHU K·ª≤ (PSO)")
        print(f"{'='*70}\n")

        # Main simulation loop
        while True:
            cycle += 1
            print(f"\n{'‚îÄ'*70}")
            print(f"CHU K·ª≤ {cycle}")
            print(f"{'‚îÄ'*70}")

            alive_ratio = len(all_nodes) / total_nodes if total_nodes > 0 else 0
            print(f"T·ªâ l·ªá node s·ªëng: {alive_ratio*100:.2f}% ({len(all_nodes)}/{total_nodes})")
            
            if alive_ratio < 0.9:
                print("\n‚ö† D·ª™NG: T·ªâ l·ªá node s·ªëng < 90%")
                break

            if len(clusters) == 0:
                print("\n‚ö† D·ª™NG: Kh√¥ng c√≤n c·ª•m n√†o")
                break

            # Run PSO
            print(f"ƒêang ch·∫°y PSO v·ªõi {len(clusters)} c·ª•m...", end=" ")
            pso = ClusterTSP_PSO(clusters, pso_params)
            best_indices, best_mapped_path, best_time = pso.evolve()
            print("‚úì")

            cumulative_time += best_time
            energy_report = compute_energy(best_time)

            # Update energies
            update_energy(all_nodes, clusters, energy_report)

            # Remove dead nodes
            clusters, dead_nodes = remove_dead_nodes(all_nodes, clusters)
            for d in dead_nodes:
                if d in node_positions:
                    del node_positions[d]

            # Calculate total energy
            total_energy_remaining = sum(node['residual_energy'] for node in all_nodes.values())

            # Track first death
            if dead_nodes and first_death_cycle is None:
                first_death_cycle = cycle
                first_death_time = cumulative_time

            # Log output
            output_entry = {
                'cycle': cycle,
                'num_clusters': len(clusters),
                'best_path_indices': best_indices,
                'best_path_node_ids': best_mapped_path,
                'best_time': best_time,
                'cumulative_time': cumulative_time,
                'dead_nodes': dead_nodes,
                'alive_nodes': len(all_nodes),
                'total_energy_remaining': total_energy_remaining
            }
            outputs.append(output_entry)

            # Print cycle summary
            print(f"‚îú‚îÄ S·ªë c·ª•m: {len(clusters)}")
            print(f"‚îú‚îÄ ƒê∆∞·ªùng ƒëi: {best_mapped_path}")
            print(f"‚îú‚îÄ Th·ªùi gian chu k·ª≥: {best_time:.2f}s")
            print(f"‚îú‚îÄ Th·ªùi gian t√≠ch l≈©y: {cumulative_time:.2f}s ({cumulative_time/3600:.2f}h)")
            print(f"‚îú‚îÄ NƒÉng l∆∞·ª£ng c√≤n l·∫°i: {total_energy_remaining:.2f}J")
            
            if dead_nodes:
                print(f"‚îî‚îÄ ‚ö† Node ch·∫øt: {len(dead_nodes)} node(s) - {dead_nodes}")
            else:
                print(f"‚îî‚îÄ ‚úì Kh√¥ng c√≥ node ch·∫øt")

            # ===== REAL-TIME UPDATES =====
            
            # 1. Save results to file
            meta_info.update({
                'cycles': cycle,
                'total_operation_time': cumulative_time,
                'first_death_cycle': first_death_cycle,
                'first_death_time': first_death_time,
                'final_alive_nodes': len(all_nodes),
                'survival_rate': (len(all_nodes)/total_nodes)*100
            })
            save_realtime_results(outputs, meta_info, out_filename)
            
            # 2. Plot real-time chart
            last_filename = plot_realtime_analysis(
                outputs, total_nodes, INITIAL_ENERGY, filename, output_dir, cycle
            )
            display(Image(filename=last_filename))

            # Recluster if nodes still alive
            if len(all_nodes) > 0:
                clusters = recluster(all_nodes, node_positions)
                for k, v in clusters.items():
                    clusters[k]['center'] = [float(x) for x in v['center']]

        # ===== K·∫æT TH√öC M√î PH·ªéNG =====
        
        print(f"\n{'='*70}")
        print("PH√ÇN T√çCH K·∫æT QU·∫¢ CU·ªêI C√ôNG (PSO)")
        print(f"{'='*70}")
        print(f"T·ªïng s·ªë chu k·ª≥: {cycle - 1}")
        print(f"T·ªïng th·ªùi gian ho·∫°t ƒë·ªông: {cumulative_time:.2f}s ({cumulative_time/3600:.2f} gi·ªù)")
        
        if first_death_cycle:
            print(f"\nChu k·ª≥ ƒë·∫ßu ti√™n c√≥ node ch·∫øt: {first_death_cycle}")
            print(f"Th·ªùi gian ƒë·∫øn l√∫c ch·∫øt ƒë·∫ßu ti√™n: {first_death_time:.2f}s ({first_death_time/3600:.2f} gi·ªù)")
            print(f"T·ª∑ l·ªá th·ªùi gian: {(first_death_time/cumulative_time)*100:.2f}% t·ªïng th·ªùi gian")
        else:
            print("\nKh√¥ng c√≥ node n√†o ch·∫øt trong qu√° tr√¨nh m√¥ ph·ªèng")
        
        print(f"\nS·ªë node c√≤n s·ªëng cu·ªëi c√πng: {len(all_nodes)}/{total_nodes}")
        print(f"S·ªë node ƒë√£ ch·∫øt: {total_nodes - len(all_nodes)}")
        print(f"T·ª∑ l·ªá s·ªëng s√≥t: {(len(all_nodes)/total_nodes)*100:.2f}%")
        
        # Th·ªëng k√™ v·ªÅ th·ªùi gian chu k·ª≥
        if outputs:
            cycle_times = [o['best_time'] for o in outputs]
            print(f"\nTh·ªùi gian chu k·ª≥ trung b√¨nh: {np.mean(cycle_times):.2f}s")
            print(f"Th·ªùi gian chu k·ª≥ ng·∫Øn nh·∫•t: {np.min(cycle_times):.2f}s")
            print(f"Th·ªùi gian chu k·ª≥ d√†i nh·∫•t: {np.max(cycle_times):.2f}s")
            print(f"ƒê·ªô l·ªách chu·∫©n: {np.std(cycle_times):.2f}s")
        
        # Th·ªëng k√™ v·ªÅ s·ªë node ch·∫øt m·ªói chu k·ª≥
        deaths_per_cycle = [len(o['dead_nodes']) for o in outputs if o['dead_nodes']]
        if deaths_per_cycle:
            print(f"\nS·ªë node ch·∫øt trung b√¨nh/chu k·ª≥ (khi c√≥ ch·∫øt): {np.mean(deaths_per_cycle):.2f}")
            print(f"S·ªë node ch·∫øt nhi·ªÅu nh·∫•t trong 1 chu k·ª≥: {np.max(deaths_per_cycle)}")
        
        print(f"{'='*70}")

        # Save final results
        meta_info.update({
            'cycles': cycle - 1,
            'total_operation_time': cumulative_time,
            'first_death_cycle': first_death_cycle,
            'first_death_time': first_death_time,
            'final_alive_nodes': len(all_nodes),
            'survival_rate': (len(all_nodes)/total_nodes)*100,
            'outputs': outputs
        })
        
        with open(out_filename, 'w', encoding='utf-8') as f:
            json.dump(meta_info, f, indent=4, ensure_ascii=False)

        print(f"\n‚úì K·∫øt qu·∫£ cu·ªëi c√πng ƒë√£ l∆∞u: {out_filename}")
        
        # Create final comprehensive chart
        if outputs:
            print("\nüìä ƒêang t·∫°o bi·ªÉu ƒë·ªì t·ªïng h·ª£p cu·ªëi c√πng...")
            final_chart = plot_final_analysis(outputs, total_nodes, INITIAL_ENERGY, 
                                              filename, output_dir)
            print(f"‚úì ƒê√£ l∆∞u bi·ªÉu ƒë·ªì t·ªïng h·ª£p: {os.path.basename(final_chart)}")
        
        print(f"\n{'='*70}")
        print(f"HO√ÄN TH√ÄNH X·ª¨ L√ù FILE: {filename}")
        print(f"{'='*70}\n")


if __name__ == '__main__':
    main()

SyntaxError: incomplete input (3988386910.py, line 554)