# Isolation-Migration Model

## Setup

In [None]:
# Always import phasic first to set jax backend correctly
import phasic
import numpy as np
np.random.seed(42)
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats('retina', 'png')
import matplotlib
matplotlib.rcParams['figure.figsize'] = (7, 5)
sns.set_context('paper', font_scale=0.9)
sns.set_style('ticks')
phasic.set_theme('dark')

In [None]:
from phasic import Graph
from graphviz import Digraph

## State Space

Each state has the following properties:
1. Number of descendants in population 1
2. Number of descendants in population 2
3. What population the lineage is currently in

The state vector is organized as a flattened matrix with dimensions:
- `matrix_size = max(n1, n2) + 1`
- `state_length = matrix_size * matrix_size * 2` (for 2 populations)

Index calculation: `index = (matrix_size^2 * population) + i + j * matrix_size`

In [None]:
def get_matrix_index(i, j, population, matrix_size):
    """Convert (i, j, population) to linear index in state vector."""
    return (matrix_size * matrix_size * population) + i + j * matrix_size

def construct_im_graph(n1, n2, p1_size, p2_size, m12, m21):
    """
    Construct isolation-migration graph.
    
    Parameters
    ----------
    n1 : int
        Number of samples from population 1
    n2 : int
        Number of samples from population 2
    p1_size : float
        Size of population 1 (in coalescent units)
    p2_size : float
        Size of population 2 (in coalescent units)
    m12 : float
        Migration rate from population 1 to 2
    m21 : float
        Migration rate from population 2 to 1
    
    Returns
    -------
    Graph
        Phase-type graph for IM model
    """
    matrix_size = max(n1, n2) + 1
    state_length = matrix_size * matrix_size * 2
    
    def im_callback(state):
        lineages_left = np.sum(state)
        
        if lineages_left <= 1:
            # Only one lineage left, absorb
            return []
        
        edges = []
        
        # Migration events
        for i in range(matrix_size):
            for j in range(matrix_size):
                for population in range(2):
                    index1 = get_matrix_index(i, j, population, matrix_size)
                    index2 = get_matrix_index(i, j, 1 - population, matrix_size)
                    
                    if state[index1] == 0:
                        continue
                    
                    child_state = state.copy()
                    child_state[index1] -= 1
                    child_state[index2] += 1
                    
                    migration_rate = m21 if population == 0 else m12
                    weight = (child_state[index1] + 1) * migration_rate
                    edges.append((child_state, weight))
        
        # Coalescence events within each population
        for population in range(2):
            pop_size = p1_size if population == 0 else p2_size
            
            for i in range(matrix_size):
                for j in range(matrix_size):
                    index1 = get_matrix_index(i, j, population, matrix_size)
                    entry1 = state[index1]
                    
                    if entry1 == 0:
                        continue
                    
                    for i2 in range(matrix_size):
                        for j2 in range(matrix_size):
                            # Only consider pairs where i+j*matrix_size >= i2+j2*matrix_size
                            if i + j * matrix_size < i2 + j2 * matrix_size:
                                continue
                            
                            index2 = get_matrix_index(i2, j2, population, matrix_size)
                            entry2 = state[index2]
                            
                            if entry2 == 0:
                                continue
                            
                            if i == i2 and j == j2:
                                if entry1 == 1:
                                    continue
                                weight = entry1 * (entry1 - 1) / 2 / pop_size
                            else:
                                weight = entry1 * entry2 / pop_size
                            
                            child_state = state.copy()
                            child_state[get_matrix_index(i, j, population, matrix_size)] -= 1
                            child_state[get_matrix_index(i2, j2, population, matrix_size)] -= 1
                            child_state[get_matrix_index(i + i2, j + j2, population, matrix_size)] += 1
                            edges.append((child_state, weight))
        
        return edges
    
    # Create initial state
    initial_state = np.zeros(state_length, dtype=int)
    initial_state[get_matrix_index(1, 0, 0, matrix_size)] = n1
    initial_state[get_matrix_index(0, 1, 1, matrix_size)] = n2
    
    # Build graph
    graph = Graph(
        state_length=state_length,
        callback=im_callback,
        initial_state=initial_state
    )
    
    return graph, matrix_size

