# **SciSTree2 Tutorial**

This tutorial includes two examples to help you get started with **SciSTree2**:

- **Example I**: Running SciSTree2 with your own **probabilistic genotype matrix**.
- **Example II**: Running SciSTree2 with **raw read data** as input.


### **Importing Required Packages**

Before running SciSTree2, make sure to import the required libraries:

In [1]:
# import packages
import scistree2 as s2
import numpy as np 
import pandas as pd

### 🧬**Example I: Toy Genotype Probability Matrix**

In this example, we provide a small toy dataset where rows represent SNPs and columns represent cells.  
Each entry in the matrix denotes the probability of being the **wild type** (reference).

This format is suitable when you already have probabilistic genotypes derived from upstream processing.


We invoke **SciSTree2** with **SPR (Subtree Prune and Regraft)** local search on a small toy dataset consisting of 5 cells and 6 SNPs.

The input should be a `numpy.ndarray` where:
- **Rows** represent **SNPs**
- **Columns** represent **cells**
- Each entry contains the **probability of being wild type** (i.e., the reference allele)


|             | cell1 | cell2 | cell3 | cell4 | cell5 |
|:-----------:|:-----:|:-----:|:-----:|:-----:|:-----:|
| **snp1**    | 0.01  | 0.60  | 0.08  | 0.80  | 0.70  |
| **snp2**    | 0.80  | 0.02  | 0.70  | 0.01  | 0.30  |
| **snp3**    | 0.02  | 0.80  | 0.02  | 0.80  | 0.90  |
| **snp4**    | 0.90  | 0.90  | 0.80  | 0.80  | 0.02  |
| **snp5**    | 0.01  | 0.80  | 0.01  | 0.80  | 0.90  |
| **snp6**    | 0.05  | 0.02  | 0.70  | 0.05  | 0.90  |


In [2]:
prob = np.array([[0.01, 0.6, 0.08, 0.8, 0.7],
                 [0.8, 0.02, 0.7, 0.01, 0.3],
                 [0.02, 0.8, 0.02, 0.8, 0.9],
                 [0.9, 0.9, 0.8, 0.8, 0.02],
                 [0.01, 0.8, 0.01, 0.8, 0.9],
                 [0.05, 0.02, 0.7, 0.05, 0.9]]) 

Next, we initialize a **SciSTree2** caller with **SPR local search** enabled and set the number of threads to 8.

After calling the `infer` method, SciSTree2 returns:
- The **imputed genotype** (binary matrix)
- The **inferred tree** in **Newick format**
- The corresponding **log-likelihood** of the tree


In [3]:
caller = s2.ScisTree2(threads=8) # use 8 threads
imputed_genotype, tree, likelihood = caller.infer(prob) # run Scistree2 inference
print('Imputed genotype from SPR: \n', imputed_genotype)
print('Newick of the SPR tree: ', tree)
print('Likelihood of the SPR tree: ', likelihood)

Imputed genotype from SPR: 
 [[1 0 1 0 0]
 [0 1 0 1 0]
 [1 0 1 0 0]
 [0 0 0 0 1]
 [1 0 1 0 0]
 [1 1 1 1 0]]
Newick of the SPR tree:  (((1,3),(2,4)),5);
Likelihood of the SPR tree:  -6.27126


We can also replace **SPR** (Subtree Prune and Regraft) local search with **NNI** (Nearest Neighbor Interchange) by setting `nni=True`.

> ℹ️**Note:** Using **NNI** typically **speeds up the algorithm** but may result in **lower accuracy** compared to **SPR**. NNI is recommended when a faster approximation is needed, especially for **large datasets**.


In [4]:
caller_nni = s2.ScisTree2(threads=8, nni=True)
imputed_genotype_nni, tree_nni, likelihood_nni = caller_nni.infer(prob)
print('Imputed genotype from NNI: \n', imputed_genotype_nni)
print('Newick of the NNI tree: ', tree_nni)
print('Likelihood of the NNI tree: ', likelihood_nni)

Imputed genotype from NNI: 
 [[1 0 1 0 0]
 [0 1 0 1 0]
 [1 0 1 0 0]
 [0 0 0 0 1]
 [1 0 1 0 0]
 [1 1 1 1 0]]
Newick of the NNI tree:  (((1,3),(2,4)),5);
Likelihood of the NNI tree:  -6.27126


We may also invoke **SciSTree2** with **Neighbor Joining (NJ)** by setting `nj=True` to obtain **only the initial tree**.

In **NJ mode**, no further optimization steps (like SPR or NNI) are performed after constructing the tree.  
As a result, there are **no outputs for the imputed genotype or likelihood** by default.

