# Tree completion methods comparison
## k-NCL vs RF(+)
### Heatmaps and boxplots for 4 datasets and 3 distance metrics

## k-NCL

In [1]:
# Please use the k-ncl script available on GitHub

"""
This script implements the k-NCL algorithm for completing phylogenetic trees
that are defined on different but overlapping taxon sets. The implementation
uses the ete3 library to work with phylogenetic trees.

Requires:  ete3  (pip install ete3)
"""

## BSD(+)
BSD(k-NCL)

In [2]:
import numpy as np
import math
from itertools import combinations

def squared_distance_sum(t1, t2, leaves):
    sum_sq_distance = 0
    for leaf1, leaf2 in combinations(leaves, 2):
        d1 = t1.get_distance(leaf1, leaf2)
        d2 = t2.get_distance(leaf1, leaf2)
        sum_sq_distance += (d1 - d2) ** 2
    return sum_sq_distance

def BSD(T1, T2, k=None):
    # Get the leaves from the original input trees (before completion)
    leaves1 = set(leaf.name for leaf in T1.get_leaves())
    leaves2 = set(leaf.name for leaf in T2.get_leaves())

    # Find the common leaves between the two original trees
    common_leaves = leaves1.intersection(leaves2)
    if len(common_leaves) < 3:
        return None, None, None, None

    # Complete both trees using the k-NCL function
    T1_completed, T2_completed = kNCL(T1, T2, k)

    # Get the leaves of the completed trees
    leaves_completed = set(leaf.name for leaf in T1_completed.get_leaves())

    # Calculate BSD over the completed trees and the leafset of T1_completed
    bsd_plus = math.sqrt(squared_distance_sum(T1_completed, T2_completed, leaves_completed))

    return bsd_plus

## RF on k-NCL

In [3]:
from ete3 import Tree

def rf_kncl(T1, T2, k=None):
    """
    Computes the rooted RF distance between two k-NCL completed trees.

    Parameters:
    - T1, T2: ete3.Tree objects (incomplete trees)
    - k: int, number of neighbors for completion

    Returns:
    - rf_kncl_dist: int, Robinson-Foulds distance
    """
    try:
        # Complete the trees using k-NCL algorithm
        T1_completed, T2_completed = kNCL(T1.copy(), T2.copy(), k)

        try:
            rf_kncl_dist, *_ = T1_completed.robinson_foulds(T2_completed, unrooted_trees=False)
        except Exception as e:
            rf_kncl_dist, *_ = T1_completed.robinson_foulds(T2_completed, unrooted_trees=True)

        return rf_kncl_dist

    except Exception as e:
        print(f"[ERROR] RF computation after k-NCL failed: {e}")
        return None, None, None

## RF(+)

In [4]:
# Please use the RF+ scripts available here: https://github.com/kty1/RFplus
# Source: https://compbio.engr.uconn.edu/software/rf_plus/


## Heatmaps with boxplots

In [None]:
import os
import math
import pickle
from itertools import combinations
from pathlib import Path

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from ete3 import Tree

# ------------------------------------------------------------------
# 0.  User-configurable globals
# ------------------------------------------------------------------
K                = None          # k for k-NCL completion; None for the default value
M                = 25            # number of trees for the smaller figure (more compact version)
TREES_DIR        = Path("Trees") # put your files with tree clusters here

DATASETS         = {             # filename  ->  figure-label
    "amphibians_clusters5.txt": "Amphibians",
    "birds_clusters5.txt"     : "Birds",
    "mammals_clusters5.txt"   : "Mammals",
    "sharks_clusters5.txt"    : "Sharks",
}

CACHE_PKL        = Path("distance_matrices_clusters.pkl")

# Heatmap-only figure filenames
FIG_FULL_PDF     = Path("heatmaps_full_proper_clusters.pdf")
FIG_FULL_SVG     = Path("heatmaps_full_proper_clusters.svg")
FIG_M_PDF        = Path(f"heatmaps_first{M}_proper_clusters.pdf")
FIG_M_SVG        = Path(f"heatmaps_first{M}_proper_clusters.svg")

