# Comparison between tree completion (by k-NCL) and tree pruning

BSD(k-NCL) vs. BSD(-)

Three scenarios are evaluated based on how each distance interprets the similarity between two trees, $T_1$ and $T_2$, relative to a reference tree $T^*$ (the constructed supertree).

* **Scenario 1 — disagreement in ordering:**

  $$
  \begin{aligned}
  \text{BSD}(k\text{-NCL})(T_1, T^*) &< \text{BSD}(k\text{-NCL})(T_2, T^*) \\
  \text{BSD}(-)(T_1, T^*) &> \text{BSD}(-)(T_2, T^*)
  \end{aligned}
  \quad \text{or} \quad
  \begin{aligned}
  \text{BSD}(k\text{-NCL})(T_2, T^*) &< \text{BSD}(k\text{-NCL})(T_1, T^*) \\
  \text{BSD}(-)(T_2, T^*) &> \text{BSD}(-)(T_1, T^*)
  \end{aligned}
  $$

* **Scenario 2 — different $k$-NCL distances, same BSD(-):**

  $$
  \begin{aligned}
  \text{BSD}(k\text{-NCL})(T_1, T^*) &\neq \text{BSD}(k\text{-NCL})(T_2, T^*) \\
  \text{BSD}(-)(T_1, T^*) &= \text{BSD}(-)(T_2, T^*)
  \end{aligned}
  $$

* **Scenario 3 — same $k$-NCL distance, different BSD(-):**

  $$
  \begin{aligned}
  \text{BSD}(k\text{-NCL})(T_1, T^*) &= \text{BSD}(k\text{-NCL})(T_2, T^*) \\
  \text{BSD}(-)(T_1, T^*) &\neq \text{BSD}(-)(T_2, T^*)
  \end{aligned}
  $$

The results include a table, line graphs, and violin charts showing the proportion of conflicts for each overlap level.

### k-NCL

In [None]:
# Please use the k-ncl script available on GitHub

"""
This script implements the k-NCL algorithm for completing phylogenetic trees
that are defined on different but overlapping taxon sets. The implementation
uses the ete3 library to work with phylogenetic trees.

Requires:  ete3  (pip install ete3)
"""

#from kncl import kNCL

### BSD(k-NCL) / BSD(+) and BSD(-)

In [None]:
import math
from itertools import combinations

def squared_distance_sum_fast_ete3(t1, t2, leaf_names):
    leaf_list = list(leaf_names)
    name_to_node_t1 = {leaf.name: leaf for leaf in t1.iter_leaves()}
    name_to_node_t2 = {leaf.name: leaf for leaf in t2.iter_leaves()}

    # Cache node-to-root distances for both trees
    dist_to_root_t1 = {name: node.get_distance(t1) for name, node in name_to_node_t1.items()}
    dist_to_root_t2 = {name: node.get_distance(t2) for name, node in name_to_node_t2.items()}

    total = 0
    for i in range(len(leaf_list)):
        for j in range(i + 1, len(leaf_list)):
            a, b = leaf_list[i], leaf_list[j]

            # Tree 1: get distance via LCA
            lca1 = t1.get_common_ancestor(a, b)
            d1 = dist_to_root_t1[a] + dist_to_root_t1[b] - 2 * lca1.get_distance(t1)

            # Tree 2: get distance via LCA
            lca2 = t2.get_common_ancestor(a, b)
            d2 = dist_to_root_t2[a] + dist_to_root_t2[b] - 2 * lca2.get_distance(t2)

            total += (d1 - d2) ** 2

    return total

def BSD(T1, T2, k=None):
    leaves1 = {leaf.name for leaf in T1.iter_leaves()}
    leaves2 = {leaf.name for leaf in T2.iter_leaves()}
    common_leaves = leaves1 & leaves2

    if len(common_leaves) < 3:
        return None, None

    T1_completed, T2_completed = kNCL(T1, T2, k)

    completed_leaves = {leaf.name for leaf in T1_completed.iter_leaves()}

    bsd_plus = math.sqrt(squared_distance_sum_fast_ete3(T1_completed, T2_completed, completed_leaves))
    bsd_minus = math.sqrt(squared_distance_sum_fast_ete3(T1_completed, T2_completed, common_leaves))

    return bsd_plus, bsd_minus


