In [None]:
import networkx as nx
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import json
from collections import deque
import math
import os
from config import abs_file_path
from resource_allocation import allocate_resource_uniformly
#map_size = 1000 # 1000
# grid_size = 20# 15
endnum_grid_mapping = {25: 13, 50: 15, 100: 15, 200: 20, 400: 21}

color_map = {'endnode': '#10739E', 'repeater': '#B40504'}
def euclidean_distance(pos1, pos2):
    return np.sqrt((pos1[0] - pos2[0]) ** 2 + (pos1[1] - pos2[1]) ** 2)
def get_nearest_neighbors(graph, node, n_neighbors=2):
    node_pos = graph.nodes[node]['pos']
    distances = []

    for neighbor in graph.nodes:
        if neighbor != node:
            neighbor_pos = graph.nodes[neighbor]['pos']
            distance = euclidean_distance(node_pos, neighbor_pos)
            distances.append((neighbor, distance))

    distances.sort(key=lambda x: x[1])
    nearest_neighbors = [neighbor for neighbor, _ in distances[:n_neighbors]]
    return nearest_neighbors
def read_endnodes_init_grid_graph_without_edges(endnodes_graph_file, grid_size=15, step_size=(1000.0/15)):
    # Remove all edges
    # grid.remove_edges_from(G.edges())
    # Add all edges if the distance between two nodes is less than l_rr
    # Add all edges if the distance between two nodes is less than l_rr
    grid = nx.grid_2d_graph(grid_size, grid_size)

    # for node1 in grid.nodes:
    #     for node2 in grid.nodes:
    #         if node1 != node2:
    #             if math.sqrt((node1[0] - node2[0]) ** 2 + (node1[1] - node2[1]) ** 2) < l_rr:
    #                 grid.add_edge(node1, node2)

    # Calculate the intersection points' 2-D position in a map_sz x map_sz map

    intersection_points = []
    for node in grid.nodes:
        x = (node[0] + 0.5) * step_size
        y = (node[1] + 0.5) * step_size
        intersection_points.append((x, y))

    #print(intersection_points)

    # Add nodes to the graph
    for node, pos in zip(grid.nodes, intersection_points):
        grid.nodes[node]['pos'] = pos
        grid.nodes[node]['type'] = 'repeater'

    G = nx.Graph()

    with open(endnodes_graph_file, 'r') as f:
        endnodes_graph = json.load(f)
        nodes = endnodes_graph['nodes']

    for node in nodes:
        if node['type'] == 'endnode':
            pos = node['pos']
            num_qubits = node['num_qubits']
            G.add_node(node['id'], pos=pos, num_qubits=num_qubits, type='endnode')
    endnodes = [node for node in G.nodes if G.nodes[node]['type'] == 'endnode']

    # Add repeaters to G
    id_r = len(endnodes)
    for node, pos in zip(grid.nodes, intersection_points):
        G.add_node(id_r, pos=pos, type='repeater')
        id_r += 1

    # Connect endnodes to repeaters if the distance is less than l_er
    repeaters = [node for node in G.nodes if G.nodes[node]['type'] == 'repeater']

    # Return the graph and the endnodes
    return G, endnodes
def graph_plot(G):

    # draw the graph
    pos = nx.get_node_attributes(G, 'pos')
    node_colors = [color_map[G.nodes[n]['type']] for n in G.nodes()]
    edge_colors = [color_map[G.edges[e]['type']] for e in G.edges()]
    num_endnodes = len([n for n in G.nodes() if G.nodes[n]['type'] == 'endnode'])
    # only get labels for node which type is repeater

    # labels_qubits = {n: G.nodes[n]['num_qubits'] for n in G.nodes() if G.nodes[n]['type'] == 'repeater'}
    labels_id = {n: n for n in G.nodes()}
    # nx.draw(G, pos, with_labels=True, node_color=node_colors, edge_color=edge_colors, width=0.5, node_size=10, labels=labels_qubits)
    # nx.draw(G, pos, with_labels=False, node_color=node_colors, edge_color=edge_colors, width=0.5, node_size=10)
    label_endnodes = {n: n for n in G.nodes() if G.nodes[n]['type'] == 'endnode'}
    nx.draw(G, pos, with_labels=True, labels=label_endnodes, node_color=node_colors, edge_color=edge_colors,
            width=0.5, node_size=10)
    plt.show()
    # Pause the program until the plot is closed
    # plt.savefig('graph.png')