# Heatmaps + boxplots figure filenames
FIG_FULL_BOX_PDF = Path("heatmaps_full_with_boxplots_proper_clusters.pdf")
FIG_FULL_BOX_SVG = Path("heatmaps_full_with_boxplots_proper_clusters.svg")
FIG_M_BOX_PDF    = Path(f"heatmaps_first{M}_with_boxplots_proper_clusters.pdf")
FIG_M_BOX_SVG    = Path(f"heatmaps_first{M}_with_boxplots_proper_clusters.svg")

# colormaps (one per metric row)
CMAPS = {
    "RF(+)"      : "Blues",
    "RF(k‑NCL)"  : "Greens",
    "BSD(+)"     : "Oranges",
}

latex_labels = {
    "RF(+)"     : "RF(+)",
    "RF(k‑NCL)" : r"RF($k$-NCL)",
    "BSD(+)"    : r"BSD($k$-NCL)",
}

COLORBAR_TICK_FONTSIZE = 6

# 1. Load trees from a data-file

def load_trees(file_path):
    """Return a list[ete3.Tree] from a one-tree-per-line text file."""
    trees = []
    with open(file_path) as fh:
        for line in fh:
            line = line.strip()
            if line:
                trees.append(Tree(line, format=1))
    return trees


# 2.  Compute or load the distance matrices

def BSD_from_completed(T1_completed, T2_completed):
    leaves = [leaf.name for leaf in T1_completed.get_leaves()]
    return math.sqrt(squared_distance_sum(T1_completed, T2_completed, leaves))

from math import comb
try:
    from tqdm import tqdm
    _progress = tqdm
except ImportError:
    # fallback – prints every 50 pairs
    def _progress(iterable, *, total=None, desc=""):
        for idx, item in enumerate(iterable, 1):
            if total and (idx % 50 == 0 or idx == total):
                print(f"{desc:10s}: {idx}/{total} pairs")
            yield item

def compute_distance_matrices():
    """
    Build (or load) a nested dict: matrices[metric][dataset] -> np.ndarray
    An incremental cache is written after each dataset (in order to save partial work).
    """
    # ---------------------------------------------------------------
    # 0. cached results
    # ---------------------------------------------------------------
    if CACHE_PKL.exists():
        with open(CACHE_PKL, "rb") as fh:
            print("[cache]  Using cached results from", CACHE_PKL)
            return pickle.load(fh)

    # 1. compute from scratch
    
    matrices = {metric: {} for metric in CMAPS}     # 3 × 4 nested dict

    for fname, label in DATASETS.items():
        trees = load_trees(TREES_DIR / fname)
        n     = len(trees)
        print(f"[info]  {label}: {n} trees")

        # empty n×n matrices for this dataset
        mats = {metric: np.zeros((n, n)) for metric in CMAPS}

        total_pairs = comb(n, 2)
        for i, j in _progress(combinations(range(n), 2),
                              total=total_pairs,
                              desc=label):

            T1, T2 = trees[i], trees[j]

            # one k-NCL completion reused by both RF(k-NCL) and BSD(k-NCL)
            T1_c, T2_c = kNCL(T1.copy(), T2.copy(), K)

            d_rf_plus  = rf_plus_distance(T1, T2)

            try:
                d_rf_kncl  = T1_c.robinson_foulds(T2_c, unrooted_trees=False)[0]
            except Exception as e:
                d_rf_kncl  = T1_c.robinson_foulds(T2_c, unrooted_trees=True)[0]

            d_bsd      = BSD_from_completed(T1_c, T2_c)

            # fill symmetric matrices
            for metric, value in (("RF(+)"     , d_rf_plus),
                                  ("RF(k-NCL)", d_rf_kncl),
                                  ("BSD(+)"   , d_bsd)):
                mats[metric][i, j] = mats[metric][j, i] = value

        # store matrices for this dataset
        for metric in CMAPS:
            matrices[metric][label] = mats[metric]

        # incremental cache
        with open(CACHE_PKL, "wb") as fh:
            pickle.dump(matrices, fh)
        print(f"[cache]  Saved partial results after {label}")

    
    # 2. final cache dump (overwrites the same file one last time)
    
    with open(CACHE_PKL, "wb") as fh:
        pickle.dump(matrices, fh)
    print("[cache]  All datasets complete – results saved to", CACHE_PKL)

    return matrices