## Tree completion versus tree pruning comparison

In [None]:
import os
import sys
import pickle
import time
import tempfile
import signal
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from math import comb
from ete3 import Tree

# CONFIG
CHECKPOINT_EVERY = 1
USE_PARALLEL = True    # set False to disable joblib parallelism
N_JOBS = max(1, (os.cpu_count() or 2) - 1)  # workers for BSD precompute
REPORT_INTERVAL = 2000  # print progress every N qualifying pairs

CHECKPOINT_DIR = os.path.join(os.getcwd(), "_checkpoints")
os.makedirs(CHECKPOINT_DIR, exist_ok=True)


# Helpers

def save_atomic(path, obj, retries=10, delay=0.1):
    """
    Write a pickle with Windows-friendly retries.
    - Writes to a temp file in the same directory, fsyncs it, then replaces.
    - Retries on PermissionError (file lock).
    """
    # Destination and directory
    dest = os.path.abspath(path)
    d = os.path.dirname(dest) or "."
    os.makedirs(d, exist_ok=True)

    # Create temp file in the same directory
    fd, tmp = tempfile.mkstemp(dir=d, prefix=".tmp_ckpt_", suffix=".pkl")
    try:
        # Write + flush + fsync to the temp file
        with os.fdopen(fd, "wb") as f:
            pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)
            f.flush()
            os.fsync(f.fileno())

        # Try to replace with retries
        for attempt in range(retries):
            try:
                os.replace(tmp, dest)
                return  # success
            except PermissionError:
                # brief backoff
                time.sleep(delay * (2 ** attempt))

        try:
            if os.path.exists(dest):
                os.remove(dest)
            os.replace(tmp, dest)
            return
        finally:
            # If replace failed, clean up temp file
            if os.path.exists(tmp):
                try:
                    os.remove(tmp)
                except Exception:
                    pass

    except Exception:
        # On any unexpected error, ensure temp file is removed
        if os.path.exists(tmp):
            try:
                os.remove(tmp)
            except Exception:
                pass
        raise

def load_pickle(path):
    with open(path, "rb") as f:
        return pickle.load(f)

def load_if_exists(path):
    return load_pickle(path) if os.path.exists(path) else None


# Data loading
def load_supertrees(dataset_names):
    supertree_files = {
        "(a) Amphibians": "supertree_amphibians.txt",
        "(b) Birds": "supertree_birds.txt",
        "(c) Mammals": "supertree_mammals.txt",
        "(d) Sharks": "supertree_sharks.txt"
    }
    supertrees = []
    for name in dataset_names:
        with open(supertree_files[name], 'r') as f:
            tree = Tree(f.readline().strip(), format=1)
            supertrees.append(tree)
    return supertrees

def load_datasets(filenames):
    datasets = []
    for file in filenames:
        with open(file, 'r') as f:
            trees = [Tree(line.strip(), format=1) for line in f]
            datasets.append(trees)
    return datasets


# Core analysis