In [None]:
# Build IM graph with sample parameters
n1, n2 = 4, 4
p1_size, p2_size = 1.0, 1.0
m12, m21 = 0.1, 0.1

im_graph, matrix_size = construct_im_graph(n1, n2, p1_size, p2_size, m12, m21)
print(f"Number of vertices: {im_graph.vertices_length()}")

## Visualize Graph Structure

In [None]:
def plot_graph(graph, constrained=True, size='10'):
    """Plot graph structure using Graphviz."""
    states = []
    parents = []
    children = []
    rates = []
    
    # Extract graph structure
    for i in range(graph.vertices_length()):
        vertex = graph.vertex_at(i)
        states.append(vertex.state)
        
        for edge in vertex.edges:
            parents.append(i)
            children.append(edge.child.index)
            rates.append(edge.weight)
    
    constrained_str = 'true' if constrained else 'false'
    
    dot = Digraph()
    dot.node('S', 'S')
    
    for i in range(1, len(states)):
        dot.node(str(i), str(states[i]))
    
    dot.edge('S', str(1))
    
    for p, c, r in zip(parents, children, rates):
        if p > 0:
            dot.edge(str(p), str(c), constraint=constrained_str, label=f"{r:.2f}")
    
    dot.graph_attr['size'] = size
    return dot

# Visualize small graph
small_graph, _ = construct_im_graph(2, 2, 1.0, 1.0, 0.1, 0.1)
plot_graph(small_graph)

## Numerical Accuracy Check

Compare expectations computed using different methods for an "infinite" IM stage.

In [None]:
def compute_rewards_at(state_length, matrix_size, i, j, n1, n2):
    """
    Compute reward vector for lineages with i descendants at locus 1
    and j descendants at locus 2.
    """
    rewards = np.zeros(state_length)
    
    for population in range(2):
        index = get_matrix_index(i, j, population, matrix_size)
        rewards[index] = 1
    
    return rewards

# Compute expectations using graph algorithms
n1, n2 = 4, 4
im_graph, matrix_size = construct_im_graph(n1, n2, 1.0, 1.0, 0.1, 0.1)
state_length = im_graph.state_length()

algorithm_expectation = np.zeros((n1+1, n2+1))

for i in range(n1+1):
    for j in range(n2+1):
        rewards = compute_rewards_at(state_length, matrix_size, i, j, n1, n2)
        algorithm_expectation[i, j] = im_graph.expectation(rewards)

print("Computed expectations for all (i,j) configurations")

In [None]:
plt.figure(figsize=(7, 5))
ax = sns.heatmap(algorithm_expectation, cmap="viridis_r", annot=True, fmt=".3f")
ax.invert_yaxis()
plt.xlabel('Descendants at locus 2')
plt.ylabel('Descendants at locus 1')
plt.title('Expected branch length for IM stage')
sns.despine()
plt.tight_layout()

## Complete IM Model with Split Time and Ancestral Population

The full isolation-migration model consists of:
1. An IM stage with two populations, migration, and a finite duration (split time)
2. An ancestral panmictic population stage

We need to compute:
- Expected branch lengths during the IM stage (up to split time)
- Probability distribution of states at the split time
- Expected branch lengths in the ancestral population