# 3A.  Heatmap drawing

def draw_heatmaps(matrices, tree_limit=None, outfile_pdf=None, outfile_svg=None):
    """
    matrices : dict[metric][dataset] -> ndarray
    tree_limit : int or None – use only the first 'tree_limit' trees
    """
    import numpy as np

    n_rows = len(CMAPS)
    n_cols = len(DATASETS)
    figsize = (2.8 * n_cols, 2.6 * n_rows)

    fig, axes = plt.subplots(n_rows, n_cols,
                             figsize=figsize,
                             gridspec_kw={"wspace": 0.01, "hspace": 0.2})

    metric_list = list(CMAPS.keys())
    dataset_list = list(DATASETS.items())

    # Create aligned colorbars on the right of each row
    cbar_axes = []
    for row in range(n_rows):
        last_ax = axes[row, -1]
        pos = last_ax.get_position()
        cbar_ax = fig.add_axes([
            pos.x1 + 0.001,
            pos.y0,
            0.012,
            pos.height
        ])
        cbar_axes.append(cbar_ax)

    letter = ord("a")

    for row, metric in enumerate(metric_list):
        cmap = CMAPS[metric]
        for col, (fname, label) in enumerate(dataset_list):
            ax = axes[row, col]
            mat = matrices[metric][label]
            if tree_limit is not None:
                mat = mat[:tree_limit, :tree_limit]

            show_cbar = (col == n_cols - 1)

            heatmap = sns.heatmap(
                mat,
                ax=ax,
                cmap=cmap,
                cbar=show_cbar,
                cbar_ax=(cbar_axes[row] if show_cbar else None),
                square=True,
                xticklabels=False,
                yticklabels=False
            )

            # Set manual tick labels
            tick_labels = [str(i) for i in range(1, mat.shape[0] + 1)]
            ticks = np.arange(mat.shape[0]) + 0.5

            ax.set_xticks(ticks)
            ax.set_xticklabels(tick_labels, fontsize=3)
            ax.set_yticks(ticks)
            ax.set_yticklabels(tick_labels, fontsize=3)

            ax.tick_params(axis="both", pad=0)

            ax.set_xlabel("Tree index", fontsize=4, labelpad=2)
            ax.set_ylabel("Tree index", fontsize=4, labelpad=2)

            if show_cbar:
                cbar = heatmap.collections[0].colorbar
                cbar.ax.tick_params(labelsize=COLORBAR_TICK_FONTSIZE)

            # Subplot labels: (a), (b), ...
            ax.set_title(f"({chr(letter)})", loc="left", pad=-10, fontsize=10)
            letter += 1

            # Column title (dataset)
            if row == 0:
                ax.set_title(label, loc="center", pad=8, fontsize=11)

            # Metric label on the left side
            if col == 0:
                ax.set_ylabel("Tree index", fontsize=4)
                ax.annotate(latex_labels[metric],
                            xy=(0, 0.5),
                            xycoords="axes fraction",
                            rotation=90,
                            va="center",
                            ha="right",
                            fontsize=11,
                            xytext=(-20, 0),
                            textcoords="offset points")

    # Apply tight layout to minimize excess space
    fig.tight_layout(rect=[0, 0, 0.98, 1])

    if outfile_pdf:
        fig.savefig(outfile_pdf, bbox_inches="tight")
    if outfile_svg:
        fig.savefig(outfile_svg, bbox_inches="tight")
    plt.close(fig)
    print("[fig]  Saved", outfile_pdf, "and", outfile_svg)



# 3B.  Heatmaps + boxplots drawing

from matplotlib import gridspec
from matplotlib.patches import Patch