def analyze_conflicting_pairs_per_dataset(datasets, dataset_names, supertrees,
                                          checkpoint_every=CHECKPOINT_EVERY,
                                          use_parallel=USE_PARALLEL,
                                          n_jobs=N_JOBS,
                                          report_interval=REPORT_INTERVAL):
    # bin setup
    bin_edges   = np.round(np.arange(0.05, 1.0, 0.1), 2)  # [0.05, 0.15, ..., 0.95]
    bin_centers = np.round(np.arange(0.1,  1.0, 0.1), 2)  # [0.1, 0.2, ..., 0.9]
    num_bins    = len(bin_edges) - 1

    all_conflict_data = []
    all_stats_scenario1, all_stats_scenario2, all_stats_scenario3 = [], [], []
    conflict_counts = []
    all_bin_pair_counts = []

    for idx, dataset in enumerate(datasets):
        dataset_name = dataset_names[idx]
        safe_name = dataset_name.replace(" ", "_").replace("(", "").replace(")", "").replace(".", "")
        dataset_file     = f"conflict_data_{safe_name}.pkl"         # final per-dataset results
        checkpoint_file = os.path.join(CHECKPOINT_DIR, f"conflict_data_{safe_name}.checkpoint")


        # If finished before, just load and move on
        if os.path.exists(dataset_file):
            print(f"[{dataset_name}] Loaded saved results from {dataset_file}")
            (conflict_data, stats1, stats2, stats3, summary, bin_pair_counts) = load_pickle(dataset_file)
            all_conflict_data.append(conflict_data)
            all_stats_scenario1.append(stats1)
            all_stats_scenario2.append(stats2)
            all_stats_scenario3.append(stats3)
            conflict_counts.append(summary)
            all_bin_pair_counts.append(bin_pair_counts)
            continue

        print(f"\n[{dataset_name}] Processing dataset...")

        base_tree = supertrees[idx].copy()

        # Precompute leaf sets once per tree for fast Jaccard coefficient calculation
        leaf_sets = [set(t.get_leaf_names()) for t in dataset]

        # Precompute BSD(tree, base) once per tree; also record per-tree status for compatibility
        # _bsd_from_newick returns ((bplus, bminus), status) with status in {'ok','none','exc'}
        def _bsd_from_newick(tree_newick, base_newick):
            try:
                t = Tree(tree_newick, format=1)
                base = Tree(base_newick, format=1)
                bplus, bminus = BSD(t, base)
                if None in (bplus, bminus):
                    return (None, 'none')
                return ((int(bplus), int(bminus)), 'ok')
            except Exception:
                return (None, 'exc')

        base_newick = base_tree.write(format=1)
        tree_newicks = [t.write(format=1) for t in dataset]

        if use_parallel:
            try:
                from joblib import Parallel, delayed
                print(f"[{dataset_name}] Precomputing BSD for {len(dataset)} trees using {n_jobs} workers...")
                pairs = Parallel(n_jobs=n_jobs, prefer="processes")(
                    delayed(_bsd_from_newick)(nw, base_newick) for nw in tree_newicks
                )
            except Exception as e:
                print(f"[{dataset_name}] Parallel BSD precompute unavailable ({e}). Falling back to serial.")
                pairs = [_bsd_from_newick(nw, base_newick) for nw in tree_newicks]
        else:
            print(f"[{dataset_name}] Precomputing BSD serially for {len(dataset)} trees...")
            pairs = [_bsd_from_newick(nw, base_newick) for nw in tree_newicks]

        bsd_cache  = [p[0] for p in pairs]     # None or (bplus, bminus)
        statuses   = [p[1] for p in pairs]     # 'ok' | 'none' | 'exc'

        # Setup accumulators (resume if checkpoint exists)
        ckpt = load_if_exists(checkpoint_file)
        if ckpt is None:
            bin_counts_scenario1 = [[] for _ in range(num_bins)]
            bin_counts_scenario2 = [[] for _ in range(num_bins)]
            bin_counts_scenario3 = [[] for _ in range(num_bins)]
            bin_pair_counts = [0 for _ in range(num_bins)]

            # Stats used for outputs
            stats = {
                's1': 0, 's2': 0, 's3': 0,
                'bsd_exception_pairs': 0,
                'bsd_none_pairs': 0,
            }

            n = len(dataset)
            total_pairs_expected = n * (n - 1) // 2
            total_pairs = 0      # number of valid (binned) pairs processed
            progress_count = 0   # count of all i<j pairs visited (for progress display)
            start_i = 0
        else:
            (bin_counts_scenario1, bin_counts_scenario2, bin_counts_scenario3, bin_pair_counts,
             stats, total_pairs, progress_count, start_i) = ckpt
            n = len(dataset)
            total_pairs_expected = n * (n - 1) // 2

        # Checkpoint-related lines
        current_i_holder = {'i': start_i}
        done_flag = {'done': False}

        def write_ckpt_and_exit(signum=None, frame=None):
            if not done_flag['done']:
                save_atomic(checkpoint_file,
                            (bin_counts_scenario1, bin_counts_scenario2, bin_counts_scenario3,
                             bin_pair_counts, stats, total_pairs, progress_count, current_i_holder['i']))
            if signum is not None:
                sys.exit(1)

        for sig_name in ("SIGINT", "SIGTERM"):
            sig = getattr(signal, sig_name, None)
            if sig is not None:
                signal.signal(sig, write_ckpt_and_exit)

        # Main pairwise loop
        
        for i in range(start_i, len(dataset)):
            current_i_holder['i'] = i
            s1 = leaf_sets[i]
            b1 = bsd_cache[i]
            st1 = statuses[i]

            for j in range(i + 1, len(dataset)):
                s2 = leaf_sets[j]

                progress_count += 1
                if report_interval and (progress_count % report_interval == 0):
                    percent = (progress_count / total_pairs_expected) * 100
                    print(f"[{dataset_name}] {progress_count}/{total_pairs_expected} pairs visited ({percent:.1f}%)")

                # Jaccard overlap
                inter = len(s1 & s2)
                uni   = len(s1 | s2)
                overlap = (inter / uni) if uni else 0.0

                # Apply original overlap filter before considering BSD outcomes
                if (overlap < 0.05) or (overlap >= 0.95):
                    continue

                # Determine bin
                bin_index = np.digitize([overlap], bin_edges)[0] - 1
                if bin_index < 0 or bin_index >= num_bins:
                    continue  # safety

                st2 = statuses[j]
                b2  = bsd_cache[j]

                # Reproduce original per-pair failure bookkeeping after overlap filter
                if st1 == 'exc' or st2 == 'exc':
                    stats['bsd_exception_pairs'] += 1
                    continue
                if st1 == 'none' or st2 == 'none':
                    stats['bsd_none_pairs'] += 1
                    continue

                # If Both BSD available, then proceed to scenarios
                d1_T1, d2_T1 = b1[0], b1[1]
                d1_T2, d2_T2 = b2[0], b2[1]

                total_pairs += 1
                bin_pair_counts[bin_index] += 1

                # Scenario 1
                if (d2_T1 < d2_T2 and d1_T1 > d1_T2) or (d2_T2 < d2_T1 and d1_T2 > d1_T1):
                    stats['s1'] += 1
                    bin_counts_scenario1[bin_index].append(1)
                else:
                    bin_counts_scenario1[bin_index].append(0)

                # Scenario 2
                if d2_T1 == d2_T2 and d1_T1 != d1_T2:
                    stats['s2'] += 1
                    bin_counts_scenario2[bin_index].append(1)
                else:
                    bin_counts_scenario2[bin_index].append(0)

                # Scenario 3
                if d2_T1 != d2_T2 and d1_T1 == d1_T2:
                    stats['s3'] += 1
                    bin_counts_scenario3[bin_index].append(1)
                else:
                    bin_counts_scenario3[bin_index].append(0)

            # Save a checkpoint after each i (or every N trees)
            if (i + 1) % checkpoint_every == 0:
                save_atomic(checkpoint_file,
                            (bin_counts_scenario1, bin_counts_scenario2, bin_counts_scenario3,
                             bin_pair_counts, stats, total_pairs, progress_count, i + 1))
                print(f"[{dataset_name}] Checkpoint saved at i={i + 1}/{len(dataset)}")

        # Finished dataset
        print(f"[{dataset_name}] Finished processing {total_pairs} valid tree pairs.")
        done_flag['done'] = True  # suppress further checkpoint writes in handlers

        # Summaries
        
        def compute_stats(bin_data):
            out = {'mean': [], 'std': [], 'median': [], 'q1': [], 'q3': [], 'ci_lower': [], 'ci_upper': []}
            for data in bin_data:
                if data:
                    arr = np.array(data)
                    out['mean'].append(np.mean(arr))
                    out['std'].append(np.std(arr, ddof=1) if len(arr) > 1 else 0)
                    out['median'].append(np.median(arr))
                    out['q1'].append(np.percentile(arr, 25))
                    out['q3'].append(np.percentile(arr, 75))
                    ci = np.percentile(arr, [2.5, 97.5])
                    out['ci_lower'].append(ci[0])
                    out['ci_upper'].append(ci[1])
                else:
                    for k in out:
                        out[k].append(0)
            return out

        stats1 = compute_stats(bin_counts_scenario1)
        stats2 = compute_stats(bin_counts_scenario2)
        stats3 = compute_stats(bin_counts_scenario3)

        conflict_data = {
            'fractions_scenario1': [np.mean(b) if b else 0 for b in bin_counts_scenario1],
            'fractions_scenario2': [np.mean(b) if b else 0 for b in bin_counts_scenario2],
            'fractions_scenario3': [np.mean(b) if b else 0 for b in bin_counts_scenario3]
        }

        # Build summary
        summary = {
            'Scenario 1': stats['s1'],
            'Scenario 2': stats['s2'],
            'Scenario 3': stats['s3'],
            'Total Pairs Processed': total_pairs,
            'Skipped Pairs (Zero Overlap)': 0,
            'Skipped Pairs (Full Overlap)': 0,
            'Skipped Pairs (BSD None)': stats['bsd_none_pairs'],
            'Skipped Pairs (BSD Exceptions)': stats['bsd_exception_pairs'],
        }

        # Save final per-dataset result and remove checkpoint
        save_atomic(dataset_file, (conflict_data, stats1, stats2, stats3, summary, bin_pair_counts))
        if os.path.exists(checkpoint_file):
            try:
                os.remove(checkpoint_file)
            except OSError:
                pass
        print(f"[{dataset_name}] Results saved to {dataset_file}\n")

        # Collect for return
        all_conflict_data.append(conflict_data)
        all_stats_scenario1.append(stats1)
        all_stats_scenario2.append(stats2)
        all_stats_scenario3.append(stats3)
        conflict_counts.append(summary)
        all_bin_pair_counts.append(bin_pair_counts)

    return bin_centers, conflict_counts, all_conflict_data, all_stats_scenario1, all_stats_scenario2, all_stats_scenario3, all_bin_pair_counts


