In [5]:
import msprime
import tskit
import numpy as np
import networkx as nx
from collections import defaultdict
from itertools import chain
import pandas as pd
import matplotlib.pyplot as plt


def ts_to_nx(ts, connect_recombination_nodes=False, recomb_nodes=[]):
    """
    Converts tskit tree sequence to networkx graph.
    """
    topology = defaultdict(list)
    for tree in ts.trees():
        for k, v in chain(tree.parent_dict.items()):
            if connect_recombination_nodes:
                if recomb_nodes == []:
                    recomb_nodes = list(np.where(ts.tables.nodes.flags == 131072)[0])
                if v in recomb_nodes and recomb_nodes.index(v)%2 == 1:
                    v -= 1
                if k in recomb_nodes and recomb_nodes.index(k)%2 == 1:
                    k -= 1
                if v not in topology[k]:
                    topology[k].append(v)
            else:
                if v not in topology[k]:
                    topology[k].append(v)
    nx_graph = nx.MultiDiGraph(topology)
    return nx_graph

def ts_to_nx_updated(ts):
    topology = defaultdict(list)
    for edge in ts.tables.edges:
        topology[edge.parent].append(edge.child)
    nx_graph = nx.MultiDiGraph(topology)
    return nx_graph

def simplify_graph(G, root=-1):
    ''' Loop over the graph until all nodes of degree 2 have been removed and their incident edges fused 
    Adapted from https://stackoverflow.com/questions/53353335/networkx-remove-node-and-reconnect-edges
    '''

    g = G.copy()
    while any(degree==2 for _, degree in g.degree):
        g0 = g.copy() #<- simply changing g itself would cause error `dictionary changed size during iteration` 
        for node, degree in g.degree():
            if degree==2 and node!=root:
                if g.is_directed(): #<-for directed graphs
                    a0,b0 = list(g0.in_edges(node))[0]
                    a1,b1 = list(g0.out_edges(node))[0]
                else:
                    edges = g0.edges(node)
                    edges = list(edges.__iter__())
                    a0,b0 = edges[0]
                    a1,b1 = edges[1]
                e0 = a0 if a0!=node else b0
                e1 = a1 if a1!=node else b1
                g0.remove_node(node)
                g0.add_edge(e0, e1)
        g = g0
    return g





ts = tskit.load("run1/slim_0.25rep3sigma.trees")
np.random.seed(1)
keep_nodes = list(np.random.choice(ts.samples(), 20, replace=False))
subset_ts = ts.simplify(samples=keep_nodes, keep_input_roots=True, keep_unary=True)
nx_arg = ts_to_nx_updated(ts=subset_ts)
simple_arg = simplify_graph(G=nx_arg, root=subset_ts.node(subset_ts.num_nodes-1).id)

In [9]:
print(len(simple_arg.nodes))

2228


In [8]:
print(len(simple_arg.edges))

4465


In [11]:
recomb_nodes = []
for i in simple_arg.nodes:
    is_parent = subset_ts.tables.edges[np.where(subset_ts.tables.edges.parent==i)[0]]
    is_child = subset_ts.tables.edges[np.where(subset_ts.tables.edges.child==i)[0]]
    if (len(set(is_parent.child)) == 1) & (len(set(is_child.parent)) > 1) & ((len(is_parent) + len(is_child)) % 2 != 0):
        recomb_nodes.append(i)

In [12]:
print(recomb_nodes)

[173, 226, 267, 281, 307, 321, 382, 479, 593, 623, 692, 774, 886, 909, 1180, 1928, 1980, 2008, 2040, 2068, 2383, 2405, 2476, 2595, 2803, 2871, 2904, 3155, 3194, 3278, 3693, 3893, 4027, 4272, 4512, 4650, 4691, 4867, 5206, 5408, 5419, 5740, 6088, 6482, 6509, 7152, 7440, 7539, 7947, 8375, 8580, 8669, 8857, 9298, 9361, 9392, 9821, 9937, 9942, 10076, 10517, 10639, 10950, 11106, 11344, 11854, 11937, 11981, 12032, 12308, 12724, 12760, 12981, 13036, 13087, 13124, 13904, 14020, 14096, 14481, 14485, 14603, 14866, 14948, 15465, 15800, 16120, 16644, 16762, 16810, 16840, 17211, 17269, 17475, 17765, 17766, 17833, 18463, 18506, 18757, 18906, 19072, 19094, 19111, 19343, 19410, 19437, 19588, 19956]


In [13]:
recomb_nodes = []
for i in simple_arg.nodes:
    is_parent = subset_ts.tables.edges[np.where(subset_ts.tables.edges.parent==i)[0]]
    is_child = subset_ts.tables.edges[np.where(subset_ts.tables.edges.child==i)[0]]
    if (len(is_parent) > len(is_child)):
        recomb_nodes.append(i)
print(recomb_nodes)

[36, 47, 78, 82, 130, 211, 272, 303, 339, 352, 397, 407, 412, 444, 472, 562, 660, 709, 726, 734, 772, 821, 932, 982, 1097, 1109, 1267, 1311, 1328, 1806, 1965, 2115, 2363, 2539, 2636, 2662, 2689, 3174, 3314, 3325, 3921, 4077, 4167, 4351, 4555, 4675, 4838, 4874, 5219, 5975, 5981, 6172, 6512, 6580, 7181, 7396, 7631, 8576, 8593, 8872, 9263, 9303, 9415, 9819, 9876, 10029, 10282, 10658, 10993, 11113, 11123, 11811, 11905, 11988, 12335, 12947, 13016, 13065, 13079, 13104, 13377, 13911, 14025, 14407, 14555, 14871, 14881, 14973, 15812, 16582, 16649, 16793, 17290, 17345, 17628, 17791, 17855, 17886, 18390, 18547, 18718, 18927, 19169, 19295, 19356, 19433, 19857, 19955, 19985, 19988]
