In [None]:
pip install ete3

In [2]:
from ete3 import Tree
from collections import defaultdict
import networkx as nx
import matplotlib.pyplot as plt

In [15]:
vertices = []
vertices_no = 0
graph = []
edge_frequencies = defaultdict(int)
all_internal_nodes = []
in_degrees = []

def identify_nodes(tree):
    leaf_nodes = []
    for node in tree.traverse("postorder"):
        if node.is_leaf():
            leaf_nodes.append(node.name)
    return leaf_nodes


def add_vertex(v):
    global graph
    global vertices_no
    global vertices
    global in_degrees
    if v not in vertices:
      vertices_no = vertices_no + 1
      vertices.append(v)
      in_degrees.append(0)
      if vertices_no > 1:
          for vertex in graph:
              vertex.append(0)
      graph.append([0] * vertices_no)


def add_edge(v1, v2, e):
    global graph
    global vertices_no
    global vertices
    global in_degrees
    edge_frequencies[(v1, v2)] += 1
    edge_frequencies[(v2, v1)] += 1  # Assuming an undirected graph
    index1 = vertices.index(v1)
    index2 = vertices.index(v2)
    # Check if an edge already exists between v1 and v2
    if graph[index1][index2] != 0:
        graph[index1][index2] = graph[index1][index2] + e
        in_degrees[index2] += 1
        # print(v1," to ", v2, " more than once\n\n")
    else:
        graph[index1][index2] = e
        in_degrees[index2] += 1
        # print(v1," to ", v2, "\n\n")
    # print("graph is now as: ",graph)


def update_internal_node_names(tree, leaf_nodes):
    internal_nodes = []
    def generate_unique_name(name):
        # Generate a unique name by adding a suffix
        suffix = 1
        while name + str(suffix) in leaf_nodes:
            suffix += 1
        return name + str(suffix)

    def update_internal_node(node):
        if not node.is_leaf():
            children_names = sorted([child.name for child in node.children])
            new_name = ''.join(children_names)
            new_name = ''.join(sorted(new_name))
            if new_name in leaf_nodes:
                new_name = generate_unique_name(new_name)
            node.name = new_name
            internal_nodes.append(node.name)
            if node.name not in all_internal_nodes:
              all_internal_nodes.append(node.name)

    for node in tree.traverse("postorder"):
        update_internal_node(node)
    return internal_nodes


def check_direct_connection(tree, v1, v2):
    # Find the nodes by name
    node1 = tree.search_nodes(name=v1)
    node2 = tree.search_nodes(name=v2)
    # print("\n", v1, node1,node1[0], " and ",v2,node2,"\n")
    if not node1 or not node2:
        print("they does not exits")
        return False  # One of the nodes does not exist

    node1 = node1[0]
    node2 = node2[0]
    #print("\ncheck_direct_connection :", node1, "-", node1[0],"\n")
    # Check if node1 is the parent of node2
    if node2.up == node1: # or node1.up == node2
        # print()
        # print(node1, "is parent of ",node2)
        # print()
        return True
    return False


def calculate_distances(tree):
    distances = {}
    all_nodes = list(tree.traverse())  # Collect all nodes
    for i in range(len(all_nodes)):
        for j in range(i + 1, len(all_nodes)):
            node1 = all_nodes[i]
            node2 = all_nodes[j]
            distance = node1.get_distance(node2)  # Calculate distance using the tree library function
            distance = round(distance, 2)
            distances[(node1.name, node2.name)] = distance
            distances[(node2.name, node1.name)] = distance  # Optionally, for undirected graphs
    return distances