# Reporting 
def create_conflict_dataframe(conflict_counts, dataset_names):
    df = pd.DataFrame(conflict_counts, index=dataset_names)
    print("\nConflict Summary:")
    print(df.T)
    return df

def plot_conflicts_by_scenario(overlap_ranges, all_conflict_data, all_stats_scenario1, all_stats_scenario2, all_stats_scenario3, dataset_names, all_bin_pair_counts):
    scenarios = [
        ('Scenario 1', 'fractions_scenario1', all_stats_scenario1),
        ('Scenario 2', 'fractions_scenario2', all_stats_scenario2),
        ('Scenario 3', 'fractions_scenario3', all_stats_scenario3),
    ]
    group_colors = ['tab:blue', 'tab:green', 'tab:orange', 'tab:red']

    for sidx, (sname, frac_key, stat_list) in enumerate(scenarios):
        plt.figure(figsize=(10, 7))
        for idx, group in enumerate(dataset_names):
            y = all_conflict_data[idx][frac_key]
            plt.plot(overlap_ranges, y, label=group, marker='o', color=group_colors[idx])
            for x, val, cnt in zip(overlap_ranges, y, all_bin_pair_counts[idx]):
                plt.text(x, val, f"{cnt}", color=group_colors[idx], fontsize=9, ha='center', va='bottom')
        plt.xlabel('Overlap ratio', fontsize=14)
        plt.ylabel('Fraction of conflicts', fontsize=14)
        plt.title(f"{sname}: Conflict Fractions by Group", fontsize=16)
        plt.xlim(0, 1)
        plt.ylim(0, 1)
        plt.xticks(np.arange(0, 1.1, 0.1))
        plt.grid(True)
        plt.legend(title='Group', fontsize=12)
        plt.tight_layout()
        plt.savefig(f'{sname.replace(" ", "_").lower()}_conflicts_by_group.svg')
        plt.savefig(f'{sname.replace(" ", "_").lower()}_conflicts_by_group.pdf')
        plt.show()

    plt.figure(figsize=(10, 7))
    for idx, group in enumerate(dataset_names):
        y_total = (np.array(all_conflict_data[idx]['fractions_scenario1']) +
                   np.array(all_conflict_data[idx]['fractions_scenario2']) +
                   np.array(all_conflict_data[idx]['fractions_scenario3']))
        plt.plot(overlap_ranges, y_total, label=group, marker='o', color=group_colors[idx])
        for x, val, cnt in zip(overlap_ranges, y_total, all_bin_pair_counts[idx]):
            plt.text(x, val, f"{cnt}", color=group_colors[idx], fontsize=9, ha='center', va='bottom')
    plt.xlabel('Overlap ratio', fontsize=14)
    plt.ylabel('Total fraction of conflicts', fontsize=14)
    plt.title('All Scenarios: Total Conflict Fractions by Group', fontsize=16)
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.xticks(np.arange(0, 1.1, 0.1))
    plt.grid(True)
    plt.legend(title='Group', fontsize=12)
    plt.tight_layout()
    plt.savefig('all_scenarios_conflicts_by_group2.svg')
    plt.savefig('all_scenarios_conflicts_by_group2.pdf')
    plt.show()


