In [None]:
pip install ete3

In [3]:
from ete3 import Tree
from collections import defaultdict
import networkx as nx
import matplotlib.pyplot as plt

In [None]:
vertices = []
vertices_no = 0
graph = []
edge_frequencies = defaultdict(int)
leaf_node_names = ['1','2','3','4','5','6','7','8','9','10','11','12','13','14']
def add_vertex(v):
    global graph
    global vertices_no
    global vertices
    if v not in vertices:
      vertices_no = vertices_no + 1
      vertices.append(v)
      if vertices_no > 1:
          for vertex in graph:
              vertex.append(0)
      temp = []
      for i in range(vertices_no):
          temp.append(0)
      graph.append(temp)

def add_edge(v1, v2, e):
    global graph
    global vertices_no
    global vertices
    edge_frequencies[(v1, v2)] += 1
    edge_frequencies[(v2, v1)] += 1  # Assuming an undirected graph
    index1 = vertices.index(v1)
    index2 = vertices.index(v2)
    # Check if an edge already exists between v1 and v2
    if graph[index1][index2] != 0:
        existing_weight = graph[index1][index2]
        new_weight = (existing_weight + e) / 2
        new_weight = round(new_weight, 2)
        graph[index1][index2] = new_weight
    else:
        graph[index1][index2] = e


def update_internal_node_names2(tree):
    def update_internal_node(node):
        if not node.is_leaf():
            children_names = sorted([child.name for child in node.children])
            new_name = ''.join(children_names)
            new_name = ''.join(sorted(new_name))
            if new_name in leaf_node_names:
                new_name = generate_unique_name(new_name)
            node.name = new_name
    def generate_unique_name(name):
        # Generate a unique name by adding a suffix
        suffix = 1
        while name + str(suffix) in leaf_node_names:
            suffix += 1
        return name + str(suffix)
    for node in tree.traverse("postorder"):
        update_internal_node(node)


def calculate_distance_matrices(tree, output_file):
    # Get a list of all unique nodes (leaves and internal nodes)
    unique_nodes = set(tree.iter_leaves())
    unique_nodes.update(tree.traverse("preorder"))

    # Create a matrix to represent the distances
    node_names = [node.name for node in unique_nodes]  # Define node_names here
    node_names.sort(key=lambda x: (tree & x).get_distance(tree))
    # Initialize a dictionary to store distances between nodes
    distance_matrix = []
    distances = {}
    for node in node_names:
        add_vertex(node)
    # Calculate distances
    for node1 in unique_nodes:
        for node2 in unique_nodes:
            if node1 != node2:
                distance = node1.get_distance(node2)
                distances[(node1.name, node2.name)] = distance
    # add edge
    for node1 in node_names:
        for node2 in node_names:
            if node1 != node2:
                distance1 = distances.get((node1, node2), 0)
                distance1 = round(distance1, 2)
                add_edge(node1, node2, distance1)

    for node1 in node_names:
        row = [distances.get((node1, node2), 0) for node2 in node_names]
        distance_matrix.append(row)
    distance_matrix = [[round(value, 2) for value in row] for row in distance_matrix]

    condensed_distance_matrix = []
    for i in range(len(node_names)):
        row = [distances.get((node_names[i], node_names[j]), 0) for j in range(i + 1)]
        condensed_distance_matrix.append(row)

    condensed_distance_matrix = [[round(value, 2) for value in row] for row in condensed_distance_matrix]
    return distance_matrix

input_file = 'input_trees'

with open(input_file, 'r') as file:
    for idx, line in enumerate(file, 1):
        tree = Tree(line.strip(), format=1)
        update_internal_node_names2(tree)
        distance_matrix2 = calculate_distance_matrices(tree, output_file)

# Extract unique node names
nodes = sorted(set(node for edge in edge_frequencies for node in edge))

# Create an empty 2D array filled with zeros
num_nodes = len(nodes)

frequency_matrix = [[0] * num_nodes for _ in range(num_nodes)]
# Populate the frequency matrix with the values from the dictionary
for i in range(num_nodes):
    for j in range(num_nodes):
        node_i, node_j = nodes[i], nodes[j]
        if (node_i, node_j) in edge_frequencies:
            frequency_matrix[i][j] = edge_frequencies[(node_i, node_j)]
        elif (node_j, node_i) in edge_frequencies:
            frequency_matrix[i][j] = edge_frequencies[(node_j, node_i)]  # Assuming an undirected graph

dist_matrix = [[0 for _ in range(vertices_no)] for _ in range(vertices_no)]
for i in range(vertices_no):
  for j in range(vertices_no):
    if graph[i][j] != 0:
      dist_matrix[i][j] = graph[i][j]



# Define a function to remove cycles from the graph
def remove_cycles(graph):
    G = nx.Graph()

    # Add edges to a NetworkX graph
    for i in range(len(graph)):
        for j in range(i + 1, len(graph[i])):
            weight = graph[i][j]
            if weight != 0:
                G.add_edge(i, j, weight=weight)

    # Find cycles
    cycles = list(nx.cycle_basis(G))

    # Remove edges in cycles
    for cycle in cycles:
        for i in range(len(cycle)):
            j = (i + 1) % len(cycle)
            u, v = cycle[i], cycle[j]
            if graph[u][v] != 0:
                graph[u][v] = 0

# Call the function to remove cycles from the graph
remove_cycles(graph)

# Define a function to remove duplicate edges
def remove_duplicates(graph):
    for i in range(len(graph)):
        for j in range(i + 1, len(graph[i])):
            if graph[i][j] == 0:
                continue
            if graph[i][j] != graph[j][i] and graph[j][i] != 0:
              # Keep the edge with average weight
              first_weight = graph[i][j]
              second_weight = graph[j][i]
              updated_weight = first_weight/second_weight
              updated_weight = round(updated_weight, 2)
              graph[i][j] = updated_weight
              graph[j][i] = 0

# Call the function to remove duplicate edges
remove_duplicates(graph)


def modified_prims(graph, frequencies, nodes_of_interest):
    num_nodes = len(graph)
    mst = []
    in_mst = [False] * num_nodes

    # Start from the first node of interest.
    start_node = nodes_of_interest[0]
    in_mst[start_node] = True

    while len(mst) < len(nodes_of_interest) - 1:
        max_frequency = -1
        # chosen_edge = None
        chosen_edge = (None, None, float('inf'))  # Initialize with placeholder values

        for node in nodes_of_interest:
            if in_mst[node]:
                for neighbor, edge_weight in enumerate(graph[node]):
                    if neighbor in nodes_of_interest and not in_mst[neighbor]:
                        frequency = frequencies[node][neighbor]
                        if frequency >= max_frequency and edge_weight < chosen_edge[2]:
                            max_frequency = frequency
                            chosen_edge = (node, neighbor, edge_weight)

        in_mst[chosen_edge[1]] = True
        mst.append((chosen_edge[0], chosen_edge[1]))

    return mst



nodes_of_interest = [0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]

# Call the modified_prims function with the specified nodes
result = modified_prims(dist_matrix, frequency_matrix, nodes_of_interest)

G = nx.Graph()
for edge in result:
    G.add_edge(vertices[edge[0]], vertices[edge[1]])

# Draw the graph
pos = nx.spring_layout(G)  # Layout algorithm (you can choose another one)

# Increase the figure size (adjust the width and height as needed)
plt.figure(figsize=(15, 15))

# Customize node and edge visual properties for clarity
nx.draw(G, pos, with_labels=True, node_size=300, node_color='skyblue', font_size=8, font_color='black', edge_color='gray', width=1.0)

# Display the graph
plt.axis('off')
plt.show()