In [None]:
def construct_ancestral_graph(n1, n2, pa_size):
    """
    Construct ancestral panmictic population graph.
    
    Parameters
    ----------
    n1 : int
        Number of samples from population 1
    n2 : int
        Number of samples from population 2
    pa_size : float
        Size of ancestral population (in coalescent units)
    
    Returns
    -------
    Graph
        Phase-type graph for ancestral population
    """
    matrix_size = max(n1, n2) + 1
    state_length = matrix_size * matrix_size * 2
    
    def ancestral_callback(state):
        lineages_left = np.sum(state)
        
        if lineages_left <= 1:
            return []
        
        edges = []
        
        # Only coalescence in single (ancestral) population
        for i in range(matrix_size):
            for j in range(matrix_size):
                index1 = get_matrix_index(i, j, 0, matrix_size)
                entry1 = state[index1]
                
                if entry1 == 0:
                    continue
                
                for i2 in range(matrix_size):
                    for j2 in range(matrix_size):
                        if i + j * matrix_size < i2 + j2 * matrix_size:
                            continue
                        
                        index2 = get_matrix_index(i2, j2, 0, matrix_size)
                        entry2 = state[index2]
                        
                        if entry2 == 0:
                            continue
                        
                        if i == i2 and j == j2:
                            if entry1 == 1:
                                continue
                            weight = entry1 * (entry1 - 1) / 2 / pa_size
                        else:
                            weight = entry1 * entry2 / pa_size
                        
                        child_state = state.copy()
                        child_state[get_matrix_index(i, j, 0, matrix_size)] -= 1
                        child_state[get_matrix_index(i2, j2, 0, matrix_size)] -= 1
                        child_state[get_matrix_index(i + i2, j + j2, 0, matrix_size)] += 1
                        edges.append((child_state, weight))
        
        return edges
    
    # Initial state has all lineages in population 0 (ancestral)
    initial_state = np.zeros(state_length, dtype=int)
    initial_state[get_matrix_index(1, 0, 0, matrix_size)] = n1
    initial_state[get_matrix_index(0, 1, 0, matrix_size)] = n2
    
    graph = Graph(
        state_length=state_length,
        callback=ancestral_callback,
        initial_state=initial_state
    )
    
    return graph, matrix_size

# Build ancestral graph
a_graph, a_matrix_size = construct_ancestral_graph(4, 4, 1.0)
print(f"Ancestral graph vertices: {a_graph.vertices_length()}")

In [None]:
# Full IM model parameters
n1, n2 = 10, 10
m12, m21 = 0.005, 2.0
p1, p2 = 2.0, 1.0
pa = 4.0
split_t = 3.0  # Time in coalescent units

print(f"Building IM graph with n1={n1}, n2={n2}...")
im_graph, matrix_size = construct_im_graph(n1, n2, p1, p2, m12, m21)
print(f"IM graph vertices: {im_graph.vertices_length()}")

# Compute accumulated visiting time up to split_t
print(f"\nComputing accumulated visiting time up to t={split_t}...")
im_expected_visits = im_graph.accumulated_visiting_time(split_t)

print(f"\nBuilding ancestral graph...")
a_graph, a_matrix_size = construct_ancestral_graph(n1, n2, pa)
print(f"Ancestral graph vertices: {a_graph.vertices_length()}")

In [None]:
def compute_start_prob_from_im(a_graph, im_graph, im_stopping, matrix_size):
    """
    Compute starting probabilities for ancestral graph based on 
    stopping probabilities from IM graph.
    
    Maps IM states (with two populations) to ancestral states (single population)
    by summing lineages across populations.
    """
    state_length = im_graph.state_length()
    start_prob = np.zeros(a_graph.vertices_length())
    
    # Build mapping from state to ancestral vertex index
    a_state_to_index = {}
    for i in range(a_graph.vertices_length()):
        vertex = a_graph.vertex_at(i)
        a_state_to_index[tuple(vertex.state)] = i
    
    # Map IM states to ancestral states
    for k in range(1, im_graph.vertices_length()):
        im_vertex = im_graph.vertex_at(k)
        im_state = im_vertex.state
        
        # Create ancestral state by collapsing populations
        a_state = np.zeros(state_length, dtype=int)
        
        for i in range(matrix_size):
            for j in range(matrix_size):
                # Sum lineages from both populations
                index_pop0 = get_matrix_index(i, j, 0, matrix_size)
                index_pop1 = get_matrix_index(i, j, 1, matrix_size)
                a_state[get_matrix_index(i, j, 0, matrix_size)] = (
                    im_state[index_pop0] + im_state[index_pop1]
                )
        
        # Find matching ancestral vertex
        a_state_tuple = tuple(a_state)
        if a_state_tuple in a_state_to_index:
            a_index = a_state_to_index[a_state_tuple]
            start_prob[a_index] += im_stopping[k]
    
    return start_prob