# Main
if __name__ == "__main__":
    data_file = 'conflict_data_per_dataset.pkl'
    dataset_names = ["(a) Amphibians", "(b) Birds", "(c) Mammals", "(d) Sharks"]
    raw_dataset_names = ["amphibians170.txt", "birds100.txt", "mammals140.txt", "sharks100.txt"]

    if os.path.exists(data_file):
        with open(data_file, 'rb') as f:
            (overlap_bins, conflict_counts, all_conflict_data,
             all_stats_scenario1, all_stats_scenario2, all_stats_scenario3, all_bin_pair_counts) = pickle.load(f)
        print(f"Loaded aggregated results from {data_file}")
    else:
        datasets = load_datasets(raw_dataset_names)
        supertrees = load_supertrees(dataset_names)
        (overlap_bins, conflict_counts, all_conflict_data,
         all_stats_scenario1, all_stats_scenario2, all_stats_scenario3, all_bin_pair_counts) = \
            analyze_conflicting_pairs_per_dataset(datasets, dataset_names, supertrees)
        save_atomic(data_file, (overlap_bins, conflict_counts, all_conflict_data,
                                all_stats_scenario1, all_stats_scenario2, all_stats_scenario3, all_bin_pair_counts))
        print(f"Aggregated results saved to {data_file}")

    create_conflict_dataframe(conflict_counts, dataset_names)
    plot_conflicts_by_scenario(
        overlap_bins, all_conflict_data,
        all_stats_scenario1, all_stats_scenario2, all_stats_scenario3,
        dataset_names, all_bin_pair_counts
    )