def read_endnodes_init_grid_graph_with_grid_edges(endnodes_graph_file, grid_size=15, step_size=(1000.0/15)):
    # Remove all edges
    # grid.remove_edges_from(G.edges())
    # Add all edges if the distance between two nodes is less than l_rr
    # Add all edges if the distance between two nodes is less than l_rr
    grid = nx.grid_2d_graph(grid_size, grid_size)

    # for node1 in grid.nodes:
    #     for node2 in grid.nodes:
    #         if node1 != node2:
    #             if math.sqrt((node1[0] - node2[0]) ** 2 + (node1[1] - node2[1]) ** 2) < l_rr:
    #                 grid.add_edge(node1, node2)

    # Calculate the intersection points' 2-D position in a map_sz x map_sz map

    intersection_points = []
    for node in grid.nodes:
        x = (node[0] + 0.5) * step_size
        y = (node[1] + 0.5) * step_size
        intersection_points.append((x, y))

    # print(intersection_points)

    # Add nodes to the graph
    for node, pos in zip(grid.nodes, intersection_points):
        grid.nodes[node]['pos'] = pos
        grid.nodes[node]['xcoord'] = pos[0]
        grid.nodes[node]['ycoord'] = pos[1]
        grid.nodes[node]['type'] = 'repeater'

    G = nx.Graph()

    with open(endnodes_graph_file, 'r') as f:
        endnodes_graph = json.load(f)
        nodes = endnodes_graph['nodes']

    for node in nodes:
        if node['type'] == 'endnode':
            pos = node['pos']

            num_qubits = node['num_qubits']
            G.add_node(node['id'], pos=pos, num_qubits=num_qubits, type='endnode', xcoord = pos[0], ycoord =pos[1])
    endnodes = [node for node in G.nodes if G.nodes[node]['type'] == 'endnode']

    grid_G_node_mappling = {}
    # Add repeaters to G
    id_r = len(endnodes)
    for node, pos in zip(grid.nodes, intersection_points):
        grid_G_node_mappling[node] = id_r
        G.add_node(id_r, pos=pos, type='repeater', xcoord = pos[0], ycoord =pos[1])
        id_r += 1

    # Add edges between repeaters in G to form a grid
    for edges in grid.edges:
        dis = ((grid.nodes[edges[0]]['pos'][0] - grid.nodes[edges[1]]['pos'][0]) ** 2 + (
                grid.nodes[edges[0]]['pos'][1] - grid.nodes[edges[1]]['pos'][1]) ** 2) ** 0.5
        G.add_edge(grid_G_node_mappling[edges[0]], grid_G_node_mappling[edges[1]], type='repeater', dis=dis)

    # print(f'Number of nodes in G: {len(G.nodes)}')
    # graph_plot(G)

    # Connect endnodes to repeaters if the distance is less than l_er
    repeaters = [node for node in G.nodes if G.nodes[node]['type'] == 'repeater']

    # Return the graph and the endnodes
    return G, endnodes