def _within_between_from_matrix(mat: np.ndarray, cluster_size: int = 5):
    """
    Given an n×n symmetric distance matrix with n ≈ (#clusters * cluster_size),
    return two 1D arrays: distances within clusters and between clusters.
    Clusters are assumed to be contiguous blocks of indices of length cluster_size.
    """
    n = mat.shape[0]
    idx = np.arange(n)
    clusters = [idx[i:i+cluster_size] for i in range(0, n, cluster_size)]

    within = []
    for cl in clusters:
        m = len(cl)
        for a in range(m):
            for b in range(a+1, m):
                within.append(mat[cl[a], cl[b]])

    between = []
    for i in range(len(clusters)):
        for j in range(i+1, len(clusters)):
            for a in clusters[i]:
                for b in clusters[j]:
                    between.append(mat[a, b])

    return np.asarray(within), np.asarray(between)


def draw_heatmaps_with_boxplots(matrices,
                                tree_limit=None,
                                outfile_pdf=None,
                                outfile_svg=None,
                                cluster_size: int = 5,
                                showfliers: bool = False):
    """
    Draws:
      - 4 heatmaps per row (one per dataset/species group)
      - a slim colorbar column
      - a boxed panel with 8 boxplots (within/between for each species group)
    """
    metric_list   = list(CMAPS.keys())           # rows: RF(+), RF(k-NCL), BSD(k-NCL)
    datasets_list = list(DATASETS.items())       # columns: Amphibians, Birds, Mammals, Sharks

    n_rows = len(metric_list)
    n_heatmap_cols = len(datasets_list)

    # Grid: 4 heatmaps + 1 slim colorbar + 1 spacer + 1 boxplot column
    width_ratios = [1]*n_heatmap_cols + [0.06, 0.15, 1.25]
    hspace = 0.25
    wspace = 0.03

    # Figure size scaled with extra room for colorbar + spacer + boxplots
    base_w = 2.8
    base_h = 2.6
    fig_w  = base_w * sum(width_ratios)
    fig_h  = base_h * n_rows

    fig = plt.figure(figsize=(fig_w, fig_h))
    gs  = gridspec.GridSpec(nrows=n_rows,
                            ncols=n_heatmap_cols + 3,
                            width_ratios=width_ratios,
                            hspace=hspace,
                            wspace=wspace)

    letter = ord("a")

    for row, metric in enumerate(metric_list):
        cmap_name = CMAPS[metric]
        cmap      = plt.get_cmap(cmap_name)

        # Heatmaps (4 across) + colorbar
        for col, (fname, label) in enumerate(datasets_list):
            ax = fig.add_subplot(gs[row, col])

            mat = matrices[metric][label]
            if tree_limit is not None:
                mat = mat[:tree_limit, :tree_limit]

            # Show colorbar only on last heatmap
            show_cbar = (col == n_heatmap_cols - 1)
            cax = fig.add_subplot(gs[row, n_heatmap_cols]) if show_cbar else None

            hm = sns.heatmap(
                mat,
                ax=ax,
                cmap=cmap_name,
                cbar=show_cbar,
                cbar_ax=cax,
                square=True,
                xticklabels=False,
                yticklabels=False
            )
            if show_cbar:
                hm.collections[0].colorbar.ax.tick_params(labelsize=COLORBAR_TICK_FONTSIZE)

            
            tick_labels = [str(i) for i in range(1, mat.shape[0] + 1)]
            ticks = np.arange(mat.shape[0]) + 0.5
            ax.set_xticks(ticks)
            ax.set_xticklabels(tick_labels, fontsize=3)
            ax.set_yticks(ticks)
            ax.set_yticklabels(tick_labels, fontsize=3)
            ax.tick_params(axis="both", pad=0)

            ax.set_xlabel("Tree index", fontsize=4, labelpad=2)
            ax.set_ylabel("Tree index", fontsize=4, labelpad=2)

            # Subplot letter
            ax.set_title(f"({chr(letter)})", loc="left", pad=-10, fontsize=10)
            letter += 1

            # Column title on top row
            if row == 0:
                ax.set_title(label, loc="center", pad=8, fontsize=11)

            # Metric label along the far left
            if col == 0:
                ax.set_ylabel("Tree index", fontsize=4)
                ax.annotate(latex_labels[metric],
                            xy=(0, 0.5),
                            xycoords="axes fraction",
                            rotation=90,
                            va="center",
                            ha="right",
                            fontsize=11,
                            xytext=(-20, 0),
                            textcoords="offset points")

        
        spacer_ax = fig.add_subplot(gs[row, n_heatmap_cols + 1])
        spacer_ax.axis("off")

        # Boxplot panel
        bx = fig.add_subplot(gs[row, n_heatmap_cols + 2])

        # Prepare data: [A_within, A_between, B_within, B_between, ...]
        box_data = []
        box_colors = []
        species_labels = []
        for _, label in datasets_list:
            mat = matrices[metric][label]
            if tree_limit is not None:
                mat = mat[:tree_limit, :tree_limit]

            within, between = _within_between_from_matrix(mat, cluster_size=cluster_size)
            box_data.extend([within, between])

            # Colors: lighter for 'within', darker for 'between'
            box_colors.extend([cmap(0.35), cmap(0.85)])
            species_labels.append(label)

        positions = np.arange(1, 2*len(species_labels) + 1)  # 1..8
        bp = bx.boxplot(box_data,
                        positions=positions,
                        widths=0.6,
                        patch_artist=True,
                        showfliers=showfliers)

        for patch, fc in zip(bp["boxes"], box_colors):
            patch.set_facecolor(fc)
            patch.set_edgecolor("black")
            patch.set_linewidth(0.8)
        for key in ("whiskers", "caps", "medians"):
            for artist in bp[key]:
                artist.set_color("black")
                artist.set_linewidth(0.8)

        # Vertical separators between species groups
        for g in range(1, len(species_labels)):
            bx.axvline(x=2*g + 0.5, color="0.85", lw=0.8, zorder=0)

        # X-ticks at group centers; species labels under each pair
        centers = [2*i + 1.5 for i in range(len(species_labels))]
        bx.set_xticks(centers)
        bx.set_xticklabels(species_labels, fontsize=7)

        # Boxplot axis styling
        bx.set_ylabel("")
        bx.yaxis.set_ticks_position('right')
        bx.yaxis.set_label_position('right')
        bx.tick_params(axis='y',
                       labelleft=False,
                       labelright=True,
                       labelsize=COLORBAR_TICK_FONTSIZE)

        # Borders around the panel
        for spine in bx.spines.values():
            spine.set_visible(True)
            spine.set_linewidth(1.0)

        bx.set_xlim(0.5, positions[-1] + 0.5)

        # legend in the top-right corner
        leg_handles = [
            Patch(facecolor=cmap(0.35), edgecolor="black", label="within"),
            Patch(facecolor=cmap(0.85), edgecolor="black", label="between")
        ]
        bx.legend(handles=leg_handles,
                  loc="upper right",
                  frameon=True,
                  fontsize=7,
                  labelspacing=0.3,
                  handlelength=0.8,
                  handletextpad=0.4,
                  borderpad=0.3)

        # Subplot label for the boxplot panel
        bx.set_title(f"({chr(letter)})", loc="left", pad=-10, fontsize=10)
        letter += 1

    fig.tight_layout()

    if outfile_pdf:
        fig.savefig(outfile_pdf, bbox_inches="tight")
    if outfile_svg:
        fig.savefig(outfile_svg, bbox_inches="tight")
    plt.close(fig)
    print("[fig]  Saved", outfile_pdf, "and", outfile_svg)



# 4.  Main

if __name__ == "__main__":
    matrices = compute_distance_matrices()

    # Heatmap-only figures
    draw_heatmaps(matrices,
                  tree_limit=None,
                  outfile_pdf=FIG_FULL_PDF,
                  outfile_svg=FIG_FULL_SVG)

    draw_heatmaps(matrices,
                  tree_limit=M,
                  outfile_pdf=FIG_M_PDF,
                  outfile_svg=FIG_M_SVG)

    # Heatmaps + boxplots
    draw_heatmaps_with_boxplots(matrices,
                                tree_limit=None,
                                outfile_pdf=FIG_FULL_BOX_PDF,
                                outfile_svg=FIG_FULL_BOX_SVG,
                                cluster_size=5,
                                showfliers=False)

    draw_heatmaps_with_boxplots(matrices,
                                tree_limit=M,
                                outfile_pdf=FIG_M_BOX_PDF,
                                outfile_svg=FIG_M_BOX_SVG,
                                cluster_size=5,
                                showfliers=False)