However, we can still **evaluate the likelihood** and **obtain the imputed genotype** by calling the `evaluate` method with the NJ tree and genotype matrix.

> ℹ️**Note:** In this toy example, **Neighbor Joining** performs quite well and produces a tree close to the optimal.


In [5]:
caller_nj = s2.ScisTree2(threads=8, nj=True)
tree_nj = caller_nj.infer(prob)
imputed_genotype_nj, likelihood_nj = caller_nj.evaluate(prob, tree) # evaluate the NJ tree
print('Imputed genotype from NJ: \n', imputed_genotype_nj)
print('Newick of the NJ tree: ', tree_nj)
print('Likelihood of the NJ tree: ', likelihood_nj)

Imputed genotype from NJ: 
 [[1 0 1 0 0]
 [0 1 0 1 0]
 [1 0 1 0 0]
 [0 0 0 0 1]
 [1 0 1 0 0]
 [1 1 1 1 0]]
Newick of the NJ tree:  (((1,3),(2,4)),5);
Likelihood of the NJ tree:  -6.271255186813891


We can also **evaluate a random or alternative tree** using the genotype probability matrix.  
Using the same example as before, we evaluate a random tree structure.

As expected, the **likelihood** of this alternative tree is **lower** than that of the optimal tree (−6.27).

In [6]:
random_tree = '((((1,2),3),4),5);'
imputed_genotype_random, likelihood_random = caller.evaluate(prob, random_tree)
print('Imputed genotype from NJ: \n', imputed_genotype_random)
print('Newick of the NJ tree: ', likelihood_random)
print('Likelihood of the NJ tree: ', random_tree)

Imputed genotype from NJ: 
 [[1 1 1 0 0]
 [1 1 1 1 1]
 [1 1 1 0 0]
 [0 0 0 0 1]
 [1 1 1 0 0]
 [1 1 1 1 0]]
Newick of the NJ tree:  -10.835603378281727
Likelihood of the NJ tree:  ((((1,2),3),4),5);


SciSTree2 also provides functionality for **tree visualization**.

> ℹ️**Note:** SciSTree2 includes built-in support for visualizing **moderate-size** trees. For **very large trees**, it is recommended to use specialized external tools such as **FigTree**, **iTOL**, or **ETE Toolkit**.


In [7]:
t = s2.util.from_newick(tree)
t.draw()

           ┌5
 38585c439f┤
           │                     ┌4
           │          ┌1cdf0d681d┤
           │          │          └2
           └3355c1757f┤
                      │          ┌3
                      └158aa854aa┤
                                 └1


### 🧬 **Example II: Toy Raw Reads Data**

In this dataset, the input format remains a matrix of shape `(num_sites, num_cells)`.  
However, instead of using **precomputed genotype probabilities**, **each entry is a tuple** representing raw sequencing read counts.

Each tuple has the form: **`(ref_count, alt_count)`**, where:
- `ref_count` is the number of reads supporting the **reference (wild type)** allele
- `alt_count` is the number of reads supporting the **mutation (alternative)** allele

This format is suitable when you start from **raw read counts** rather than inferred genotype probabilities, enabling SciSTree2 to perform **probabilistic genotype modeling** internally before tree inference.


This dataset, located in the `data` folder, contains **50 cells** and **100 SNPs**.  


| SNP / Cell | Cell 1 | Cell 2 | Cell 3 | Cell 4 | Cell 5 | Cell 6 | Cell 7 | Cell 8 | Cell 9 | Cell 10 |
|------------|--------|--------|--------|--------|--------|--------|--------|--------|--------|---------|
| SNP 1      | (4,1)  | (4,0)  | (5,1)  | (4,0)  | (1,0)  | (4,0)  | (6,0)  | (6,0)  | (5,0)  | (5,0)   |
| SNP 2      | (4,0)  | (3,0)  | (7,1)  | (1,1)  | (3,0)  | (2,0)  | (10,0) | (11,0) | (1,0)  | (9,0)   |
| SNP 3      | (4,5)  | (11,0) | (4,0)  | (3,0)  | (4,0)  | (3,0)  | (7,0)  | (0,0)  | (2,0)  | (4,0)   |
| SNP 4      | (8,0)  | (2,0)  | (6,0)  | (1,0)  | (2,2)  | (3,5)  | (5,3)  | (1,3)  | (2,4)  | (8,0)   |
| SNP 5      | (5,0)  | (9,0)  | (4,0)  | (3,0)  | (4,0)  | (3,0)  | (7,0)  | (4,1)  | (4,0)  | (5,0)   |

