# Discrete phase-type distribution of mutations

In [None]:
# Always import phasic first to set jax backend correctly
import phasic
import numpy as np
np.random.seed(42)
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats('retina', 'png')
import matplotlib
matplotlib.rcParams['figure.figsize'] = (5, 3.7)
sns.set_context('paper', font_scale=0.9)
phasic.set_theme('dark')

In [None]:
from phasic import Graph
from phasic.state_indexing import Property, StateSpace

# State space for two-locus, two-island model

Sample size:

In [None]:
s = 4

Define state space for two loci, two populations:

In [None]:
# Define state space
state_space = StateSpace([
    Property('descendants_l1', max_value=s),
    Property('descendants_l2', max_value=s),
    Property('population', max_value=2, offset=1)  # populations 1, 2
])

# State size (total number of states)
state_size = state_space.size
print(f"State size: {state_size}")

# Build continuous two-locus, two-island graph

In [None]:
def construct_twolocus_island_graph(s, N1, N2, M1, M2, R):
    """
    Construct a two-locus, two-island coalescent graph.
    
    Parameters:
    -----------
    s : int
        Sample size
    N1, N2 : float
        Effective population sizes for populations 1 and 2
    M1, M2 : float
        Migration rates from pop 1 and pop 2
    R : float
        Recombination rate
    """
    
    # Define state space
    state_space = StateSpace([
        Property('descendants_l1', max_value=s),
        Property('descendants_l2', max_value=s),
        Property('population', max_value=2, offset=1)
    ])
    
    def callback(state):
        # Convert state to properties
        props = state_space.index_to_props_dict(state[0])
        a = props['descendants_l1']
        b = props['descendants_l2']
        p = props['population']
        
        # Count total lineages
        count = a + b
        if count <= 1:
            return []  # Absorbing state
        
        N = N1 if p == 1 else N2
        M = M1 if p == 1 else M2
        other_pop = 2 if p == 1 else 1
        
        transitions = []
        
        # Coalescence within locus 1
        if a >= 2:
            rate = a * (a - 1) / 2 / N
            new_props = {'descendants_l1': a - 1, 'descendants_l2': b, 'population': p}
            new_state = np.array([state_space.props_to_index(new_props)])
            transitions.append((new_state, rate))
        
        # Coalescence within locus 2
        if b >= 2:
            rate = b * (b - 1) / 2 / N
            new_props = {'descendants_l1': a, 'descendants_l2': b - 1, 'population': p}
            new_state = np.array([state_space.props_to_index(new_props)])
            transitions.append((new_state, rate))
        
        # Coalescence between loci
        if a >= 1 and b >= 1:
            rate = a * b / N
            new_props = {'descendants_l1': a - 1, 'descendants_l2': b - 1, 'population': p}
            new_state = np.array([state_space.props_to_index(new_props)])
            transitions.append((new_state, rate))
        
        # Recombination
        if a > 0 and b > 0:
            rate = R
            new_props = {'descendants_l1': a + 1, 'descendants_l2': b + 1, 'population': p}
            new_state = np.array([state_space.props_to_index(new_props)])
            transitions.append((new_state, rate))
        
        # Migration
        if a > 0 or b > 0:
            rate = M
            new_props = {'descendants_l1': a, 'descendants_l2': b, 'population': other_pop}
            new_state = np.array([state_space.props_to_index(new_props)])
            transitions.append((new_state, rate))
        
        return transitions
    
    # Create starting state (all lineages at locus 1 in population 1)
    starting_props = {'descendants_l1': s, 'descendants_l2': 0, 'population': 1}
    starting_state = np.array([state_space.props_to_index(starting_props)])
    
    graph = Graph(
        state_length=1,
        callback=callback,
        starting_state=starting_state
    )
    
    return graph, state_space

In [None]:
# Build graph with sample size s=4
graph, state_space = construct_twolocus_island_graph(s, 1, 1, 0, 0, 0)
print(f"Number of vertices: {graph.vertices_length()}")

# Turn into discrete version

