In [None]:
# Always import phasic first to set jax backend correctly
import phasic
import numpy as np
np.random.seed(42)
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats('retina', 'png')
import matplotlib
matplotlib.rcParams['figure.figsize'] = (5, 3.7)
sns.set_context('paper', font_scale=0.9)
# import warnings
# warnings.filterwarnings(action='ignore', category=Warning, module='seaborn')
phasic.set_theme('dark')

# Isolation-Migration Model

This notebook demonstrates the isolation-migration (IM) model using phasic.

The IM model describes two populations that:
1. Exchange migrants at rates m12 and m21
2. Experience coalescence within each population
3. May merge into an ancestral population at a split time

**Note**: This notebook uses the IM model implementation. For a complete reference implementation,
see `/Users/kmt/phasic/docs/pages/popgen/python/im_model.ipynb` which includes:
- State space construction for two populations
- Migration and coalescence graph building
- Ancestral population construction
- Reward computation for expectations

In [None]:
def matrix_index(i, j, pop, n1, n2):
    """
    Convert (i, j, population) to linear state vector index.
    
    Parameters
    ----------
    i : int
        Number of lineages in first sample
    j : int
        Number of lineages in second sample
    pop : int
        Population (0 or 1)
    n1 : int
        Max lineages in first sample
    n2 : int
        Max lineages in second sample
    
    Returns
    -------
    int
        Linear index in state vector
    """
    matrix_size = (n1 + 1) * (n2 + 1)
    return i * (n2 + 1) + j + pop * matrix_size


def construct_im_graph(n1, n2, m12, m21, p1_size=1.0, p2_size=1.0):
    """
    Construct isolation-migration graph.
    
    Parameters
    ----------
    n1 : int
        Sample size from population 1
    n2 : int
        Sample size from population 2
    m12 : float
        Migration rate from population 1 to 2
    m21 : float
        Migration rate from population 2 to 1
    p1_size : float
        Relative size of population 1 (default 1.0)
    p2_size : float
        Relative size of population 2 (default 1.0)
    
    Returns
    -------
    phasic.Graph
        IM model graph
    """
    matrix_size = (n1 + 1) * (n2 + 1)
    state_length = matrix_size * 2  # Two populations
    
    graph = phasic.Graph(state_length)
    
    # Initial state: all lineages in their respective populations
    initial_state = np.zeros(state_length, dtype=int)
    initial_state[matrix_index(n1, n2, 0, n1, n2)] = 1
    
    first_vertex = graph.find_or_create_vertex(initial_state)
    graph.starting_vertex().add_edge(first_vertex, 1.0)
    
    index = 1
    while index < graph.vertices_length():
        vertex = graph.vertex_at(index)
        state = vertex.state()
        
        # Process each (i, j, pop) configuration
        for i in range(n1 + 1):
            for j in range(n2 + 1):
                for pop in range(2):
                    idx = matrix_index(i, j, pop, n1, n2)
                    if state[idx] == 0:
                        continue
                    
                    pop_size = p1_size if pop == 0 else p2_size
                    
                    # Migration events
                    if pop == 0 and i > 0:
                        # Migration from pop 0 to pop 1
                        child_state = state.copy()
                        child_state[idx] -= 1
                        child_state[matrix_index(i, j, 1, n1, n2)] += 1
                        child_vertex = graph.find_or_create_vertex(child_state)
                        vertex.add_edge(child_vertex, i * m12)
                    
                    if pop == 1 and j > 0:
                        # Migration from pop 1 to pop 0
                        child_state = state.copy()
                        child_state[idx] -= 1
                        child_state[matrix_index(i, j, 0, n1, n2)] += 1
                        child_vertex = graph.find_or_create_vertex(child_state)
                        vertex.add_edge(child_vertex, j * m21)
                    
                    # Coalescence within population
                    # Type (i,i) -> (i+i)
                    if i >= 2:
                        child_state = state.copy()
                        child_state[idx] -= 1
                        child_state[matrix_index(i-1+i-1, j, pop, n1, n2)] += 1
                        child_vertex = graph.find_or_create_vertex(child_state)
                        rate = i * (i - 1) / (2 * pop_size)
                        vertex.add_edge(child_vertex, rate)
                    
                    # Type (j,j) -> (j+j)
                    if j >= 2:
                        child_state = state.copy()
                        child_state[idx] -= 1
                        child_state[matrix_index(i, j-1+j-1, pop, n1, n2)] += 1
                        child_vertex = graph.find_or_create_vertex(child_state)
                        rate = j * (j - 1) / (2 * pop_size)
                        vertex.add_edge(child_vertex, rate)
                    
                    # Type (i,j) -> (i+j)
                    if i >= 1 and j >= 1:
                        child_state = state.copy()
                        child_state[idx] -= 1
                        child_state[matrix_index(i-1+j-1, 0, pop, n1, n2)] += 1
                        child_vertex = graph.find_or_create_vertex(child_state)
                        rate = i * j / pop_size
                        vertex.add_edge(child_vertex, rate)
        
        index += 1
    
    return graph


