In [None]:
import phasic
from phasic.state_indexing import Property, StateSpace
import numpy as np
phasic.plot.get_theme()

In [11]:
def two_locus_arg(s=None, N=None, R=None):
    """
    Build two-locus ancestral recombination graph (ARG).

    Parameters
    ----------
    s : int
        Sample size
    N : float
        Effective population size (diploid)
    R : float
        Recombination rate between loci

    Returns
    -------
    phasic.Graph
        Phase-type graph representing the ARG
    """

    # Define state space for two-locus model
    state_space = StateSpace([
        Property('descendants_l1', max_value=s),
        Property('descendants_l2', max_value=s)
    ])

    # State vector length
    n = state_space.size

    # Initialize graph
    graph = phasic.Graph(n)

    # Create initial state: s lineages with (1,1) descendants at both loci
    initial_state = np.zeros(n, dtype=int)
    initial_idx = state_space.props_to_index(descendants_l1=1, descendants_l2=1)
    initial_state[initial_idx] = s
    print(initial_state)

    first_vertex = graph.find_or_create_vertex(initial_state)
    graph.starting_vertex().add_edge(first_vertex, 1.0)

    # Process vertices iteratively (graph grows during iteration)
    index = 1
    while index < graph.vertices_length():
        vertex = graph.vertex_at(index)
        state = vertex.state()

        # Count total lineages
        total_lineages = int(np.sum(state))

        if total_lineages <= 1:
            # Only one lineage remaining, stop
            index += 1
            continue

        # Iterate over all lineage configurations
        for i in range(n):
            if state[i] == 0:
                continue

            # Get properties for configuration i
            conf_i = state_space.index_to_props(i)

            # COALESCENCE: Try coalescing with all configurations j >= i
            for j in range(i, n):
                if state[j] == 0:
                    continue

                conf_j = state_space.index_to_props(j)

                # Calculate coalescence rate
                if i == j:
                    if state[i] < 2:
                        continue
                    rate = state[i] * (state[i] - 1) / 2 / N
                else:
                    if state[i] < 1 or state[j] < 1:
                        continue
                    rate = state[i] * state[j] / N

                # Create child state after coalescence
                child_state = state.copy()

                # Remove coalescing lineages
                child_state[i] -= 1
                child_state[j] -= 1

                # Add coalesced lineage (descendants sum at both loci)
                new_l1 = conf_i['descendants_l1'] + conf_j['descendants_l1']
                new_l2 = conf_i['descendants_l2'] + conf_j['descendants_l2']

                # Verify descendants don't exceed sample size
                if new_l1 <= s and new_l2 <= s:
                    k = state_space.props_to_index(descendants_l1=new_l1, descendants_l2=new_l2)
                    child_state[k] += 1

                    child_vertex = graph.find_or_create_vertex(child_state)
                    vertex.add_edge(child_vertex, rate)

            # RECOMBINATION: Can only recombine if lineage has descendants at both loci
            if state[i] > 0 and conf_i['descendants_l1'] > 0 and conf_i['descendants_l2'] > 0:
                rate = R

                # Create child state after recombination
                child_state = state.copy()

                # Remove recombining lineage
                child_state[i] -= 1

                # Add two recombinant lineages (one for each locus)
                k = state_space.props_to_index(descendants_l1=conf_i['descendants_l1'], descendants_l2=0)
                l = state_space.props_to_index(descendants_l1=0, descendants_l2=conf_i['descendants_l2'])
                child_state[k] += 1
                child_state[l] += 1

                child_vertex = graph.find_or_create_vertex(child_state)
                vertex.add_edge(child_vertex, rate)

        index += 1

    return graph

In [12]:
graph = two_locus_arg(s=3, N=1000, R=1)
#graph.plot(max_nodes=200)
graph.vertices_length()


[0 0 0 0 0 3 0 0 0 0 0 0 0 0 0 0]


32

In [None]:
for i in range(3, 20):
    print(i, two_locus_arg(s=i, N=1000, R=1).vertices_length())

3 32
4 110
5 340
6 1044
7 2999
8 8407
9 22653
10 59522
11 151959
12 379694
13 927623


KeyboardInterrupt: 