In [None]:
def make_discrete(graph, state_space, mutation_rate):
    """
    Takes a graph for a continuous distribution and turns
    it into a discrete one (in-place). Returns a matrix of
    rewards for computing marginal moments.
    
    Parameters:
    -----------
    graph : Graph
        The continuous-time graph
    state_space : StateSpace
        State space for the model
    mutation_rate : float
        Mutation rate (self-transition rate)
    
    Returns:
    --------
    rewards : ndarray
        Reward matrix (state_size x n_features)
    """
    # Current number of vertices
    vlength = graph.vertices_length()
    
    # Number of fields in state vector
    n_properties = len(state_space.properties)
    
    # Track which state vector indexes to reward at each auxiliary node
    rewarded_indexes = {}
    
    # Loop through all vertices (skip starting vertex at index 0)
    for i in range(1, vlength):
        vertex = graph.vertex_at(i)
        
        if vertex.rate() > 0:  # Not absorbing
            state_idx = int(vertex.state()[0])
            props = state_space.index_to_props_dict(state_idx)
            
            # Check each property
            for prop_name, prop_value in props.items():
                if prop_value > 0:  # Only ones we may reward
                    # Add auxiliary node with zero state
                    zero_state = np.array([state_space.props_to_index(
                        {name: 0 for name in state_space.property_names()}
                    )])
                    mutation_vertex = graph.find_or_create_vertex(zero_state)
                    
                    # Add edges
                    mutation_vertex.add_edge(vertex, 1.0)
                    vertex.add_edge(mutation_vertex, mutation_rate * prop_value)
                    
                    # Track rewards
                    aux_idx = mutation_vertex.index()
                    if aux_idx not in rewarded_indexes:
                        rewarded_indexes[aux_idx] = []
                    rewarded_indexes[aux_idx].append(prop_name)
    
    # Normalize graph
    graph.normalize()
    
    # Build reward matrix
    new_vlength = graph.vertices_length()
    rewards = np.zeros((new_vlength, n_properties))
    
    prop_name_to_idx = {name: i for i, name in enumerate(state_space.property_names())}
    
    for vertex_idx, prop_names in rewarded_indexes.items():
        for prop_name in prop_names:
            prop_idx = prop_name_to_idx[prop_name]
            rewards[vertex_idx, prop_idx] = 1
    
    return rewards

In [None]:
# Self-transition rate
mutation_rate = 1

# Clone graph to get one to modify
mutation_graph = graph.copy()

# Add auxiliary states, normalize and return reward matrix
rewards = make_discrete(mutation_graph, state_space, mutation_rate)

print(f"Number of vertices after making discrete: {mutation_graph.vertices_length()}")
print(f"Reward matrix shape: {rewards.shape}")

# Sanity checks

In [None]:
# Compute expectation with sum of all rewards
total_rewards = np.sum(rewards, axis=1)
exp_value = mutation_graph.expectation(total_rewards)
print(f"Expectation: {exp_value}")

# Helper functions for computing expectations and covariances

In [None]:
def broadcast_props_to_index(state_space, s, a, b, p):
    """
    Helper function to get state index from properties.
    Mimics the R function for compatibility.
    """
    props = {'descendants_l1': a, 'descendants_l2': b, 'population': p}
    return state_space.props_to_index(props)

In [None]:
def disc_two_locus_expectation(s, N1, N2, M1, M2, R):
    """
    Compute expected ARG branches with i and j descendants at each locus.
    Discrete version.
    """
    # Build continuous graph
    graph, state_space = construct_twolocus_island_graph(s, N1, N2, M1, M2, R)
    
    # Make discrete
    mutation_graph = graph.copy()
    rewards = make_discrete(mutation_graph, state_space, 1.0)
    
    # Get property indices
    prop_names = state_space.property_names()
    l1_idx = prop_names.index('descendants_l1')
    l2_idx = prop_names.index('descendants_l2')
    
    # Compute expectation matrix
    exp_mat = np.zeros((s + 1, s + 1))
    
    for i in range(s + 1):
        for j in range(s + 1):
            # Get rewards for this configuration (sum over both populations)
            reward_vec = np.zeros(mutation_graph.vertices_length())
            for p in [1, 2]:
                state_idx = broadcast_props_to_index(state_space, s, i, j, p)
                # Get all vertices with this state
                for v_idx in range(mutation_graph.vertices_length()):
                    vertex = mutation_graph.vertex_at(v_idx)
                    if int(vertex.state()[0]) == state_idx:
                        reward_vec[v_idx] = rewards[v_idx, l1_idx] + rewards[v_idx, l2_idx]
            
            exp_mat[i, j] = mutation_graph.expectation(reward_vec)
    
    return exp_mat