def tree_to_graph(tree):
    # Get a list of all unique nodes (leaves and internal nodes)
    # print(tree)
    unique_nodes = set(tree.iter_leaves())
    unique_nodes.update(tree.traverse("preorder"))

    # Create a matrix to represent the distances
    node_names = [node.name for node in unique_nodes]  # Define node_names here
    node_names.sort(key=lambda x: (tree & x).get_distance(tree))
    print("node_names : ",node_names)

    distances = calculate_distances(tree)

    # add vertex
    for node in node_names:
        add_vertex(node)

    # add edge
    for i in range(0, len(node_names)):
      for j in range(i + 1, len(node_names)):
        v1 = node_names[i]
        v2 = node_names[j]
        if check_direct_connection(tree, v1,v2):
          distance = round(distances.get((v1, v2), 0), 2)
          #print("\n\nTree nodes conection : ",v1,"-",v2,"-",distance,"\n")
          add_edge(v1, v2, distance)


input_file = '/content/12trees23.txt'
with open(input_file, 'r') as file:
    for idx, line in enumerate(file, 1):
        tree = Tree(line.strip(), format=1)
        leaf_nodes = identify_nodes(tree)
        # print(leaf_nodes)
        internal_nodes = update_internal_node_names(tree,leaf_nodes)
        # print(internal_nodes)
        tree_to_graph(tree)


print("Combined Graph In-degrees:", in_degrees)

# print('\nnumber of internal nodes:',len(internal_nodes))
print('\n number of all internal nodes:',len(all_internal_nodes))
print('\n all internal nodes:',all_internal_nodes)
print('\n all vertices:',vertices)
print('\n number of vertices:',vertices_no)

for i in range(vertices_no):
  print(vertices[i],end="-")
  print(in_degrees[i])

# Extract unique node names
nodes = sorted(set(node for edge in edge_frequencies for node in edge))
num_nodes = len(nodes)
print('\n nodes:',nodes)
print('\n number of nodes:',len(nodes))

# Populate the frequency matrix with the values from the dictionary
frequency_matrix = [[0] * num_nodes for _ in range(num_nodes)]
for i in range(num_nodes):
    for j in range(num_nodes):
        node_i, node_j = nodes[i], nodes[j]
        if (node_i, node_j) in edge_frequencies:
            frequency_matrix[i][j] = edge_frequencies[(node_i, node_j)]
        elif (node_j, node_i) in edge_frequencies:
            frequency_matrix[i][j] = edge_frequencies[(node_j, node_i)]  # Assuming an undirected graph

for i in range(vertices_no):
  for j in range(vertices_no):
    if graph[i][j] != 0 and frequency_matrix[i][j] != 0:
       graph[i][j] = round(graph[i][j] / frequency_matrix[i][j], 2)
       #print("Hi ",graph[i][j])

print("frequency_matrix : ",frequency_matrix)
for i in graph:
  print(i)