class QuantumRepeaterDeployment:
    def __init__(self, nx_graph, Lmax, leaf_nodes=None):
        self.nx_graph = nx_graph
        self.Lmax = Lmax
        self.leaf_nodes = leaf_nodes if leaf_nodes else [n for n in nx_graph.nodes if nx_graph.degree[n] == 1]

    def choose_centers(self):
        Vleaf = self.leaf_nodes + [n for n in self.nx_graph.nodes if self.nx_graph.degree[n] == 1]
        # Vaccess = [n for n in self.nx_graph.nodes if any(leaf in Vleaf for leaf in self.nx_graph.neighbors(n)) and self.nx_graph.nodes[n]['type'] == 'repeater']
        Vaccess = set()
        repeater_nodes =[n for n in self.nx_graph.nodes if self.nx_graph.nodes[n]['type'] == 'repeater' ]
        for n in repeater_nodes:
            for leaf in self.nx_graph.neighbors(n):
                if leaf in Vleaf:
                    Vaccess.add(n)
        C = set()
        Vcovered = set()

        for v in Vaccess:
            # vleaf = [n for n in self.nx_graph.neighbors(v) if n in Vleaf]
            # if any(self.get_distance(l1, l2) > self.Lmax for l1 in vleaf for l2 in vleaf):
            #     C.add(v)
            #     Vcovered.update(self.get_coverage(v))
            neighbors = list(self.nx_graph.neighbors(v))
            vleaf = [n for n in neighbors if n in Vleaf]
            exceeds_Lmax = False
            for l1 in vleaf:
                for l2 in vleaf:
                    distance = self.get_distance(l1, l2)
                    if distance > self.Lmax:
                        exceeds_Lmax = True
                        break
                if exceeds_Lmax:
                    break
            if exceeds_Lmax:
                C.add(v)
                Vcovered.update(self.get_coverage_new(v))
        # find end nodes that are not covered by the centers
        endnodes = [ n for n in self.nx_graph.nodes if self.nx_graph.nodes[n]['type'] == 'endnode']
        Vremaining = set(endnodes) - Vcovered
        remaining_repear = set([n for n in self.nx_graph.nodes if self.nx_graph.nodes[n]['type'] == 'repeater']) - C

        while Vremaining:
            max_coverage_node = None
            max_coverage_size = 0
            # get the first end node
            endnode = Vremaining.pop()
            Vremaining.add(endnode)
            # print endnode pos
            # print(self.nx_graph.nodes[endnode]['pos'])
            # get neighbors of end node
            candidates = list(self.nx_graph.neighbors(endnode))
            # candidates filter out the nodes that are endnode
            candidates = [n for n in candidates if self.nx_graph.nodes[n]['type'] == 'repeater']
            for node in candidates:
                ll = self.get_coverage_new(node)
                rr =   ll & Vremaining
                if not rr:
                    continue
                coverage_size = len(rr)
                if coverage_size > max_coverage_size:
                    # if self.nx_graph.nodes[node]['type'] == 'endnode':
                    #     continue
                    max_coverage_node = node
                    max_coverage_size = coverage_size


            v = max_coverage_node
            C.add(v)
            Vcovered.update(self.get_coverage_new(v))
            Vremaining = set(endnodes) - Vcovered

        return C

    def get_distance(self, node1, node2):

        return nx.shortest_path_length(self.nx_graph, source=node1, target=node2, weight='dis')

    def get_coverage_old(self, node):

        return {n for n in self.nx_graph.nodes if self.get_distance(node, n) <= self.Lmax}

    def get_coverage_new(self, node):
        # Get all nodes within Lmax distance from the given node using BFS
        visited = {node}
        queue = deque([(node, 0)])
        coverage = set()

        while queue:
            current_node, current_distance = queue.popleft()
            if current_distance <= self.Lmax:
                if self.nx_graph.nodes[current_node]['type'] == 'endnode':
                    coverage.add(current_node)
                for neighbor in self.nx_graph.neighbors(current_node):
                    if neighbor not in visited:
                        edge_distance = self.nx_graph[current_node][neighbor].get('dis', 1)
                        if current_distance + edge_distance <= self.Lmax:
                            visited.add(neighbor)
                            queue.append((neighbor, current_distance + edge_distance))

        return coverage

    def find_intermediate_nodes(self, centers):
        # MST = self.minimum_spanning_tree_with_intermediates_new(centers)
        # I = set()
        #
        # for edge in MST:
        #     nodes = self.get_nodes_on_edge(edge)
        #     node1 = nodes[0]
        #     for i in range(1, len(nodes)):
        #         node2 = nodes[i]
        #         if self.get_distance(node1, node2) > self.Lmax:
        #             I.add(nodes[i-1])
        #             node1 = nodes[i-1]
        # subgraph = self.nx_graph.copy()
        # nodes_to_avoid = [n for n in self.nx_graph.nodes if self.nx_graph.nodes[n]['type'] == 'endnode']
        # subgraph.remove_nodes_from(self.nodes_to_avoid)
        mst = nx.minimum_spanning_tree(self.nx_graph, weight='dis')
        I = set()
        for c1 in centers:
            for c2 in centers:
                edges_to_include = []
                if c1 != c2:
                    path = nx.shortest_path(mst, source=c1, target=c2)
                    for i in range(len(path) - 1):
                        edges_to_include.append((path[i], path[i + 1]))
                if edges_to_include:
                    nodes = [edge[0] for edge in edges_to_include] + [edges_to_include[-1][1]]
                    node1 = nodes[0]
                    # node1 = edges_to_include[0][0]
                    # for i in range(1, len(edges_to_include)):
                    #     node2 = edges_to_include[i][1]
                    #     if self.get_distance(node1, node2) > self.Lmax:
                    #         if node1 not in centers:
                    #             I.add(node1)
                    #         node1 = node2
                    for i in range(1, len(nodes)):
                        node2 = nodes[i]
                        if self.get_distance(node1, node2) > self.Lmax:
                            if nodes[i-1] not in centers and self.nx_graph.nodes[nodes[i-1]]['type'] == 'repeater':
                                I.add(nodes[i-1])
                            node1 = nodes[i-1]


        return I

    def minimum_spanning_tree(self, centers):
        subgraph = self.nx_graph.subgraph(centers)
        mst = nx.minimum_spanning_tree(subgraph)
        return list(mst.edges())

    def minimum_spanning_tree_with_intermediates(self, centers):
        complete_subgraph = self.nx_graph.subgraph(centers).copy()
        for center in centers:
            for node in self.nx_graph.nodes:
                if node not in centers:
                    complete_subgraph.add_edge(center, node, weight=self.get_distance(center, node))

        mst = nx.minimum_spanning_tree(complete_subgraph, weight='dis')
        return list(mst.edges())

    def minimum_spanning_tree_with_intermediates_new(self, centers):
        # Find the MST of the entire graph
        mst = nx.minimum_spanning_tree(self.nx_graph)

        # Extract the relevant edges and intermediate nodes to connect centers
        edges_to_include = []
        for c1 in centers:
            for c2 in centers:
                if c1 != c2:
                    path = nx.shortest_path(mst, source=c1, target=c2)
                    for i in range(len(path) - 1):
                        edges_to_include.append((path[i], path[i + 1]))

        return edges_to_include

    def get_nodes_on_edge(self, edge):
        # Return the nodes that form the edge
        return list(edge)