## Violin charts

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Inputs expected in the current working directory:
  - Preferred: conflict_data_per_dataset.pkl
    Tuple: (overlap_bins, conflict_counts, all_conflict_data,
            all_stats_scenario1, all_stats_scenario2, all_stats_scenario3,
            all_bin_pair_counts)
  - Fallback (if the file above is absent) — four per-dataset pickles:
    - conflict_data_a_Amphibians.pkl
    - conflict_data_b_Birds.pkl
    - conflict_data_c_Mammals.pkl
    - conflict_data_d_Sharks.pkl
    Each is a 6-tuple: (conflict_data, stats1, stats2, stats3, summary, bin_pair_counts)
    where conflict_data has keys:
      'fractions_scenario1', 'fractions_scenario2', 'fractions_scenario3'

Outputs:
  - all_scenarios_line_plus_violins.svg
  - all_scenarios_line_plus_violins.pdf
  - and individual panels:
      conflicts_line_top.(svg/pdf)
      conflicts_violins_bottom.(svg/pdf)
"""

import os
import re
import sys
import pickle
import numpy as np
import matplotlib.pyplot as plt


def _install_numpy_core_aliases():
    """Make NumPy 1.x compatible with pickles saved under NumPy 2.x (numpy._core...)."""
    try:
        import numpy as _np
        if 'numpy._core' not in sys.modules:
            sys.modules['numpy._core'] = _np.core
        for sub in ['multiarray', 'numeric', 'overrides', 'umath']:
            key = f'numpy._core.{sub}'
            if key not in sys.modules and hasattr(_np.core, sub):
                sys.modules[key] = getattr(_np.core, sub)
    except Exception:
        pass

def compat_pickle_load(path):
    try:
        with open(path, "rb") as f:
            return pickle.load(f)
    except ModuleNotFoundError as e:
        if 'numpy._core' in str(e):
            _install_numpy_core_aliases()
            with open(path, "rb") as f:
                return pickle.load(f)
        raise

# Data loading
def load_data():
    """
    Returns:
      overlap_bins            : np.ndarray of bin centers (e.g., 0.1..0.9)
      dataset_names           : list of strings in the original label order
      all_conflict_data       : list of dicts (one per dataset) with scenario fraction arrays
      all_bin_pair_counts     : list of lists, per dataset -> per bin counts
    """
    agg = "conflict_data_per_dataset.pkl"
    dataset_names = ["(a) Amphibians", "(b) Birds", "(c) Mammals", "(d) Sharks"]

    if os.path.exists(agg):
        (overlap_bins, _conflict_counts, all_conflict_data,
         _s1, _s2, _s3, all_bin_pair_counts) = compat_pickle_load(agg)
        return np.asarray(overlap_bins), dataset_names, all_conflict_data, all_bin_pair_counts

    # Fallback to per-dataset files
    parts = [
        "conflict_data_a_Amphibians.pkl",
        "conflict_data_b_Birds.pkl",
        "conflict_data_c_Mammals.pkl",
        "conflict_data_d_Sharks.pkl",
    ]
    all_conflict_data = []
    all_bin_pair_counts = []
    for p in parts:
        if not os.path.exists(p):
            raise FileNotFoundError(
                f"Missing '{agg}' and per-dataset file '{p}'. "
                "Place the pickles in this directory or update the paths."
            )
        conflict_data, stats1, stats2, stats3, summary, bin_pair_counts = compat_pickle_load(p)
        all_conflict_data.append(conflict_data)
        all_bin_pair_counts.append(bin_pair_counts)

    # Infer bin centers if using fallback
    overlap_bins = np.round(np.arange(0.1, 1.0, 0.1), 2)
    return overlap_bins, dataset_names, all_conflict_data, all_bin_pair_counts

# Helpers

def clean_group_name(name: str) -> str:
    """Remove the leading '(a) ' etc from dataset names for legends."""
    return re.sub(r"^\([a-d]\)\s*", "", name, flags=re.IGNORECASE)

def compute_totals(all_conflict_data):
    """Sum fractions across the three scenarios for each dataset."""
    totals = []
    for d in all_conflict_data:
        y_total = (np.array(d['fractions_scenario1']) +
                   np.array(d['fractions_scenario2']) +
                   np.array(d['fractions_scenario3']))
        totals.append(np.asarray(y_total))
    return totals  # list length=4, each an array length n_bins

# Plotting

def plot_top_line(ax, overlap_bins, dataset_names, totals_by_group, all_bin_pair_counts, panel_label="(a)"):
    """
    Draw line chart and overlay per-bin counts (colored to match the line).
    Legend is under the plot.
    """
    # Plot lines and annotate counts
    for name, y, bin_counts in zip(dataset_names, totals_by_group, all_bin_pair_counts):
        (ln,) = ax.plot(overlap_bins, y, marker='o', label=clean_group_name(name))
        color = ln.get_color()
        # numbers at each point
        for x, val, cnt in zip(overlap_bins, y, bin_counts):
            ax.text(x, val, f"{cnt}", color=color, fontsize=9, ha='center', va='bottom')

    ax.set_xlabel(r"Overlap level $p$")
    ax.set_ylabel("Total fraction of conflicts")
    ax.set_title("(a)", loc="left", pad=6)

    ax.set_xlim(0.0, 1.0)
    ax.set_ylim(0.0, 0.2)
    ax.set_xticks(np.round(np.arange(0.1, 1.0, 0.1), 2))
    ax.set_yticks([0.0, 0.1, 0.2])
    ax.grid(True, linestyle='--', linewidth=0.5, alpha=0.6)

    # Legend under the plot
    ax.legend(
        loc='upper center',
        bbox_to_anchor=(0.5, -0.18),
        ncol=len(dataset_names),
        frameon=False
    )

def plot_bottom_violins(ax, overlap_bins, totals_by_group, panel_label="(b)",
                        jitter=True, show_means=False, show_medians=True,
                        median_color="black", median_linewidth=1.6):
    """
    For each overlap bin k, build a small distribution from the 4 groups' totals -> violin at x=bin center.
    Shows only the median line by default.
    """
    data_per_bin = [[tg[k] for tg in totals_by_group] for k in range(len(overlap_bins))]

    parts = ax.violinplot(
        dataset=data_per_bin,
        positions=overlap_bins,
        showmeans=show_means,
        showmedians=show_medians,
        showextrema=False,
        widths=0.07
    )

    # Style the median line for clarity
    if show_medians and isinstance(parts, dict) and 'cmedians' in parts and parts['cmedians'] is not None:
        parts['cmedians'].set_color(median_color)
        parts['cmedians'].set_linewidth(median_linewidth)

    # Overlay the actual group values per bin
    for x, ys in zip(overlap_bins, data_per_bin):
        for y in ys:
            ax.scatter(x, y, s=18)

    ax.set_xlabel(r"Overlap level $p$")
    ax.set_ylabel("Total fraction of conflicts")
    ax.set_title("(b)", loc="left", pad=6)

    ax.set_xlim(0.0, 1.0)
    ax.set_ylim(0.0, 0.2)
    ax.set_xticks(np.round(np.arange(0.1, 1.0, 0.1), 2))
    ax.set_yticks([0.0, 0.1, 0.2])
    ax.grid(True, linestyle='--', linewidth=0.5, alpha=0.6)

# Main

def main():
    overlap_bins, dataset_names, all_conflict_data, all_bin_pair_counts = load_data()
    totals_by_group = compute_totals(all_conflict_data)

    # Create a single figure with two vertical panels
    fig = plt.figure(figsize=(10, 9))
    gs = fig.add_gridspec(nrows=2, ncols=1, height_ratios=[1, 1.1])

    ax_top = fig.add_subplot(gs[0, 0])
    plot_top_line(ax_top, overlap_bins, dataset_names, totals_by_group, all_bin_pair_counts, panel_label="(a)")

    ax_bottom = fig.add_subplot(gs[1, 0])
    # Median only
    plot_bottom_violins(
        ax_bottom, overlap_bins, totals_by_group,
        panel_label="(b)", jitter=True,
        show_means=False, show_medians=True,
        median_color="black", median_linewidth=1.6
    )

    # Extra bottom margin
    fig.subplots_adjust(hspace=0.35, bottom=0.14, top=0.95)

    # Save combined outputs
    out_svg = "all_scenarios_line_plus_violins_median.svg"
    out_pdf = "all_scenarios_line_plus_violins_median.pdf"
    fig.savefig(out_svg, bbox_inches="tight")
    fig.savefig(out_pdf, bbox_inches="tight")

    # (Optional) also save each panel separately
    # Top only
    fig_top, ax_t = plt.subplots(figsize=(10, 4.8))
    plot_top_line(ax_t, overlap_bins, dataset_names, totals_by_group, all_bin_pair_counts, panel_label="(a)")
    fig_top.tight_layout()
    fig_top.savefig("conflicts_line_top_median.svg", bbox_inches="tight")
    fig_top.savefig("conflicts_line_top_median.pdf", bbox_inches="tight")
    #fig_top.savefig("conflicts_line_top.png", bbox_inches="tight", dpi=300)
    plt.close(fig_top)

    # Bottom only
    fig_bot, ax_b = plt.subplots(figsize=(10, 5.2))
    plot_bottom_violins(
        ax_b, overlap_bins, totals_by_group,
        panel_label="(b)", jitter=True,
        show_means=False, show_medians=True,
        median_color="black", median_linewidth=1.6
    )
    fig_bot.tight_layout()
    fig_bot.savefig("conflicts_violins_bottom_median.svg", bbox_inches="tight")
    fig_bot.savefig("conflicts_violins_bottom_median.pdf", bbox_inches="tight")
    #fig_bot.savefig("conflicts_violins_bottom.png", bbox_inches="tight", dpi=300)
    plt.close(fig_bot)

    plt.close(fig)
    print(f"Saved:\n  {out_svg}\n  {out_pdf}\n"
          f"  conflicts_line_top.(svg/pdf)\n  conflicts_violins_bottom.(svg/pdf)")

if __name__ == "__main__":
    main()