node_names :  ['1', '23', '22', '21', '20', '00111111111111112222222333445566778899', '01111111111112233445566778899', '16', '15', '111111111123456789', '1156', '14', '19', '1189', '18', '01123456789', '01456789', '4', '111123', '11', '12', '13', '231', '2', '3', '5678', '5', '6', '7', '8', '17', '019', '9', '10']
node_names :  ['18', '19', '00111111111111112222222333445566778899', '0011111111111122222223334455667789', '17', '00111111111112222222333445566789', '111111123456', '12', '11', '1112', '111123', '13', '111456', '14', '1156', '16', '15', '00111122222233456789', '4', '8', '5678', '5', '7', '6', '567', '2', '3', '231', '10', '9', '019', '20', '011222223', '21', '23', '22', '1122223', '1']
node_names :  ['00111111111111112222222333445566778899', '111111111123456789', '11111111234789', '00111122222233456789', '11111234', '11', '12', '1112', '111123', '14', '13', '19', '1189', '18', '17', '01112222223345678', '1156', '019', '15', '16', '10', '9', '12345678', '45678', '231', '3', '2

In [16]:
def modified_prims(graph, frequencies, internal_nodes, degrees):
    num_nodes = len(graph)
    print("graph len: ", num_nodes)
    mst = []
    in_mst = [False] * num_nodes

    # Create a mapping from node names to indices
    node_name_to_index = {name: index for index, name in enumerate(vertices)}

    # Debugging: Print the mappings and input lists
    print("Vertices:", vertices)
    print("Internal Nodes:", internal_nodes)
    print("Node Name to Index Mapping:", node_name_to_index)

    # Start from the first node of interest.
    start_node = internal_nodes[3]
    print("start_node", start_node)
    # Handle case where start_node might not be in the mapping
    if start_node not in node_name_to_index:
        print(f"Start node '{start_node}' not found in node_name_to_index mapping.")
        return []  # or handle this case as needed

    # Convert start_node from a name to an index
    start_node_index = node_name_to_index[start_node]
    print("start_node_index", start_node_index)
    # Use this index in your MST algorithm
    in_mst[start_node_index] = True

    while len(mst) < len(internal_nodes) - 1:
        max_degree = float('-inf')
        max_frequency = float('-inf')
        min_branch_length = float('inf')
        chosen_edge = (None, None, float('inf'))  # Initialize with placeholder values

        for node_name in internal_nodes:
            node_index = node_name_to_index[node_name]
            if in_mst[node_index]:
                for neighbor_index, edge_weight in enumerate(graph[node_index]):
                    neighbor_name = vertices[neighbor_index]  # Get the name of the neighbor
                    if neighbor_name in internal_nodes and not in_mst[neighbor_index]:
                        frequency = frequencies[node_index][neighbor_index]
                        degree = degrees[neighbor_index]
                        # branch_length = graph [node_index][neighbor_index]

                        if frequency > max_frequency:
                            max_frequency = frequency
                            chosen_edge = (node_index, neighbor_index, edge_weight)
                        elif frequency == max_frequency and degree > max_degree:
                            max_degree = degree
                            max_frequency = frequency
                            chosen_edge = (node_index, neighbor_index, edge_weight)
                        elif frequency == max_frequency and degree == max_degree and edge_weight < min_branch_length:
                            max_degree = degree
                            max_frequency = frequency
                            min_branch_length = edge_weight
                            chosen_edge = (node_index, neighbor_index, edge_weight)



        # After the loop, update in_mst and mst using indices
        print("\n\nChosen edge: ", chosen_edge)
        # print("Type of chosen_edge[1]:", type(chosen_edge[1]))
        if chosen_edge[0] is not None and chosen_edge[1] is not None:
            # Use index for in_mst
            in_mst[chosen_edge[1]] = True
            # Convert indices back to names for mst
            # mst.append((vertices[chosen_edge[0]], vertices[chosen_edge[1]]))
            mst.append((vertices[chosen_edge[0]], vertices[chosen_edge[1]], chosen_edge[2]))

    return mst



def draw_mst(mst_G):
    # Draw the graph
    pos = nx.spring_layout(mst_G)  # Layout algorithm (you can choose another one)

    # Increase the figure size (adjust the width and height as needed)
    plt.figure(figsize=(15, 15))

    # Customize node and edge visual properties for clarity
    nx.draw(mst_G, pos, with_labels=True, node_size=30, node_color='skyblue', font_size=8, font_color='black', edge_color='gray', width=0.8)

    # Display the graph
    plt.axis('off')
    plt.show()



# Call the modified_prims function with the specified nodes
mst_result = modified_prims(graph, frequency_matrix, all_internal_nodes, in_degrees)
print('\nMST :',mst_result)

mst_G = nx.Graph()
for edge in mst_result:
    node1, node2, distance = edge
    mst_G.add_edge(node1, node2, weight=distance)
    # mst_G.add_edge(edge[0], edge[1])
    # print(node1, node2,'->', distance)
# draw_mst(mst_G)
print("\n\nmst_G: ", mst_G)


graph len:  83
Vertices: ['1', '23', '22', '21', '20', '00111111111111112222222333445566778899', '01111111111112233445566778899', '16', '15', '111111111123456789', '1156', '14', '19', '1189', '18', '01123456789', '01456789', '4', '111123', '11', '12', '13', '231', '2', '3', '5678', '5', '6', '7', '8', '17', '019', '9', '10', '0011111111111122222223334455667789', '00111111111112222222333445566789', '111111123456', '1112', '111456', '00111122222233456789', '567', '011222223', '1122223', '11111111234789', '11111234', '01112222223345678', '12345678', '45678', '0112222', '01111111111111222222233344556677889', '56', '78', '11111111234567', '01222223', '0222', '111567', '0011111111111122222223334456778899', '111789', '123', '0011111111111112222222333445566778899', '0011111111112222222333445566789', '00111112222223345566789', '0011122222233456789', '0123456789', '013456789', '568', '0149', '011111111111112222222333456789', '011111111111122222233456789', '1111112389', '0011111111111111222222233

In [17]:
# Assuming result is already obtained from modified_prims function
G = nx.Graph()
for edge in mst_result:
    #G.add_edge(edge[0], edge[1])
    node1, node2, distance = edge
    G.add_edge(node1, node2, weight=distance)
# Create a mapping from node names to indices
node_name_to_index = {name: index for index, name in enumerate(vertices)}

print("G: ", G)
for i in graph:
  print(i)

# Assuming dist_matrix is a 2D array-like structure with distances between nodes
# and node_name_to_index is a dictionary mapping node names to their indices in dist_matrix
print("leaf_nodes : ",leaf_nodes)
for leaf in leaf_nodes:
    leaf_index = node_name_to_index[leaf]
    nearest_internal_node = None
    min_distance = float('inf')
    max_freq = -1
    for internal_node in all_internal_nodes:
        internal_node_index = node_name_to_index[internal_node]
        distance = graph[leaf_index][internal_node_index]
        freq = frequency_matrix[leaf_index][internal_node_index]

        if freq > max_freq:
            nearest_internal_node = internal_node
        elif freq == max_freq and distance < min_distance:
              min_distance = distance
              nearest_internal_node = internal_node

    if nearest_internal_node is not None:
        G.add_edge(leaf, nearest_internal_node,weight=distance)
        print("New added : ",leaf, nearest_internal_node, min_distance)

# Now G contains the original MST with leaf nodes connected to the nearest internal nodes

# Create a set of leaf nodes for faster lookup
leaf_node_set = set(leaf_nodes)

print("with leaves G: ", G)
for i in graph:
  print(i)

# # Iterate over a copy of the internal nodes list to avoid modification during iteration
# for internal_node in all_internal_nodes[:]:
#     # Check the neighbors of the internal node
#     neighbors = list(G.neighbors(internal_node))

#     # Check if the internal node is connected only to one other internal node and no leaf nodes
#     if len(neighbors) == 1 and neighbors[0] in all_internal_nodes and not any(neighbor in leaf_node_set for neighbor in neighbors):
#         G.remove_node(internal_node)
#         all_internal_nodes.remove(internal_node)
#     # New condition: if the internal node is connected to exactly one leaf node and one internal node
#     elif len(neighbors) == 2 and any(neighbor in leaf_node_set for neighbor in neighbors) and any(neighbor in all_internal_nodes for neighbor in neighbors):
#         # Identify the leaf node and the other internal node
#         leaf_node = next(neighbor for neighbor in neighbors if neighbor in leaf_node_set)
#         other_internal_node = next(neighbor for neighbor in neighbors if neighbor in all_internal_nodes)

#         # Calculate the new branch length as the sum of distances from the distance matrix
#         new_branch_length = graph[node_name_to_index[internal_node]][node_name_to_index[leaf_node]] + graph[node_name_to_index[internal_node]][node_name_to_index[other_internal_node]]

#         # Remove the internal node
#         G.remove_node(internal_node)

#         # Add an edge between the leaf node and the other internal node with the new branch length
#         # Note: You might need to adjust this part if your graph G does not store weights
#         G.add_edge(leaf_node, other_internal_node, weight=new_branch_length)

#         all_internal_nodes.remove(internal_node)  # Optional: update the internal nodes list
print("After removing G: ", G)
for i in graph:
  print(i)
# Now G contains only the internal nodes that are connected to leaf nodes or more than one internal node

G:  Graph with 60 nodes and 59 edges
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [18]:
# Print all edges and their weights
for edge in G.edges(data=True):
    node1, node2, weight = edge
    print(f"Edge: {node1} - {node2}, Weight: {weight['weight']}")

Edge: 01456789 - 111123, Weight: 0
Edge: 01456789 - 45678, Weight: 0.02
Edge: 45678 - 78, Weight: 0
Edge: 45678 - 0011111111112222222333445566789, Weight: 0
Edge: 45678 - 11111111234789, Weight: 0
Edge: 78 - 00111112222223345566789, Weight: 0
Edge: 78 - 0011122222233456789, Weight: 0
Edge: 78 - 0011111111111122222223334455667789, Weight: 0
Edge: 0011111111111122222223334455667789 - 56, Weight: 0
Edge: 0011111111111122222223334455667789 - 111789, Weight: 0
Edge: 0011111111111122222223334455667789 - 1189, Weight: 0
Edge: 0011111111111122222223334455667789 - 011222223, Weight: 0
Edge: 56 - 123, Weight: 0
Edge: 56 - 0011111111111112222222333445566778899, Weight: 0
Edge: 111789 - 567, Weight: 0
Edge: 567 - 1122223, Weight: 0
Edge: 567 - 01123456789, Weight: 0
Edge: 567 - 00111111111112222222333445566789, Weight: 0
Edge: 567 - 111111123456, Weight: 0
Edge: 1122223 - 0011111111111122222223334456778899, Weight: 0
Edge: 1122223 - 11111234, Weight: 0
Edge: 1122223 - 11111111234567, Weight: 0
Edg

In [20]:
def networkx_to_ete3(G, root_node):
    if root_node not in G:
        raise ValueError(f"Root node '{root_node}' not found in the graph.")

    visited = set()  # To keep track of visited nodes

    def add_children(node, ete3_node):
        visited.add(node)  # Mark the current node as visited
        for child in G.neighbors(node):
            if child not in visited:
                # Create the child node
                child_node = ete3_node.add_child(name=str(child))

                # Set the branch length (edge weight)
                branch_length = G[node][child].get('weight', 1.0)  # Default to 1.0 if no weight is found
                child_node.dist = branch_length

                # Recursively add children of this node
                add_children(child, child_node)

    root = Tree(name=str(root_node))
    add_children(root_node, root)
    return root

# Example usage
root_node = '0011111111111122222223334455667789'
ete3_tree = networkx_to_ete3(G, root_node)
print(ete3_tree.write(format=5))

((((111123:0):0.02,0011111111112222222333445566789:0,11111111234789:0):0,00111112222223345566789:0,0011122222233456789:0):0,(123:0,0011111111111112222222333445566778899:0):0,(((0011111111111122222223334456778899:0,11111234:0,11111111234567:0,(((00111111111111112222222333445566778899:0,(((1111112389:0,((1156:0,(((01222223:0):0,0149:0):0,(568:0):0,013456789:0):0):0):0):0,(0011111111111111222222233344567899:0,((1111112356:0,01122222345678:0):0,(5:0,7:0,6:0,8:0,3:0,2:0,4:0,23:0,1:0,21:0,22:0,20:0,11:0,12:0,13:0,14:0,10:0,9:0,15:0,16:0,18:0,19:0,17:0):0,(11222:0):0,(0111569:0):0,0122223:0):0):0,111567:0):0,1235678:0):0):0.01):0,01112222223345678:0):0,111456:0,00111122222233456789:0):0,01123456789:0,(12345678:0):0,111111123456:0):0):0,1189:0,011222223:0);