class QuantumRepeaterDeployment_1:
    def __init__(self, nx_graph, Lmax, leaf_nodes=None):
        self.nx_graph = nx_graph
        self.Lmax = Lmax
        self.leaf_nodes = leaf_nodes if leaf_nodes else [n for n in nx_graph.nodes if nx_graph.degree[n] == 1]

    def choose_centers(self):
        Vleaf = self.leaf_nodes + [n for n in self.nx_graph.nodes if self.nx_graph.degree[n] == 1]
        # Vaccess = [n for n in self.nx_graph.nodes if any(leaf in Vleaf for leaf in self.nx_graph.neighbors(n)) and self.nx_graph.nodes[n]['type'] == 'repeater']
        Vaccess = set()
        repeater_nodes =[n for n in self.nx_graph.nodes if self.nx_graph.nodes[n]['type'] == 'repeater' ]
        for n in repeater_nodes:
            for leaf in self.nx_graph.neighbors(n):
                if leaf in Vleaf:
                    Vaccess.add(n)
        C = set()
        Vcovered = set()

        for v in Vaccess:
            # vleaf = [n for n in self.nx_graph.neighbors(v) if n in Vleaf]
            # if any(self.get_distance(l1, l2) > self.Lmax for l1 in vleaf for l2 in vleaf):
            #     C.add(v)
            #     Vcovered.update(self.get_coverage(v))
            neighbors = list(self.nx_graph.neighbors(v))
            vleaf = [n for n in neighbors if n in Vleaf]
            exceeds_Lmax = False
            for l1 in vleaf:
                for l2 in vleaf:
                    distance = self.get_distance(l1, l2)
                    if distance > self.Lmax:
                        exceeds_Lmax = True
                        break
                if exceeds_Lmax:
                    break
            if exceeds_Lmax:
                if self.nx_graph.nodes[v]['type'] == 'repeater':
                    C.add(v)
                # C.add(v)
                Vcovered.update(self.get_coverage_new_2(v))
        # find end nodes that are not covered by the centers
        endnodes = [ n for n in self.nx_graph.nodes if self.nx_graph.nodes[n]['type'] == 'endnode']
        Vremaining = set(endnodes) - Vcovered
        remaining_repear = set([n for n in self.nx_graph.nodes if self.nx_graph.nodes[n]['type'] == 'repeater']) - C
        print(C,Vremaining)
        while Vremaining:
            max_coverage_node = None
            max_coverage_size = 0
            # get the first end node
            endnode = Vremaining.pop()
            Vremaining.add(endnode)
            # print endnode pos
            # print(self.nx_graph.nodes[endnode]['pos'])
            # get neighbors of end node
            candidates = list(self.nx_graph.neighbors(endnode))
            # candidates filter out the nodes that are endnode
            candidates = [n for n in candidates if self.nx_graph.nodes[n]['type'] == 'repeater']
            for node in candidates:
                ll = self.get_coverage_new_2(node)
                rr =   ll & Vremaining
                if not rr:
                    continue
                coverage_size = len(rr)
                if coverage_size > max_coverage_size:
                    # if self.nx_graph.nodes[node]['type'] == 'endnode':
                    #     continue
                    max_coverage_node = node
                    max_coverage_size = coverage_size
            # if  max_coverage_node is None:
            #     # just add closest repeater in the candidates
            #     min_dis = float('inf')
            #     max_coverage_node = None
            #     for node in candidates:
            #         dis = self.get_distance(endnode, node)
            #         if dis < min_dis:
            #             min_dis = dis
            #             max_coverage_node = node
            #     print(max_coverage_node)
            #     C.add(max_coverage_node)
            #     coverage = set()
            #     coverage.add(endnode)
            #     Vcovered.update(coverage)
            #     Vremaining = set(endnodes) - Vcovered
            #     continue


            v = max_coverage_node
            C.add(v)
            Vcovered.update(self.get_coverage_new_2(v))
            Vremaining = set(endnodes) - Vcovered

        return C

    def get_distance(self, node1, node2):

        return nx.shortest_path_length(self.nx_graph, source=node1, target=node2, weight='dis')

    def get_coverage_old(self, node):

        return {n for n in self.nx_graph.nodes if self.get_distance(node, n) <= self.Lmax}

    def get_coverage_new(self, node):
        # Get all nodes within Lmax distance from the given node using BFS
        visited = {node}
        queue = deque([(node, 0)])
        coverage = set()

        while queue:
            current_node, current_distance = queue.popleft()
            if current_distance <= self.Lmax:
                if self.nx_graph.nodes[current_node]['type'] == 'endnode':
                    coverage.add(current_node)
                for neighbor in self.nx_graph.neighbors(current_node):
                    if neighbor not in visited:
                        edge_distance = self.nx_graph[current_node][neighbor].get('dis', 1)
                        if current_distance + edge_distance <= self.Lmax:
                            visited.add(neighbor)
                            queue.append((neighbor, current_distance + edge_distance))

        return coverage
    def get_coverage_new_2(self, node):
        # get all negighbors of current node
        neighbors = list(self.nx_graph.neighbors(node))
        covered = set()
        for neighbor in neighbors:
            if self.nx_graph.nodes[neighbor]['type'] == 'endnode' and self.get_distance(node, neighbor) <= self.Lmax:
                covered.add(neighbor)
        return covered
                
        

    def find_intermediate_nodes(self, centers):
        # MST = self.minimum_spanning_tree_with_intermediates_new(centers)
        # I = set()
        #
        # for edge in MST:
        #     nodes = self.get_nodes_on_edge(edge)
        #     node1 = nodes[0]
        #     for i in range(1, len(nodes)):
        #         node2 = nodes[i]
        #         if self.get_distance(node1, node2) > self.Lmax:
        #             I.add(nodes[i-1])
        #             node1 = nodes[i-1]
        # subgraph = self.nx_graph.copy()
        # nodes_to_avoid = [n for n in self.nx_graph.nodes if self.nx_graph.nodes[n]['type'] == 'endnode']
        # subgraph.remove_nodes_from(self.nodes_to_avoid)
        mst = nx.minimum_spanning_tree(self.nx_graph, weight='dis')
        I = set()
        for c1 in centers:
            for c2 in centers:
                edges_to_include = []
                if c1 != c2:
                    path = nx.shortest_path(mst, source=c1, target=c2)
                    for i in range(len(path) - 1):
                        edges_to_include.append((path[i], path[i + 1]))
                if edges_to_include:
                    nodes = [edge[0] for edge in edges_to_include] + [edges_to_include[-1][1]]
                    node1 = nodes[0]
                    # node1 = edges_to_include[0][0]
                    # for i in range(1, len(edges_to_include)):
                    #     node2 = edges_to_include[i][1]
                    #     if self.get_distance(node1, node2) > self.Lmax:
                    #         if node1 not in centers:
                    #             I.add(node1)
                    #         node1 = node2
                    for i in range(1, len(nodes)):
                        node2 = nodes[i]
                        if self.get_distance(node1, node2) > self.Lmax:
                            if nodes[i-1] not in centers and self.nx_graph.nodes[nodes[i-1]]['type'] == 'repeater':
                                I.add(nodes[i-1])
                            node1 = nodes[i-1]


        return I

    def minimum_spanning_tree(self, centers):
        subgraph = self.nx_graph.subgraph(centers)
        mst = nx.minimum_spanning_tree(subgraph)
        return list(mst.edges())

    def minimum_spanning_tree_with_intermediates(self, centers):
        complete_subgraph = self.nx_graph.subgraph(centers).copy()
        for center in centers:
            for node in self.nx_graph.nodes:
                if node not in centers:
                    complete_subgraph.add_edge(center, node, weight=self.get_distance(center, node))

        mst = nx.minimum_spanning_tree(complete_subgraph, weight='dis')
        return list(mst.edges())

    def minimum_spanning_tree_with_intermediates_new(self, centers):
        # Find the MST of the entire graph
        mst = nx.minimum_spanning_tree(self.nx_graph)

        # Extract the relevant edges and intermediate nodes to connect centers
        edges_to_include = []
        for c1 in centers:
            for c2 in centers:
                if c1 != c2:
                    path = nx.shortest_path(mst, source=c1, target=c2)
                    for i in range(len(path) - 1):
                        edges_to_include.append((path[i], path[i + 1]))

        return edges_to_include

    def get_nodes_on_edge(self, edge):
        # Return the nodes that form the edge
        return list(edge)