# Compute starting probabilities for ancestral stage
print("Computing starting probabilities for ancestral population...")
start_prob = compute_start_prob_from_im(a_graph, im_graph, im_expected_visits, matrix_size)
print(f"Starting probability mass: {np.sum(start_prob):.6f}")

In [None]:
# Compute expectations for both stages
print("Computing expectations for IM stage...")
im_expectation = np.zeros((n1+1, n2+1))
state_length = im_graph.state_length()

for i in range(n1+1):
    for j in range(n2+1):
        rewards = compute_rewards_at(state_length, matrix_size, i, j, n1, n2)
        im_expectation[i, j] = np.sum(im_expected_visits * rewards)

print("Computing expectations for ancestral stage...")
a_expectation = np.zeros((n1+1, n2+1))

for i in range(n1+1):
    for j in range(n2+1):
        rewards = compute_rewards_at(state_length, a_matrix_size, i, j, n1, n2)
        expected_waiting = a_graph.expected_waiting_time(rewards)
        a_expectation[i, j] = np.sum(start_prob * expected_waiting)

print("Done!")

## Expectation for Infinite IM Stage

In [None]:
plt.figure(figsize=(7, 5))
ax = sns.heatmap(im_expectation, cmap="viridis_r")
ax.invert_yaxis()
plt.xlabel('Descendants at locus 2')
plt.ylabel('Descendants at locus 1')
plt.title('Expected branch length: IM stage')
sns.despine()
plt.tight_layout()
plt.savefig('im_stage.pdf', bbox_inches='tight')

## Expectation for Panmictic Ancestral Population

In [None]:
plt.figure(figsize=(7, 5))
ax = sns.heatmap(a_expectation, cmap="viridis_r")
ax.invert_yaxis()
plt.xlabel('Descendants at locus 2')
plt.ylabel('Descendants at locus 1')
plt.title('Expected branch length: Ancestral stage')
sns.despine()
plt.tight_layout()
plt.savefig('a_stage.pdf', bbox_inches='tight')

## Expectation for Combined IM Model with Split Time

In [None]:
plt.figure(figsize=(7, 5))
ax = sns.heatmap(im_expectation + a_expectation, cmap="viridis_r", 
                 linewidths=1, linecolor='white')
ax.invert_yaxis()
plt.xlabel('Descendants at locus 2')
plt.ylabel('Descendants at locus 1')
plt.title(f'Combined IM model (m12={m12}, m21={m21}, split_t={split_t})')
sns.despine()
plt.tight_layout()
plt.savefig(f"im_{n1}_{n2}_{m12}_{m21}_{split_t}.pdf", bbox_inches='tight')
plt.savefig(f"im_{n1}_{n2}_{m12}_{m21}_{split_t}.png", bbox_inches='tight', dpi=300)

In [None]:
plt.figure(figsize=(13, 10))
ax = sns.heatmap(im_expectation + a_expectation, cmap="viridis_r", 
                 annot=True, fmt=".2f", linewidths=1, linecolor='white')
ax.invert_yaxis()
plt.xlabel('Descendants at locus 2')
plt.ylabel('Descendants at locus 1')
plt.title(f'Combined IM model with annotations')
sns.despine()
plt.tight_layout()
plt.savefig(f"im_{n1}_{n2}_{m12}_{m21}_{split_t}_annot.pdf", bbox_inches='tight')
plt.savefig(f"im_{n1}_{n2}_{m12}_{m21}_{split_t}_annot.png", bbox_inches='tight', dpi=300)

## Summary

This notebook demonstrates the isolation-migration (IM) model using phase-type distributions:

1. **State space**: Each state tracks descendants at two loci and population membership
2. **IM stage**: Models coalescence within populations and migration between populations
3. **Ancestral stage**: Models panmictic ancestral population after split time
4. **Combined model**: Computes total expected branch lengths across both stages

The graph-based algorithms efficiently handle the complex state space and provide exact expectations for arbitrary parameter combinations.