In [2]:
from ete3 import Tree

def get_pairwise_relationships(tree):
    pairs = set()
    for node in tree.traverse():
        if not node.is_leaf():
            leaves = node.get_leaves()
            for i in range(len(leaves)):
                for j in range(i + 1, len(leaves)):
                    pairs.add((leaves[i].name, leaves[j].name))
    return pairs

def calculate_metrics(true_pairs, predicted_pairs):
    TP = len(true_pairs & predicted_pairs)
    FP = len(predicted_pairs - true_pairs)
    TN = len((all_pairs - true_pairs) - predicted_pairs)
    FN = len(true_pairs - predicted_pairs)
    
    accuracy = (TP + TN) / (TP + FP + TN + FN)
    precision = TP / (TP + FP) if (TP + FP) != 0 else 0
    recall = TP / (TP + FN) if (TP + FN) != 0 else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) != 0 else 0
    
    return accuracy, precision, recall, f1_score

# Example usage:
true_tree_newick = "((6:1.24505,((8:1.0197,7:0.914115)1:0.426286,(((((20:0.538245,19:0.512534)1:0.506217,(18:0.760201,17:0.708199)1:0.20372)1:0.468766,(16:0.898558,15:0.654545)1:0.487851)1:0.513499,(12:0.948739,(14:0.934243,13:0.760605)1:0.359892)1:0.384672)1:0.450774,(9:0.939145,(11:0.863112,10:0.691522)1:0.41104)1:0.426002)1:0.335438)1:0.417831)1:0.365596,(4:0.922081,(3:0.833797,(2:0.666439,1:0.681804)1:0.391529)1:0.357839)1:0.287725,5:1.04341);;"  # Example true tree in Newick format
predicted_tree_newick = "(((((20:0.640199,19:0.639362)1:0.545529,(18:0.927762,17:0.818653)1:0.231335)1:0.620454,((16:0.837419,15:0.57274)1:0.559628,(14:0.665455,13:0.511096)1:0.337877)1:0.341629)1:1.32391,(7:0.890723,(8:0.994145,(9:0.961732,(10:0.79211,(12:0.746754,11:0.558187)1:0.531298)1:0.405829)1:0.300937)1:0.271281)1:0.300668)1:0.431401,(5:0.745692,(4:0.734082,(3:0.798103,(2:0.691664,1:0.680828)1:0.461073)1:0.310279)1:0.349786)1:0.211194,6:0.892796);"  # Example predicted tree in Newick format

true_tree = Tree(true_tree_newick)
predicted_tree = Tree(predicted_tree_newick)

true_pairs = get_pairwise_relationships(true_tree)
print(true_pairs)
predicted_pairs = get_pairwise_relationships(predicted_tree)
all_leaves = true_tree.get_leaf_names()
all_pairs = set((all_leaves[i], all_leaves[j]) for i in range(len(all_leaves)) for j in range(i + 1, len(all_leaves)))

accuracy, precision, recall, f1_score = calculate_metrics(true_pairs, predicted_pairs)
print(f"Accuracy: {accuracy}, Precision: {precision}, Recall: {recall}, F1 Score: {f1_score}")