In [None]:
def disc_ton_covariance_between_loci(s, N1, N2, M1, M2, R):
    """
    Compute covariance between branch lengths at different loci.
    Discrete version.
    """
    # Build continuous graph
    graph, state_space = construct_twolocus_island_graph(s, N1, N2, M1, M2, R)
    
    # Make discrete
    mutation_graph = graph.copy()
    rewards = make_discrete(mutation_graph, state_space, 1.0)
    
    # Get property indices
    prop_names = state_space.property_names()
    l1_idx = prop_names.index('descendants_l1')
    l2_idx = prop_names.index('descendants_l2')
    
    def locus1_rewards(i):
        """Get rewards for locus 1 with i descendants."""
        reward_vec = np.zeros(mutation_graph.vertices_length())
        for j in range(1, s + 1):
            for p in [1, 2]:
                state_idx = broadcast_props_to_index(state_space, s, i, j, p)
                for v_idx in range(mutation_graph.vertices_length()):
                    vertex = mutation_graph.vertex_at(v_idx)
                    if int(vertex.state()[0]) == state_idx:
                        reward_vec[v_idx] += rewards[v_idx, l1_idx] + rewards[v_idx, l2_idx]
        return reward_vec
    
    def locus2_rewards(j):
        """Get rewards for locus 2 with j descendants."""
        reward_vec = np.zeros(mutation_graph.vertices_length())
        for i in range(1, s + 1):
            for p in [1, 2]:
                state_idx = broadcast_props_to_index(state_space, s, i, j, p)
                for v_idx in range(mutation_graph.vertices_length()):
                    vertex = mutation_graph.vertex_at(v_idx)
                    if int(vertex.state()[0]) == state_idx:
                        reward_vec[v_idx] += rewards[v_idx, l1_idx] + rewards[v_idx, l2_idx]
        return reward_vec
    
    # Compute covariance matrix
    cov_mat = np.zeros((s - 1, s - 1))
    
    for i in range(1, s):
        for j in range(1, s):
            r1 = locus1_rewards(i)
            r2 = locus2_rewards(j)
            cov_mat[i - 1, j - 1] = mutation_graph.covariance(r1, r2)
    
    return cov_mat

# Test covariance computation

No recombination, no migration:

In [None]:
disc_cov_mat_no_rec_no_mig = disc_ton_covariance_between_loci(4, 1, 1, 0.5, 0.5, 2)

In [None]:
plt.figure(figsize=(7, 5))
ticks = list(range(1, s))
ax = sns.heatmap(disc_cov_mat_no_rec_no_mig, cmap="PiYG", 
                annot=True,
                center=0,
                yticklabels=ticks,
                xticklabels=ticks
                )
ax.invert_yaxis()
plt.tight_layout()

# Compare to continuous version

In [None]:
def ton_covariance_between_loci(s, N, M, R):
    """
    Compute covariance between branch lengths at different loci.
    Continuous version.
    """
    graph, state_space = construct_twolocus_island_graph(s, N, N, M, M, R)
    
    # Get rewards from state (for continuous, reward = state itself)
    prop_names = state_space.property_names()
    l1_idx = prop_names.index('descendants_l1')
    l2_idx = prop_names.index('descendants_l2')
    
    def locus1_rewards(i):
        """Get rewards for locus 1 with i descendants."""
        reward_vec = np.zeros(graph.vertices_length())
        for j in range(1, s + 1):
            for p in [1, 2]:
                state_idx = broadcast_props_to_index(state_space, s, i, j, p)
                for v_idx in range(graph.vertices_length()):
                    vertex = graph.vertex_at(v_idx)
                    if int(vertex.state()[0]) == state_idx:
                        props = state_space.index_to_props_dict(state_idx)
                        reward_vec[v_idx] = props['descendants_l1'] + props['descendants_l2']
        return reward_vec
    
    def locus2_rewards(j):
        """Get rewards for locus 2 with j descendants."""
        reward_vec = np.zeros(graph.vertices_length())
        for i in range(1, s + 1):
            for p in [1, 2]:
                state_idx = broadcast_props_to_index(state_space, s, i, j, p)
                for v_idx in range(graph.vertices_length()):
                    vertex = graph.vertex_at(v_idx)
                    if int(vertex.state()[0]) == state_idx:
                        props = state_space.index_to_props_dict(state_idx)
                        reward_vec[v_idx] = props['descendants_l1'] + props['descendants_l2']
        return reward_vec
    
    # Compute covariance matrix
    cov_mat = np.zeros((s - 1, s - 1))
    
    for i in range(1, s):
        for j in range(1, s):
            r1 = locus1_rewards(i)
            r2 = locus2_rewards(j)
            cov_mat[i - 1, j - 1] = graph.covariance(r1, r2)
    
    return cov_mat