> ℹ️ **Note**: Only the first 10 cells and 5 SNPs are shown here for illustration. The full dataset contains 49 SNPs × 20 cells.


In [8]:
reads = np.load('data/toy_raw_reads.npy', allow_pickle=True)
true_genotype = np.loadtxt('data/true_genotype.txt')
with open('data/true_tree.nwk', 'r') as f:
    true_tree = f.readline().strip()
print('Data preview:', reads.shape)
print('Newick of true tree', true_tree)
print('True genotype\n', true_genotype)

Data preview: (100, 50, 2)
Newick of true tree (((((((((15,27),28),12),((17,48),3)),18),31),((((37,7),8),49),2)),(((((11,19),(13,38)),36),10),(((1,35),46),34))),(((((((16,30),45),9),((4,50),29)),(((14,39),(21,22)),43)),(((((24,32),23),42),(20,26)),(((25,44),40),((47,6),41)))),(33,5)));
True genotype
 [[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [1. 1. 1. ... 1. 1. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]


Calculate the posterior genotype probability.


In [9]:
prob = s2.probability.genotype_probability(reads, posterior=True)

We then proceed to perform inference under three different modes.


In [10]:
# SPR local search
caller_spr = s2.ScisTree2(threads=8)
imputed_genotype_spr, tree_spr, likelihood_spr = caller_spr.infer(prob)
# NNI local search
caller_nni = s2.ScisTree2(nni=True, threads=8)
imputed_genotype_nni, tree_nni, likelihood_nni = caller_nni.infer(prob)
# NJ
caller_nj = s2.ScisTree2(nj=True)
tree_nj= caller_nj.infer(prob)
imputed_genotype_nj, likelihood_nj = caller_nj.evaluate(prob, tree_nj)

Our package also provides several metrics to evaluate results. These metrics include:

- **Genotype accuracy**
- **Tree accuracy** (defined as 1 minus the normalized Robinson-Foulds distance)
- **Ancestor-Descendant Error**
- **Different Lineage Error**


In [11]:
gacc_spr = s2.metric.genotype_accuarcy(true_genotype, imputed_genotype_spr)
gacc_nni = s2.metric.genotype_accuarcy(true_genotype, imputed_genotype_nni)
gacc_nj = s2.metric.genotype_accuarcy(true_genotype, imputed_genotype_nj)

tacc_spr = s2.metric.tree_accuracy(s2.util.from_newick(true_tree), s2.util.from_newick(tree_spr))
tacc_nni = s2.metric.tree_accuracy(s2.util.from_newick(true_tree), s2.util.from_newick(tree_nni))
tacc_nj = s2.metric.tree_accuracy(s2.util.from_newick(true_tree), s2.util.from_newick(tree_nj))

mutation_true = s2.metric.get_ancestor_descendant_pairs(true_genotype)
mutations_spr = s2.metric.get_ancestor_descendant_pairs(imputed_genotype_spr)
mutations_nni = s2.metric.get_ancestor_descendant_pairs(imputed_genotype_nni)
mutations_nj = s2.metric.get_ancestor_descendant_pairs(imputed_genotype_nj)
ad_err_spr = s2.metric.ancestor_descendant_error(mutation_true, mutations_spr)
ad_err_nni = s2.metric.ancestor_descendant_error(mutation_true, mutations_nni)
ad_err_nj = s2.metric.ancestor_descendant_error(mutation_true, mutations_nj)
dl_err_spr = s2.metric.different_lineage_error(mutation_true, mutations_spr)
dl_err_nni = s2.metric.different_lineage_error(mutation_true, mutations_nni)
dl_err_nj = s2.metric.different_lineage_error(mutation_true, mutations_nj)

metrics = {
    "Method": ["SPR", "NNI", "NJ"],
    "Genotype Accuracy": [gacc_spr, gacc_nni, gacc_nj],
    "Tree Accuracy": [tacc_spr, tacc_nni, tacc_nj],
    "Ancestor-Descendant Error": [ad_err_spr, ad_err_nni, ad_err_nj],
    "Different Lineage Error": [dl_err_spr, dl_err_nni, dl_err_nj]
}

# Convert to DataFrame
df_metrics = pd.DataFrame(metrics)
df_metrics

Unnamed: 0,Method,Genotype Accuracy,Tree Accuracy,Ancestor-Descendant Error,Different Lineage Error
0,SPR,0.9822,0.212766,0.476311,0.024925
1,NNI,0.9816,0.191489,0.478338,0.025922
2,NJ,0.9794,0.170213,0.50038,0.025922