{('20', '4'), ('19', '3'), ('20', '2'), ('6', '18'), ('20', '9'), ('4', '5'), ('14', '13'), ('18', '3'), ('8', '17'), ('4', '2'), ('8', '10'), ('9', '1'), ('19', '4'), ('6', '13'), ('19', '2'), ('8', '20'), ('19', '9'), ('7', '18'), ('13', '3'), ('6', '15'), ('18', '9'), ('10', '1'), ('7', '13'), ('13', '5'), ('12', '13'), ('13', '4'), ('15', '1'), ('13', '2'), ('6', '1'), ('17', '13'), ('20', '19'), ('13', '9'), ('7', '15'), ('20', '14'), ('14', '10'), ('6', '17'), ('6', '10'), ('7', '1'), ('12', '1'), ('19', '14'), ('6', '20'), ('16', '13'), ('14', '11'), ('17', '16'), ('14', '5'), ('7', '17'), ('7', '10'), ('6', '11'), ('12', '10'), ('16', '15'), ('17', '10'), ('7', '20'), ('8', '16'), ('8', '7'), ('8', '12'), ('16', '1'), ('7', '11'), ('15', '13'), ('12', '11'), ('18', '13'), ('17', '11'), ('17', '5'), ('3', '5'), ('8', '11'), ('18', '15'), ('9', '10'), ('8', '5'), ('19', '1'), ('8', '4'), ('8', '2'), ('2', '1'), ('18', '1'), ('6', '16'), ('9', '11'), ('19', '17'), ('11', '1'), ('9

In [9]:
import os
from secrets import token_hex
import subprocess

n_trees = 500
n_leaves = 20
seq_len = 4900
r = 0.001
t = 0.001
dataset_id = f'500-{n_leaves}-simbac-{seq_len}-{r}-{t}'

trees = os.path.join('./datasets/trees', dataset_id)
if not os.path.exists(trees):
    os.makedirs(trees)

alignments = os.path.join('./datasets/alignments', dataset_id)
if not os.path.exists(alignments):
    os.makedirs(alignments)


for i in range(n_trees):
    generation_id = token_hex(4)
    print(f'{i+1}/{n_trees}:{generation_id}', end='\r')
    command = f'./sequence_generators/SimBac -N {n_leaves} -B {seq_len} -R {r} -T {t} -o {alignments}/{generation_id}.fasta -c {trees}/{generation_id}.nwk'
    _ = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

5000/5000:3ca76ae6

In [22]:
import os
from secrets import token_hex
import subprocess

n_trees = 500
n_leaves = 20
seq_len = 700000
r = 0.001
t = 0.001
dataset_id = f'500-{n_leaves}-simbac-{seq_len}-{r}-{t}'

trees = os.path.join('./datasets/trees', dataset_id)
if not os.path.exists(trees):
    os.makedirs(trees)

alignments = os.path.join('./datasets/alignments', dataset_id)
if not os.path.exists(alignments):
    os.makedirs(alignments)


for i in range(n_trees):
    generation_id = token_hex(4)
    print(f'{i+1}/{n_trees}:{generation_id}', end='\r')
    command = f'./sequence_generators/SimBac -N {n_leaves} -B {seq_len} -R {r} -T {t} -o {alignments}/{generation_id}.fasta -c {trees}/{generation_id}.nwk'
    _ = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

500/500:a597c96e

In [3]:
from phyloformer.data import TensorDataset

# ds = TensorDataset('datasets/sequence_data/5k-20-20-un-un-seqgen-pam-200-aa')
ds = TensorDataset('datasets/typing_data/5k-20-simbac-4900-0.001-0.001')

In [4]:
ds[0][0].shape

torch.Size([32, 7, 20])

In [19]:
ds[0][0][23]

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

In [None]:
# python generate_typings.py -a 500-20-simbac-4900-0.001-0.001 -o 500-20-simbac-4900-0.001-0.001 -b 7 -e 600 -i 100
# python typing_predicted_trees.py -y 500-20-simbac-4900-0.001-0.001 -o tcmegf1q-5rinc72q -m train/checkpoints/tcmegf1q/5rinc72q.best_checkpoint.pt -d cuda:0 -s
# python evaluate.py -t ../../datasets/trees/500-20-simbac-4900-0.001-0.001 -p tcmegf1q-5rinc72q

In [None]:
# python typing_true_trees.py -y 500-20-simbac-4900-0.001-0.001 -o 500-20-simbac-4900-0.001-0.001
# python evaluate.py -t 500-20-simbac-4900-0.001-0.001 -p tcmegf1q-5rinc72q

In [None]:
# python evaluate.py -t ../../datasets/trees/500-20-simbac-700000-0.001-0.001 -p vdllc6f9-kwz53xe1

In [2]:
import numpy as np
import os
from biotite.sequence.phylo import upgma

def read_distance_matrix(file_path):
    with open(file_path, 'r') as f:
        lines = f.readlines()

    size = int(lines[0].strip())
    matrix = np.zeros((size, size))

    for line in lines[1:]:
        parts = line.split()
        row = int(parts[0]) - 1  # Convert to zero-based index
        values = [float(x) for x in parts[1:]]
        matrix[row, :] = values

    return matrix

dms_path = './test/predicted_trees/tcmegf1q-5rinc72q'
dms_files = [os.path.join(dms_path, dm) for dm in os.listdir(dms_path) if dm.endswith('dm')]

for dm_file in dms_files:
    distance_matrix = read_distance_matrix(dm_file)
    tree = upgma(distance_matrix)

(((((19:0.808489978313446,13:0.808489978313446):0.15585380792617798,4:0.964343786239624):0.03056102991104126,((18:0.6875519752502441,8:0.6875519752502441):0.21757829189300537,14:0.9051302671432495):0.08977454900741577):0.05065423250198364,(17:0.7952049970626831,3:0.7952049970626831):0.2503540515899658):0.02659451961517334,((((16:1.003638744354248,(15:0.8265569806098938,2:0.8265569806098938):0.17708176374435425):0.0044966936111450195,(12:0.7738329768180847,7:0.7738329768180847):0.23430246114730835):0.02954733371734619,((11:0.6767295002937317,10:0.6767295002937317):0.15339046716690063,(9:0.6699420213699341,5:0.6699420213699341):0.16017794609069824):0.20756280422210693):0.026340723037719727,(6:0.9316967725753784,(1:0.8217840194702148,0:0.8217840194702148):0.10991275310516357):0.13232672214508057):0.008130073547363281):0.0;


In [25]:
import numpy as np
import os
from biotite.sequence.phylo import upgma
from fastcluster import linkage
from scipy.cluster.hierarchy import to_tree, ClusterNode, dendrogram
from string import ascii_lowercase


def read_distance_matrix(file_path):
    with open(file_path, 'r') as f:
        lines = f.readlines()

    size = int(lines[0].strip())
    matrix = np.zeros((size, size))

    for line in lines[1:]:
        parts = line.split()
        row = int(parts[0]) - 1  # Convert to zero-based index
        values = [float(x) for x in parts[1:]]
        matrix[row, :] = values

    return matrix

def _scipy_tree_to_newick_list(node, newick, parentdist, leaf_names):
    """Construct Newick tree from SciPy hierarchical clustering ClusterNode

    This is a recursive function to help build a Newick output string from a scipy.cluster.hierarchy.to_tree input with
    user specified leaf node names.

    Notes:
        This function is meant to be used with `to_newick`

    Args:
        node (scipy.cluster.hierarchy.ClusterNode): Root node is output of scipy.cluster.hierarchy.to_tree from hierarchical clustering linkage matrix
        parentdist (float): Distance of parent node of `node`
        newick (list of string): Newick string output accumulator list which needs to be reversed and concatenated (i.e. `''.join(newick)`) for final output
        leaf_names (list of string): Leaf node names

    Returns:
        (list of string): Returns `newick` list of Newick output strings
    """
    if node.is_leaf():
        return newick + [f'{leaf_names[node.id]}:{parentdist - node.dist}']
        # return newick + [f'{node.id}:{parentdist - node.dist}']

    if len(newick) > 0:
        newick.append(f'):{parentdist - node.dist}')
    else:
        newick.append(');')
    newick = _scipy_tree_to_newick_list(node.get_left(), newick, node.dist, leaf_names)
    newick.append(',')
    newick = _scipy_tree_to_newick_list(node.get_right(), newick, node.dist, leaf_names)
    newick.append('(')
    return newick

def to_newick(tree: ClusterNode, leaf_names) -> str:
    """Newick tree output string from SciPy hierarchical clustering tree

    Convert a SciPy ClusterNode tree to a Newick format string.
    Use scipy.cluster.hierarchy.to_tree on a hierarchical clustering linkage matrix to create the root ClusterNode for the `tree` input of this function.

    Args:
        tree (scipy.cluster.hierarchy.ClusterNode): Output of scipy.cluster.hierarchy.to_tree from hierarchical clustering linkage matrix
        leaf_names (list of string): Leaf node names

    Returns:
        (string): Newick output string
    """
    newick_list = _scipy_tree_to_newick_list(tree, [], tree.dist, leaf_names)
    return ''.join(newick_list[::-1])


dms_path = './test/predicted_trees/tcmegf1q-5rinc72q'
# dms_path = './test/predicted_trees/u7nim7dq-e4sd89o0'
# dms_path = './test/predicted_trees/1c6tm5ng-x2te5n0m'
# dms_path = './test/predicted_trees/vdllc6f9-kwz53xe1'
nwk_output_path = './test/predicted_trees/nwk_files'
os.makedirs(nwk_output_path, exist_ok=True)

dms_files = [os.path.join(dms_path, dm) for dm in os.listdir(dms_path) if dm.endswith('dm')]

for dm_file in dms_files:
    distance_matrix = read_distance_matrix(dm_file)
    tree = upgma(distance_matrix)
    newick_str = tree.to_newick(include_distance=True)
    # Z = linkage(distance_matrix)
    # T = to_tree(Z, rd=False)
    # T = to_newick(T, ascii_lowercase)
    # T = T.replace('):', ')1:')
    
    # Save the Newick string to a file
    base_name = os.path.basename(dm_file).replace('.dm', '.nwk')
    nwk_file_path = os.path.join(nwk_output_path, base_name)
    with open(nwk_file_path, 'w') as nwk_file:
        nwk_file.write(newick_str)
        # nwk_file.write(T)
    
    print(f"Saved tree to {nwk_file_path}")

Saved tree to ./test/predicted_trees/nwk_files/2daeb89c.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/6b332c14.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/81915534.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/4d1431c7.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/facd8570.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/7f2d531f.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/08cb3a94.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/74ae7b87.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/507f4b41.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/9b6a763e.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/88db302c.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/6f555ab7.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/3471c581.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/43198632.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/ecaa9074.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/d8afd617

In [29]:
import numpy as np
import os
from biotite.sequence.phylo import upgma
from fastcluster import linkage
from scipy.cluster.hierarchy import to_tree, ClusterNode, dendrogram
from string import ascii_lowercase

def read_distance_matrix(file_path):
    with open(file_path, 'r') as f:
        lines = f.readlines()

    size = int(lines[0].strip())
    labels = []
    matrix = np.zeros((size, size))

    for i, line in enumerate(lines[1:]):
        parts = line.split()
        labels.append(parts[0])  # Add the row label
        values = [float(x) for x in parts[1:]]
        matrix[i, :] = values

    return labels, matrix

def _scipy_tree_to_newick_list(node, newick, parentdist, leaf_names):
    """Construct Newick tree from SciPy hierarchical clustering ClusterNode

    This is a recursive function to help build a Newick output string from a scipy.cluster.hierarchy.to_tree input with
    user specified leaf node names.

    Notes:
        This function is meant to be used with `to_newick`

    Args:
        node (scipy.cluster.hierarchy.ClusterNode): Root node is output of scipy.cluster.hierarchy.to_tree from hierarchical clustering linkage matrix
        parentdist (float): Distance of parent node of `node`
        newick (list of string): Newick string output accumulator list which needs to be reversed and concatenated (i.e. `''.join(newick)`) for final output
        leaf_names (list of string): Leaf node names

    Returns:
        (list of string): Returns `newick` list of Newick output strings
    """
    if node.is_leaf():
        return newick + [f'{leaf_names[node.id]}:{parentdist - node.dist}']

    if len(newick) > 0:
        newick.append(f'):{parentdist - node.dist}')
    else:
        newick.append(');')
    newick = _scipy_tree_to_newick_list(node.get_left(), newick, node.dist, leaf_names)
    newick.append(',')
    newick = _scipy_tree_to_newick_list(node.get_right(), newick, node.dist, leaf_names)
    newick.append('(')
    return newick

def to_newick(tree: ClusterNode, leaf_names) -> str:
    """Newick tree output string from SciPy hierarchical clustering tree

    Convert a SciPy ClusterNode tree to a Newick format string.
    Use scipy.cluster.hierarchy.to_tree on a hierarchical clustering linkage matrix to create the root ClusterNode for the `tree` input of this function.

    Args:
        tree (scipy.cluster.hierarchy.ClusterNode): Output of scipy.cluster.hierarchy.to_tree from hierarchical clustering linkage matrix
        leaf_names (list of string): Leaf node names

    Returns:
        (string): Newick output string
    """
    newick_list = _scipy_tree_to_newick_list(tree, [], tree.dist, leaf_names)
    return ''.join(newick_list[::-1])

dms_path = './test/predicted_trees/bbfgxmxm-359g7rv1'
# dms_path = './test/predicted_trees/pkuyiofs-khmvtpeh'

# dms_path = './test/predicted_trees/l72144nu-s44a5myr'
# dms_path = './test/predicted_trees/f6v088s5-7rbm50jz'
nwk_output_path = './test/predicted_trees/nwk_files'
os.makedirs(nwk_output_path, exist_ok=True)

dms_files = [os.path.join(dms_path, dm) for dm in os.listdir(dms_path) if dm.endswith('dm')]

for dm_file in dms_files:
    labels, distance_matrix = read_distance_matrix(dm_file)
    tree = upgma(distance_matrix)
    newick_str = tree.to_newick(include_distance=True, labels=labels)
    # Z = linkage(distance_matrix)
    # T = to_tree(Z, rd=False)
    # T = to_newick(T, labels)
    # T = T.replace('):', ')1:')
    
    # Save the Newick string to a file
    base_name = os.path.basename(dm_file).replace('.dm', '.nwk')
    nwk_file_path = os.path.join(nwk_output_path, base_name)
    with open(nwk_file_path, 'w') as nwk_file:
        nwk_file.write(newick_str)
        # nwk_file.write(T)
    
    print(f"Saved tree to {nwk_file_path}")

Saved tree to ./test/predicted_trees/nwk_files/1e2f8946-4_20.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/2b039f4e-17_20.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/82f13260-15_20.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/86ac609e-11_20.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/0fa10afb-15_20.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/1e2f8946-15_20.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/435fe6a0-3_20.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/1e2f8946-11_20.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/f79b81b8-6_20.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/1e2f8946-3_20.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/e360016a-10_20.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/c8b0ad58-9_20.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/2c6ad782-4_20.pf.nwk
Saved tree to ./test/predicted_trees/nwk_files/8892ac50-8_20.pf.nwk
Saved tree to ./test/predicted_trees/nwk_