In [None]:
cov_mat_no_rec_no_mig = ton_covariance_between_loci(4, 1, 0.5, 1)

In [None]:
plt.figure(figsize=(7, 5))
ticks = list(range(1, s))
ax = sns.heatmap(cov_mat_no_rec_no_mig, cmap="PiYG", 
                annot=True,
                center=0,
                yticklabels=ticks,
                xticklabels=ticks
                )
ax.invert_yaxis()
plt.tight_layout()

# Fit parameters to expectations and covariance

Assuming empirical SFS is perfectly known.

In [None]:
# Setup "observed" data
s = 4
expected_cov_mat = disc_ton_covariance_between_loci(s, 1, 1, 0.5, 0.5, 1)
expected_exp_mat = disc_two_locus_expectation(s, 1, 1, 0.5, 0.5, 1)

In [None]:
def fit(N, M):
    """
    Compute fit of model parameters to observed data.
    """
    observed_cov_mat = disc_ton_covariance_between_loci(s, N, N, M, M, 1)
    observed_exp_mat = disc_two_locus_expectation(s, N, N, M, M, 1)
    
    cov_fit = np.abs(np.nansum((observed_cov_mat - expected_cov_mat) / expected_cov_mat))
    exp_fit = np.abs(np.nansum((observed_exp_mat - expected_exp_mat) / expected_exp_mat))
    
    return exp_fit, cov_fit

In [None]:
# Grid search over parameters
n_vals = np.arange(0.5, 2.55, 0.05)
m_vals = np.arange(0.01, 2.06, 0.05)

fit_mat_exp = np.zeros((len(n_vals), len(m_vals)))
fit_mat_cov = np.zeros((len(n_vals), len(m_vals)))

for i, N in enumerate(n_vals):
    for j, M in enumerate(m_vals):
        exp_fit, cov_fit = fit(N, M)
        fit_mat_exp[i, j] = exp_fit
        fit_mat_cov[i, j] = cov_fit
        
fit_mat = fit_mat_exp + fit_mat_cov

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Plot expectation fit
plot_df = pd.DataFrame(fit_mat_exp, columns=np.round(m_vals, 2), index=np.round(n_vals, 2))
sns.heatmap(plot_df, cmap="viridis", ax=ax1)
ax1.invert_yaxis()
ax1.set_title('Expectation fit')

# Plot covariance fit (log scale)
plot_df = pd.DataFrame(fit_mat_cov, columns=np.round(m_vals, 2), index=np.round(n_vals, 2))
sns.heatmap(np.log10(plot_df), cmap="viridis", ax=ax2)
ax2.invert_yaxis()
ax2.set_title('Covariance fit (log10)')

plt.tight_layout()

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Plot combined fit
plot_df = pd.DataFrame(fit_mat, columns=np.round(m_vals, 2), index=np.round(n_vals, 2))
sns.heatmap(plot_df, cmap="viridis", ax=ax1)
ax1.invert_yaxis()
ax1.set_title('Combined fit')

# Plot combined fit (log scale)
sns.heatmap(np.log10(plot_df), cmap="viridis", ax=ax2)
ax2.invert_yaxis()
ax2.set_title('Combined fit (log10)')

plt.tight_layout()

## Notes

Only a single N and M can be estimated. Using recombination rate of 0.5 seems to work.