folder_path = abs_file_path + "/dist/endnodes/map_size/"
files_list = os.listdir(folder_path)

# 

# drop all folders
files_list = [file for file in files_list if '.' in file]
print(files_list, len(files_list))



for file in files_list:

    file_path = folder_path + file
    print(file_path)
    map_size = int(file.split('-')[1])
    # map_size = 2000
    # grid_size = 15 #map_size_grid_size_map[map_size]
    # step_size = map_size / grid_size
    num_endnodes = 100#int(file.split('-')[1])
    grid_size = endnum_grid_mapping[num_endnodes]
    step_size = map_size / grid_size
    

    # file_path = abs_file_path + "/dist/endnodes/" + #"/dist/endnodes/endnodesLocs-200-8.json"
    g, endnodes = read_endnodes_init_grid_graph_with_grid_edges(file_path, grid_size=grid_size, step_size=step_size)
    nearest_neighbors = {endnode: get_nearest_neighbors(g, endnode,5) for endnode in endnodes}
    for endnode in endnodes:
        for neighbor in nearest_neighbors[endnode]:
            # get distance of an edges
            dis = ((g.nodes[endnode]['pos'][0] - g.nodes[neighbor]['pos'][0]) ** 2 + (
                    g.nodes[endnode]['pos'][1] - g.nodes[neighbor]['pos'][1]) ** 2) ** 0.5
            # print(dis)
            if dis < 100:
                g.add_edge(endnode, neighbor, type='endnode',dis = dis)
    endnodes = [node for node in g.nodes if g.nodes[node]['type'] == 'endnode']
    repeaters = [node for node in g.nodes if g.nodes[node]['type'] == 'repeater']

    # assert len(endnodes) + len(repeaters) == len(g.nodes)

    # Add edges to G for every pair of nodes
    # if the distance between two repeaters is less than l_rr
    for node1 in repeaters:
        for node2 in repeaters:
            if node1 != node2:
                dis = math.sqrt((g.nodes[node1]['pos'][0] - g.nodes[node2]['pos'][0]) ** 2 + (
                            g.nodes[node1]['pos'][1] - g.nodes[node2]['pos'][1]) ** 2)
                if dis < 200: 
                    # l_rr:
                    if g.degree[node1] < 10 and g.degree[node2] < 10:
                        g.add_edge(node1, node2, dis=dis, type='repeater')
    # graph_plot(g)

    # for endnode in endnodes:
    #     for repeater in repeaters:
    #         dis = math.sqrt((g.nodes[endnode]['pos'][0] - g.nodes[repeater]['pos'][0]) ** 2 + (
    #                     g.nodes[endnode]['pos'][1] - g.nodes[repeater]['pos'][1]) ** 2)
    #         if dis < 100:  # l_er:
    #             g.add_edge(endnode, repeater, dis=dis, type='endnode')
    # print(endnodes)

    # plt.show()
    Lmax = 200
    deployment = QuantumRepeaterDeployment(g, Lmax, endnodes)
    centers = deployment.choose_centers()
    # print(centers)
    # print(len(centers))
    inter_nodes = deployment.find_intermediate_nodes(centers)
    # print(set(centers).intersection(set(endnodes)))
    # print(set(inter_nodes).intersection(set(endnodes)))
    # print(inter_nodes)
    nodes = list(centers) + list(inter_nodes) + endnodes
    subgraph = g.subgraph(nodes)
    # graph_plot(subgraph)
    # enmerate all pair of repeaters in the graphs 
    # find all the nodes that have degree 0 
    leaf_nodes = [n for n in subgraph.nodes if subgraph.degree[n] == 0 or subgraph.degree[n] == 1]
    # print(leaf_nodes)
    # deepcopy the subgraph
    subgraph = subgraph.copy()
    for node in leaf_nodes:
        # add edge between it and the nearest repeater
        nearest_repeater = None
        min_dis = float('inf')
        for repeater in nodes:
            if repeater == node:
                continue
            dis = math.sqrt((subgraph.nodes[node]['pos'][0] - subgraph.nodes[repeater]['pos'][0]) ** 2 + (
                            subgraph.nodes[node]['pos'][1] - subgraph.nodes[repeater]['pos'][1]) ** 2)
            # print(dis)
            if dis < min_dis:
                min_dis = dis
                nearest_repeater = repeater
        # print(nearest_repeater,min_dis)
        subgraph.add_edge(node, nearest_repeater, dis = min_dis, type='repeater')
    # graph_plot(subgraph)  
    # try tp find connected components in the graph
    connected_components = nx.connected_components(subgraph)
    # print(list(connected_components))
    while len(list(nx.connected_components(subgraph))) > 1:
        # find the smallest connected component
        min_comp = min(list(nx.connected_components(subgraph)), key=len)
        min_dis = float('inf')
        nearest_node = None
        for node in min_comp:
            for nodes in subgraph.nodes:
                if nodes == node or nodes in min_comp:
                    continue
                dis = math.sqrt((subgraph.nodes[node]['pos'][0] - subgraph.nodes[nodes]['pos'][0]) ** 2 + (
                            subgraph.nodes[node]['pos'][1] - subgraph.nodes[nodes]['pos'][1]) ** 2)
                if dis < min_dis:
                    min_dis = dis
                    nearest_node = nodes
        # print(nearest_node,min_dis)
        print("connexed components")
        subgraph.add_edge(node, nearest_node, dis = min_dis, type='repeater')
        # print(nx.is_connected(subgraph))

    
    repeater_nodes = [n for n in subgraph.nodes if subgraph.nodes[n]['type'] == 'repeater']
    # for all repeater_nodes: allocate randomly 10-15 qubits
    # print("Number of repeater nodes: ", len(repeater_nodes))
    for node in repeater_nodes:
        # print(node)
        subgraph.nodes[node]['num_qubits'] = np.random.randint(10, 15)

    # endnode_subgraph = g.subgraph(endnodes)

    endnode_nodes = [n for n in subgraph.nodes if subgraph.nodes[n]['type'] == 'endnode']
    # for all endnode_nodes: connect to 2 nearest repeaters
    for node in endnode_nodes:
        # print(node)
        neighbors = get_nearest_neighbors(subgraph, node, 2)
        for neighbor in neighbors:
            dis = math.sqrt((subgraph.nodes[node]['pos'][0] - subgraph.nodes[neighbor]['pos'][0]) ** 2 + (
                            subgraph.nodes[node]['pos'][1] - subgraph.nodes[neighbor]['pos'][1]) ** 2)
            subgraph.add_edge(node, neighbor, dis=dis, type='endnode')
    # assert repeater nodes in subgraph has attribute num_qubits
    assert all('num_qubits' in subgraph.nodes[node] for node in repeater_nodes)
    for e in subgraph.edges:
        u, v = e
        if subgraph.nodes[u]['type'] == 'repeater' and subgraph.nodes[v]['type'] == 'repeater':
            subgraph.edges[e]['type'] = 'repeater'
        else:
            subgraph.edges[e]['type'] = 'endnode'
        subgraph.edges[e]['dis'] = math.sqrt((subgraph.nodes[u]['pos'][0] - subgraph.nodes[v]['pos'][0]) ** 2 + (
                subgraph.nodes[u]['pos'][1] - subgraph.nodes[v]['pos'][1]) ** 2)
    # Save the graph
    endnode_file_name = file.split('/')[-1]
    useful_name_seg = endnode_file_name.split('-')[1] + '-' + endnode_file_name.split('-')[2] 
    save_path = abs_file_path + "/dist/topos/map_size/mca-" + useful_name_seg
    # print(savePath)
    with open(save_path, 'w') as f:
        json.dump(nx.node_link_data(subgraph), f)
    # print(len(centers) + len(inter_nodes))
    print("MCA graph saved for " + save_path)

    # graph_path = abs_file_path + "/dist/topos/map_size/mca-" + useful_name_seg
    # graph_name = "mca-" + useful_name_seg
    # allocate_resource_uniformly(graph_name, graph_name, 2)
    # print("Resource allocated for " + graph_name)


    graph_plot(subgraph) 