def construct_ancestral_graph(n1, n2):
    """
    Construct ancestral (panmictic) graph.
    
    All lineages are in a single ancestral population.
    """
    matrix_size = (n1 + 1) * (n2 + 1)
    state_length = matrix_size  # Only one population
    
    graph = phasic.Graph(state_length)
    
    # Build using callback (coalescence only, no migration)
    index = 0
    while index < graph.vertices_length():
        vertex = graph.vertex_at(index)
        state = vertex.state()
        
        for i in range(n1 + 1):
            for j in range(n2 + 1):
                idx = i * (n2 + 1) + j
                if state[idx] == 0:
                    continue
                
                # Coalescence events (same as IM but single population)
                if i >= 2:
                    child_state = state.copy()
                    child_state[idx] -= 1
                    child_state[(i-1+i-1) * (n2 + 1) + j] += 1
                    child_vertex = graph.find_or_create_vertex(child_state)
                    rate = i * (i - 1) / 2
                    vertex.add_edge(child_vertex, rate)
                
                if j >= 2:
                    child_state = state.copy()
                    child_state[idx] -= 1
                    child_state[i * (n2 + 1) + (j-1+j-1)] += 1
                    child_vertex = graph.find_or_create_vertex(child_state)
                    rate = j * (j - 1) / 2
                    vertex.add_edge(child_vertex, rate)
                
                if i >= 1 and j >= 1:
                    child_state = state.copy()
                    child_state[idx] -= 1
                    child_state[(i-1+j-1) * (n2 + 1) + 0] += 1
                    child_vertex = graph.find_or_create_vertex(child_state)
                    rate = i * j
                    vertex.add_edge(child_vertex, rate)
        
        index += 1
    
    return graph


def rewards_at(graph, i, j, n1, n2):
    """
    Compute reward vector for (i, j) configuration.
    
    Returns reward vector that is 1 for states with (i, j) lineages.
    """
    rewards = np.zeros(graph.vertices_length())
    
    for v_idx in range(graph.vertices_length()):
        vertex = graph.vertex_at(v_idx)
        state = vertex.state()
        
        # For IM graph (2 populations)
        if len(state) == 2 * (n1 + 1) * (n2 + 1):
            idx0 = matrix_index(i, j, 0, n1, n2)
            idx1 = matrix_index(i, j, 1, n1, n2)
            rewards[v_idx] = state[idx0] + state[idx1]
        else:
            # For ancestral graph (1 population)
            idx = i * (n2 + 1) + j
            rewards[v_idx] = state[idx]
    
    return rewards

## Example: IM Model with Expectations

Construct IM model and compute expectations for different lineage configurations.

In [None]:
import time

n1 = 4
n2 = 4

print("Constructing IM graph...")
start = time.time()
im_graph = construct_im_graph(n1, n2, m12=0.1, m21=0.1)
print(f"Construction took {time.time() - start:.2f} seconds")
print(f"Graph has {im_graph.vertices_length()} vertices")

print("\nConstructing ancestral graph...")
ancestral_graph = construct_ancestral_graph(n1, n2)
print(f"Ancestral graph has {ancestral_graph.vertices_length()} vertices")

## Computing Expectations

Compute expected branch lengths for different (i, j) configurations using reward transformation.

In [None]:
# Compute expectations for IM model
print("Computing expectations...")
start = time.time()

im_expectation = np.zeros((n1 + 1, n2 + 1))

for i in range(n1 + 1):
    for j in range(n2 + 1):
        rewards = rewards_at(im_graph, i, j, n1, n2)
        reward_graph = im_graph.reward_transform(rewards)
        im_expectation[i, j] = reward_graph.phase_type_moment(1)

print(f"Expectations took {time.time() - start:.2f} seconds")
print("\nExpectations matrix:")
print(im_expectation)

## Visualization

Visualize the expectations as a heatmap.

In [None]:
plt.figure(figsize=(8, 6))
sns.heatmap(im_expectation, annot=True, fmt='.3f', cmap='viridis',
            xticklabels=range(n2 + 1), yticklabels=range(n1 + 1))
plt.xlabel('Lineages in sample 2 (j)')
plt.ylabel('Lineages in sample 1 (i)')
plt.title('Expected Branch Lengths in IM Model')
plt.tight_layout()
plt.show()

## Summary

This notebook demonstrates:
- Construction of isolation-migration (IM) graphs
- Migration between two populations
- Coalescence within populations
- Ancestral population construction
- Reward-based expectation computation
- Visualization of branch length expectations

For more detailed IM model analysis including:
- Finite split times
- Stopping probabilities
- Combined IM + ancestral expectations

See: `/Users/kmt/phasic/docs/pages/popgen/python/im_model.ipynb`