In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import networkx as nx
import numpy as np
import boolean
import scipy.sparse as sparse
import sys


### Building STG table

In [3]:
def readBNet(filename):
    bnet_model = {}
    algebra = boolean.BooleanAlgebra()
    with open(filename, 'r') as f:
        lines = f.readlines()
        for line in lines:
            species, formula = [value.strip() for value in line.split(",")]
            if species != "target" and formula != "factors":
                b_formula = algebra.parse(formula).simplify()
                bnet_model.update({species:b_formula})
                
    return bnet_model

##### First, loading the model as a minibn

In [4]:
model_name_list = [
    'mammalian_cc', 
    'krasmodel15vars',
    'breast_cancer_zanudo2017',
    'EMT_cohen_ModNet',
    'sahin_breast_cancer_refined',
    'toy',
    'toy2',
    'toy3'
]
model_index = 4

In [5]:
model = readBNet(("model_files/%s.bnet" % model_name_list[model_index]))

In [6]:
nodes = list(model.keys())

In [7]:
n=len(nodes)
n

20

##### Then, actually building the table

In [8]:
def fcn_gen_node_update(formula, list_binary_states, nodes):

    if isinstance(formula, boolean.boolean.Symbol):
        return list_binary_states[:, nodes.index(str(formula))]
    
    elif isinstance(formula, boolean.boolean.NOT):
        return np.logical_not(
            fcn_gen_node_update(formula.args[0], list_binary_states, nodes)
        )
    
    elif isinstance(formula, boolean.boolean.OR):
        ret = fcn_gen_node_update(formula.args[0], list_binary_states, nodes)
        for i in range(1, len(formula.args)):
            ret = np.logical_or(ret, 
                fcn_gen_node_update(formula.args[i], list_binary_states, nodes)
            )
        return ret
    
    elif isinstance(formula, boolean.boolean.AND):
        ret = fcn_gen_node_update(formula.args[0], list_binary_states, nodes)
        for i in range(1, len(formula.args)):
            ret = np.logical_and(ret, 
                fcn_gen_node_update(formula.args[i], list_binary_states, nodes)
            )
        return ret
    
    else:
        print("Unknown boolean operator : %s" % type(formula))

In [9]:
def fcn_build_update_table(model, list_binary_states, nodes):
    update_matrix = np.array(
        [
            fcn_gen_node_update(model[node], list_binary_states, nodes) 
            for node in nodes
        ]
    ).transpose()
    
    return update_matrix

In [10]:
def fcn_states_inds(yes_no, n_series_exp, n_isl_exp):
    
    n_series_exp = n_series_exp - 1
    yes_no = yes_no - 1
    
    f_mat = np.array(
        range(
            1, 
            pow(2, (n_series_exp-n_isl_exp))+1
        )
    ) + yes_no

    t_repmat = np.array([f_mat]*int(pow(2, n_isl_exp)))
        
    t_reshaped = np.reshape(t_repmat, (1, int(pow(2, n_series_exp))), order='F')
    
    t_mult = t_reshaped*pow(2, n_isl_exp)
    t_last = np.array(
        range(
            1, 
            pow(2, n_series_exp)+1
        )
    )
    
    return np.sum([t_last, t_mult])-1

In [11]:
def fcn_build_stg_table(model, nodes):
    list_binary_states = np.remainder(
        np.floor(
            np.multiply(
                np.array([range(pow(2, n))]).transpose(), 
                np.array([np.power([2.0]*n, np.array(range(0, -n, -1)))])
            )
        ), 2
    ).astype(bool)
    
    update_table = fcn_build_update_table(model, list_binary_states, nodes)
    
    up_trans_source = [
        np.intersect1d(
            np.nonzero(update_table[:, x])[0],
            fcn_states_inds(0, n, x)[0, :]
        ) 
        for x in range(n)
    ]
        
    down_trans_source = [
        np.intersect1d(
            np.nonzero(np.logical_not(update_table[:, x]))[0],
            fcn_states_inds(1, n, x)[0, :]
        ) 
        for x in range(n)
    ]
    
    down_trans_target = [
        np.concatenate(
            (
                np.array([down_trans_source[x]-pow(2, x)]).transpose(), 
                np.repeat(np.array([[x,1]]), len(down_trans_source[x]), axis=0)
            ), axis=1
        )
        for x in range(len(down_trans_source))
    ]
    
    up_trans_target = [
        np.concatenate(
            (
                np.array([up_trans_source[x]+pow(2, x)]).transpose(), 
                np.repeat(np.array([[x,0]]), len(up_trans_source[x]), axis=0)
            ), axis=1
        )
        for x in range(len(up_trans_source))
    ]
    
    source = np.concatenate([
        np.concatenate(down_trans_source, axis=0),
        np.concatenate(up_trans_source, axis=0)
    ])

    target = np.concatenate([
        np.concatenate(down_trans_target, axis=0),
        np.concatenate(up_trans_target, axis=0)
    ])
    
    stg_table = np.concatenate((np.array([source]).transpose(), target), axis=1)
    
    return stg_table

In [12]:
stg_table = fcn_build_stg_table(model, nodes)
stg_table.shape

(9961472, 4)

In [13]:
print("%.2fMbytes" % (stg_table.shape[0] * stg_table.shape[1] * stg_table.itemsize / (1024*1024)))

304.00Mbytes


##### Building transition rates table

In [14]:
# to define transition rates, we can select given rates to have different values than 1, or from randomly chosen
# name of rates: 'u_nodename' or 'd_nodename'
# chosen_rates=['u_ERBB1','u_ERBB2','u_ERBB3']; chosen_rates_vals=np.zeros(len(chosen_rates));
# OR leave them empty: 
chosen_rates = []
chosen_rates_vals = []

# chosen_rates = ["u_CHEK1", "d_CHEK1"]
# chosen_rates_vals = np.ones(len(chosen_rates))

In [15]:
# then we generate the table of transition rates: first row is the 'up'rates, second row 'down' rates, 
# in the order of 'nodes'
# ARGUMENTS

In [16]:
distr_type = ['uniform', 'random'] # <uniform> assigns a value of 1 to all params. other option: <random>
meanval = 0.5 # if 'random' is chosen, the mean and standard dev of a normal distrib has to be defined
sd_val = 0.1

In [17]:
def fcn_trans_rates_table(nodes, uniform_or_rand, meanval, sd_val, chosen_rates, chosen_rates_vals):
    n = len(nodes)

    if uniform_or_rand == "uniform":
        rate_vals_num = np.ones((1, 2*n)).astype(np.int64)

    elif uniform_or_rand == "random":
        rate_vals_num = np.random.normal(meanval, sd_val, (1, 2*n))
        if np.any(rate_vals_num < 0):
            neg_cnt = 0
            while np.any(rate_vals_num < 0):
                rate_vals_num = np.random.normal(meanval, sd_val, (1, 2*n))
                neg_cnt += 1
                if neg_cnt > 100:
                    break
                    
    else:
        print("Choose 'uniform' or 'random' to generate transition rates", file=sys.stderr)
        return

    for k, chosen_rate in enumerate(chosen_rates):
        split_rate = chosen_rate.split("_")
        if len(split_rate) > 2:
            node_mod_ind = "_".join(split_rate[1:])
        else:
            node_mod_ind = split_rate[1]

        if split_rate[0] == "d":
            rate_vals_num[:, nodes.index(node_mod_ind)+n] = chosen_rates_vals[k]
        elif split_rate[0] == "u":
            rate_vals_num[:, nodes.index(node_mod_ind)] = chosen_rates_vals[k]
        else:
            print("Wrong name for transition rate, has to be 'u_nodename' or 'd_nodename'", file=sys.stderr)
            return
    
    return np.reshape(rate_vals_num, (n, 2)).transpose()


In [18]:
transition_rates_table = fcn_trans_rates_table(nodes, distr_type[0], meanval, sd_val, chosen_rates, chosen_rates_vals)

##### Building the (sparse) transition matrix

In [19]:
def fcn_build_trans_matr(stg_table, transition_rates_table, kin_matr_flag=""):

    dim_matr = pow(2, transition_rates_table.shape[1])
    
    rate_inds = ((stg_table[:, 2])*2)+stg_table[:, 3]

    # Here we reshape the transition_rates_table to a list
    reshaped_trt = np.reshape(transition_rates_table, (1, np.product(transition_rates_table.shape)), order="F")[0, :]

    B = sparse.csr_matrix(
        (
            reshaped_trt[rate_inds]/np.sum(transition_rates_table),
            (stg_table[:, 0], 
            stg_table[:, 1])
        ),
        shape=(dim_matr, dim_matr)
    )

    A_sparse = B + (sparse.eye(B.shape[0]) - sparse.diags(np.array(sparse.csr_matrix.sum(B, axis=1).transpose())[0]))

    if len(kin_matr_flag) > 0:
        K_sparse = (A_sparse.transpose() - sparse.eye(A_sparse.shape[0]))*np.sum(transition_rates_table)

    else:
        K_sparse = []

    return A_sparse, K_sparse

In [20]:
A_sparse, _ = fcn_build_trans_matr(stg_table, transition_rates_table, 'yes')

In [21]:
print("%.2f Mbytes" % (A_sparse.data.nbytes/(1024*1024)))

84.00 Mbytes


##### Setting up initial values

In [22]:
initial_fixed_nodes_list = [
    ['CycE','CycA','CycB','Cdh1','Rb_b1','Rb_b2','p27_b1','p27_b2'], # mammalian_cc
    ['cc','KRAS','DSB','cell_death'], #krasmodel15vars
    ['Alpelisib', 'Everolimus','PIM','Proliferation','Apoptosis'], # breast_cancer_zanudo2017
    ['ECMicroenv','DNAdamage','Metastasis','Migration','Invasion','EMT','Apoptosis','Notch_pthw','p53'], # EMT_cohen_ModNet 
    ['EGF','ERBB1','ERBB2','ERBB3','p21','p27'], # sahin_breast_cancer_refined
    ['A','C','D'], # toy model with fork in stg
    ['A', 'B', 'C'], # toy model with cycle in stg
    ['A', 'B'] # smaller toy model with cycle in stg, one connected component
] 

initial_fixed_nodes_vals_list = [
    [0, 0, 0, 1, 1, 1, 1, 1], # mammalian_cc
    [1, 1, 1, 0], # krasmodel15vars: [1 1] is cell cycle ON, KRAS mutation ON
    [0, 1, 0] + [0]*2, # breast_cancer_zanudo2017
    [1, 1] + [0]*5 + [1, 0], # EMT-Cohen model: [0/1 0/1 zeros(1,5)]
    [1, 0, 0, 0, 1, 1],
    [0, 0, 0],
    [0, 0, 0],
    [0, 0]
] 

initial_fixed_nodes = initial_fixed_nodes_list[model_index]
initial_fixed_nodes_vals = initial_fixed_nodes_vals_list[model_index]


# what is the probability of this state, (eg. dom_prob=0.8, ie. 80% probability)
dom_prob = 1
# if plot_flag non-empty, we get a bar plot of initial values
plot_flag = ''

In [23]:
def fcn_define_initial_states(initial_fixed_nodes,initial_fixed_nodes_vals,dom_prob,nodes,distrib_type,plot_flag):
    
    n_nodes = len(nodes)
    
    truth_table_inputs = np.remainder(
        np.floor(
            np.multiply(
                np.array([range(pow(2, n_nodes))]).transpose(), 
                np.array([np.power([2.0]*n_nodes, np.array(range(0, -n_nodes, -1)))])
            )
        ), 2
    ).astype(bool)
    
    # define initial values
    x0 = np.zeros((int(pow(2, n_nodes)), 1))
    
    # defining a dominant initial state (eg. dom_prob=0.8, ie. 80% probability
    initial_on_nodes_inds = [node in initial_fixed_nodes for node in nodes]                                  

    statespace_decim = np.sum(
        truth_table_inputs[:, initial_on_nodes_inds]*np.power(
            2, 
            np.array(
                list(reversed(range(np.sum(initial_on_nodes_inds))))
            )
        ), axis=1
    )

    initial_fixed_nodes_vals_decim = np.sum(
        initial_fixed_nodes_vals*np.power(
            2, 
            np.array(
                list(reversed(range(len(initial_fixed_nodes_vals))))
            )
        )
    )

    inds_condition = np.isin(statespace_decim, initial_fixed_nodes_vals_decim)

    if distrib_type == "uniform":
        x0[inds_condition] = np.array([[dom_prob/sum(inds_condition)]*sum(inds_condition)]).transpose()
        x0[np.logical_not(inds_condition)] = np.array([[(1-dom_prob)/(len(x0)-sum(inds_condition))]*(len(x0)-sum(inds_condition))]).transpose()
    
    elif distrib_type == "random":
        x0[inds_condition] = np.random.uniform(0, 1, (sum(inds_condition), 1))
        x0 = dom_prob*x0/sum(x0)

        x0[np.logical_not(inds_condition)] = np.random.uniform(0, 1, (len(x0)-sum(inds_condition), 1))
        x0[np.logical_not(inds_condition)] = (1-dom_prob)*x0[np.logical_not(inds_condition)]/sum(x0[np.logical_not(inds_condition)])
    
    else:
        print("distrib type should be 'uniform' or 'random'", file=sys.stderr)
    
    # rounding precision
    n_prec=3;
    if round(sum(x0)[0],n_prec) == 1:
        print('sum(x0)=1, OK.')
    
    else:
        print('sum(x0)~=1, something wrong!')

#     if ~isempty(plot_flag)
#     bar(x0); set(gca,'yscale','log'); xlim([1 2^n_nodes]); % ylim([(1-dom_prob)/2^n_nodes 1])
#     % subplot(2,1,2); x0=fcn_define_initial_states(initial_on_nodes,dom_prob,nodes,'broad'); 
#     % bar(x0); xlim([1 2^13]);set(gca,'yscale','log'); ylim([(1-dom_prob)/2^n_nodes 1])
#     end

    return x0

In [24]:
x0 = fcn_define_initial_states(initial_fixed_nodes, initial_fixed_nodes_vals, dom_prob, nodes, "uniform", plot_flag)

sum(x0)=1, OK.


In [25]:
x0.shape

(1048576, 1)

In [26]:
import networkx as nx

def fcn_metagraph_scc(A_sparse_sub):
    
    matr_size = A_sparse_sub.shape[0]

    g_sub = nx.from_scipy_sparse_matrix(A_sparse_sub, create_using=nx.DiGraph())
    g_sub.remove_edges_from(nx.selfloop_edges(g_sub))
    
    # Here we reverse it only for debugging purpose
    # The order shouldn't matter, but it's nice to have the same as matlab
    scc_list = list(reversed(list(nx.strongly_connected_components(g_sub))))
    print("%d connected components" % len(scc_list))

    
    num_verts_per_scc = []
    scc_memb_per_vert = np.zeros((matr_size, 1))

    for i, scc in enumerate(scc_list):
        num_verts_per_scc.append(len(scc))
        scc_memb_per_vert[list(scc),:] = i;
        
    # row, col = np.where((A_sparse_sub - np.diag(A_sparse_sub.diagonal())) > 0)
    # Yet another trick to get the exact same results as matlab
    # The difference is returning the list from parsing via columns or via rows, hopefully nothing critical
    col, row = np.where((A_sparse_sub - np.diag(A_sparse_sub.diagonal())).transpose() > 0)

    diff = scc_memb_per_vert[row] != scc_memb_per_vert[col]
    
    row_sel = row[np.where(diff[:, 0])]
    col_sel = col[np.where(diff[:, 0])]

    A_metagraph = sparse.csr_matrix(
        (np.array(A_sparse_sub[row_sel, col_sel]).flatten(), 
        (scc_memb_per_vert[row_sel][:, 0], scc_memb_per_vert[col_sel][:, 0])),
        shape=(len(num_verts_per_scc), len(num_verts_per_scc))
    )

    metagraph = nx.from_scipy_sparse_matrix(A_metagraph, create_using=nx.DiGraph())
    metagraph_ordering=np.array(list(nx.topological_sort(metagraph)))
    
    terminal_scc_ind, _ = np.where(A_metagraph.sum(axis=1) == 0)
    terminal_scc_pos = np.isin(metagraph_ordering, terminal_scc_ind)
    
    nonterm_scc_num = len(num_verts_per_scc) - len(terminal_scc_ind)

    scc_sup1 = [i for i, scc in enumerate(scc_list) if len(scc) > 1]
    
    term_cycles_ind = set(scc_sup1).intersection(set(terminal_scc_ind))
    where_terminal_scc_pos, = np.where(terminal_scc_pos)

    if np.sum(np.logical_not(where_terminal_scc_pos>(nonterm_scc_num-1))) > 0:
        nonterm_scc_inds = np.logical_not(np.isin(metagraph_ordering, terminal_scc_ind))
        metagraph_ordering_terminal_bottom = np.concatenate([
            metagraph_ordering[nonterm_scc_inds],
            metagraph_ordering[terminal_scc_pos]
        ])

    else:
        metagraph_ordering_terminal_bottom = metagraph_ordering


    if len(term_cycles_ind) > 0:
        
        scc_cell_reordered = [scc_list[i] for i in metagraph_ordering_terminal_bottom]
        # index of cells containing term cycles after reordering
        term_cycles_ind, = np.where(np.isin(metagraph_ordering_terminal_bottom, np.array(list(term_cycles_ind))))

        # we need a cell of the indices of certices withing whese
        scc_cell_reordered_lengths = np.array([len(scc) for scc in scc_cell_reordered])
        scc_cell_reordered_cumsum = np.cumsum(scc_cell_reordered_lengths)
        
        cycle_first_verts = scc_cell_reordered_cumsum[term_cycles_ind] - scc_cell_reordered_lengths[term_cycles_ind];
        cycle_last_verts = scc_cell_reordered_cumsum[term_cycles_ind] - 1
        
        term_cycles_bounds = [np.concatenate([cycle_first_verts, cycle_last_verts])]
        
    else:
        term_cycles_ind = []
        term_cycles_bounds = []
        

    # reordered original vertices
    vert_topol_sort = np.concatenate([list(scc_list[i]) for i in metagraph_ordering_terminal_bottom])
    
    return vert_topol_sort, term_cycles_ind, A_metagraph, scc_list, term_cycles_bounds


def fcn_scc_subgraphs(A_sparse, x0):
    
    print("Indentifying SCCs")
    G = nx.from_scipy_sparse_matrix(A_sparse, create_using=nx.DiGraph())
    G.remove_edges_from(nx.selfloop_edges(G))
    
    # Here we get a generator. Do I really need to compute it now ?
    subnetws = [list(g) for g in nx.weakly_connected_components(G)]
    cell_subgraphs = []
    scc_submats = []
    nonempty_subgraphs = []
#     print(len(subnetws))
    print("Identifying SCCs in subgraphs")
    for i, subnet in enumerate(subnetws):
        cell_subgraphs.append(subnet)
        
        # Slicing done it two steps : First the rows, which is the most efficient for csr sparse matrix
        # then columns. I should probably dig deeper
        t_sparse = A_sparse[subnet, :][:, subnet]
        
        t_g = nx.from_scipy_sparse_matrix(t_sparse, create_using=nx.DiGraph())
        t_g.remove_edges_from(nx.selfloop_edges(t_g))
        
        # Again, do I really need to compute it ?
        scc_submats.append([list(g) for g in nx.strongly_connected_components(t_g)])

        if sum(x0[subnet]) > 0:
            nonempty_subgraphs.append(i)
    
    sorted_vertices = []
    cyclic_sorted_subgraphs = []
    counter = 0
    
    for nonempty_subgraph in nonempty_subgraphs:
        
        A_sparse_sub = A_sparse[subnetws[nonempty_subgraph], :][:, subnetws[nonempty_subgraph]]
    
        if A_sparse_sub.shape[0] == len(scc_submats[nonempty_subgraph]):
            t_g = nx.from_scipy_sparse_matrix(A_sparse_sub, create_using=nx.DiGraph())
            t_g.remove_edges_from(nx.selfloop_edges(t_g))
            sorted_vertices.append(list(nx.topological_sort(t_g)))
#             print("toposort results")
#             print(list(nx.topological_sort(t_g)))
        else:
            print("Cycles in STG")
            
            # If entire graph is only one connected component, no need for re-ordering
            if len(scc_submats[nonempty_subgraph]) == 1:
                sorted_vertices.append(scc_submats[nonempty_subgraph])
            else:
                vert_topol_sort,term_cycles_ind,_,scc_cell,term_cycle_bounds=fcn_metagraph_scc(A_sparse_sub)
                cycle_lengths = [len(scc) for scc in scc_cell]
                
                a = np.zeros((max(cycle_lengths)))
                for i in range(max(cycle_lengths)):
                    for j in cycle_lengths:
                        if j == i+1:
                            a[j-1] += 1
                    
                print('Cycles of lenth: %s (%s times)' % (set(cycle_lengths), a[np.where(a>0)]) )
                cyclic_sorted_subgraphs.append((vert_topol_sort, term_cycles_ind, term_cycle_bounds))

        counter += 1
               
    return (subnetws,scc_submats,nonempty_subgraphs,sorted_vertices,cyclic_sorted_subgraphs)
                   
                   

In [27]:
%time stg_sorting_cell = fcn_scc_subgraphs(A_sparse, x0)

Indentifying SCCs
Identifying SCCs in subgraphs
Cycles in STG
512036 connected components


MemoryError: Unable to allocate 2.00 TiB for an array with shape (524288, 524288) and data type float64

In [28]:
def fcn_block_inversion(K_sp_sub_reord, sorted_vertices_terminal_bottom, x0, submatrix_inds):
    """
        This function calculate kernels and stationary solution if all terminal
    """
    
    
    # Construct kernels from matrix blocks
    dim_kernel = sum(K_sp_sub_reord.diagonal() == 0)
    dim_matr = K_sp_sub_reord.shape[0]
    
    colnum_r_null_array = range(dim_kernel)
    term_block_inds = range(dim_matr - dim_kernel, dim_matr)
    nonterm_block_inds = range(dim_matr - dim_kernel)
    term_block = sparse.eye(dim_kernel)
  
    # Right kernel
    r0_blocks = sparse.lil_matrix((dim_matr, dim_kernel), dtype=np.float32)
    r0_blocks[np.ix_(term_block_inds, colnum_r_null_array)] = term_block    
    
    # Left kernel
    l0_blocks = sparse.lil_matrix((r0_blocks.shape[0], r0_blocks.shape[1]), dtype=np.float32).transpose()
    nonzeros = r0_blocks.nonzero()
    l0_blocks[(nonzeros[1], nonzeros[0])] = 1
    
    X_block = (
        -r0_blocks[np.ix_(term_block_inds, colnum_r_null_array)]
        *K_sp_sub_reord[np.ix_(term_block_inds, nonterm_block_inds)]
    )
    
    # Solution 6
    # https://stackoverflow.com/questions/1007442/mrdivide-function-in-matlab-what-is-it-doing-and-how-can-i-do-it-in-python
    #TL;DR: A/B = np.linalg.solve(B.conj().T, A.conj().T).conj().T
    # import time
    # import scipy
    # Here we have 3 solutions : scipy sparse, scipy dense, numpy
    # And numpy is faster on the kras example
    # Using sparse solve
    
    # t0 = time.time()
    # X_block = sparse.linalg.spsolve(
    #     K_sp_sub_reord[np.ix_(nonterm_block_inds,nonterm_block_inds)].tocsr().conj().transpose(),
    #     X_block.conj().transpose()
    # ).conj().transpose()
    
    # Using scipy solve
    # t1 = time.time()
    # X_block = scipy.linalg.solve(
    #     K_sp_sub_reord[np.ix_(nonterm_block_inds,nonterm_block_inds)].todense().conj().transpose(),
    #     X_block.todense().conj().transpose()
    # ).conj().transpose()
    
    # Using numpy's solve
    # t2 = time.time()
    X_block = np.linalg.solve(
        K_sp_sub_reord[np.ix_(nonterm_block_inds,nonterm_block_inds)].toarray().conj().transpose(),
        X_block.toarray().conj().transpose()
    ).conj().transpose()
    # print("1 : %.2gs, 2 : %.2gs, 3 : %.2gs" % (t1-t0, t2-t1, time.time()-t2))
    
    l0_blocks[np.ix_(colnum_r_null_array, nonterm_block_inds)] = X_block;

    stat_sol_submatr_blocks = r0_blocks * l0_blocks * x0[submatrix_inds[sorted_vertices_terminal_bottom]]
    
    return stat_sol_submatr_blocks

def fcn_adjug_matrix(A, col_arg):
    
    size_array = np.array(list(range(A.shape[0])))
    size_vect = len(size_array)
    adj_matrix = None
    if A.shape[0] == A.shape[1]:

        if len(col_arg) == 0:
            print("NOT IMPLEMENTED")

            #if a is a double

            #else if a is symbolic

            #end

            #for k in range(size_vect):
            #    for l in range(size_vect):
            #        adj_matrix = truc

            #if a is symbolic, we simplify

            pass

        else:
            import scipy
            
#             adj_matrix
            adj_matrix = []
            for k in size_array:
                adj_matrix.append(
                    pow(-1, k)*
                    scipy.linalg.det(
                        A[np.ix_(
                            range(1,size_vect), 
                            size_array[np.where(size_array != k)[0]]
                        )].todense()
                    )
                )

    else: #non square matrix
        adj_matrix = []
        print("non-square matrix")
        
    
    return adj_matrix

def fcn_left_kernel(K_sp_sub_reord, r0_blocks, dim_matr):
    
    print("Constructing left kernel")

    if len(r0_blocks.shape) > 1:
        dim_kernel = np.sum(np.logical_not(np.isin(np.sum(r0_blocks, axis=1), 0)))
        colnum_r_null_array = range(r0_blocks.shape[1])
        size_r0_blocks = r0_blocks.shape
    else:
        print("Here we have a problem, the r0_blocks is actually 1D. Trying... RESULTS NEED TO BE CHECKED !!!")
    
#         dim_kernel = np.sum(np.logical_not(np.isin(r0_blocks, 0)))
#         colnum_r_null_array = [0]
#         size_r0_blocks = [r0_blocks.shape[0], 1]

    term_block_inds = range(dim_matr -dim_kernel,dim_matr)
    nonterm_block_inds = range(dim_matr-dim_kernel)
    
    l0_blocks = sparse.lil_matrix((size_r0_blocks[0], size_r0_blocks[1])).transpose()
    t_inds = np.where(np.logical_not(np.isin(r0_blocks, 0)).transpose())
    
    l0_blocks[t_inds] = 1
    
    X_block = (
        -l0_blocks[np.ix_(colnum_r_null_array,term_block_inds)]
        *K_sp_sub_reord[np.ix_(term_block_inds, nonterm_block_inds)]
    )
    
    # Using numpy's solve
    X_block = np.linalg.solve(
        K_sp_sub_reord[np.ix_(nonterm_block_inds,nonterm_block_inds)].toarray().conj().transpose(),
        X_block.toarray().conj().transpose()
    ).conj().transpose()
    
    l0_blocks[np.ix_(colnum_r_null_array, nonterm_block_inds)] = X_block
    
    return l0_blocks

def split_calc_inverse(A_sparse, stg_sorting_cell, transition_rates_table, x0):
    (subnetws,scc_submats,nonempty_subgraphs,sorted_vertices,cyclic_sorted_subgraphs) = stg_sorting_cell
    
    # is the STG disconnected?
    stat_sol_blocks=sparse.lil_matrix((x0.shape[0], 1))
    # A_digraph=digraph(A_sparse,'omitselfloops'); 
    num_subnets = len(subnetws)
    # preallocate cell of term vertices and of subgraphs
    term_verts = []
    cell_subgraphs = []

    if num_subnets>1:
        print('STG has multiple subgraphs')

    counter_subgraphs=0
    
    for i in nonempty_subgraphs:
        
        submatrix_inds = np.array(subnetws[i])
        cell_subgraphs.append(submatrix_inds)

        if num_subnets > 1:
            print("Calculating subgraph #%d of %d" % (i+1, num_subnets))
            
        A_sparse_sub = A_sparse[subnetws[i], :][:, subnetws[i]]
        dim_matr = A_sparse_sub.shape[0]
        scc_submat = scc_submats[i]
        
        # IF all SCCs are single vertices (ie. no cycles)
        if len(set([tuple(t_submat) for t_submat in scc_submat])) == dim_matr:
            
            # function to reorder vertices and keep ordering
            terminal_nodes = np.where(A_sparse_sub.diagonal() == 1)
#             print(terminal_nodes)
            # this is a consistent ordering but terminals are not necessarily in lower right corner of matrix
            A_orig_reordered = A_sparse_sub[sorted_vertices[counter_subgraphs], :][:, sorted_vertices[counter_subgraphs]]

            
            # but we want to have terminal states acolnum_r_null_arrayt the bottom
            #print(sorted_vertices[counter_subgraphs])
            # This weird assignment syntax is because it returns a tuple of length one. This is valid, and it works
            terminal_indices, = np.where(np.isin(sorted_vertices[counter_subgraphs], terminal_nodes))
            terminal_rem_inds, = np.where(np.logical_not(np.isin(sorted_vertices[counter_subgraphs], terminal_nodes)))
            t_inds, = np.where(np.logical_not(np.isin(sorted_vertices[counter_subgraphs], terminal_nodes)))
            
            array_sorted_vertices = np.array(sorted_vertices[counter_subgraphs])

            sorted_vertices_terminal_bottom = (
                list(array_sorted_vertices[t_inds]) + list(array_sorted_vertices[terminal_indices])
#                 axis=1
            )
                  
            reordered_terminal_inds = list(terminal_rem_inds) + list(terminal_indices)
            
            A_sparse_sub_reordered_terminal = A_orig_reordered[reordered_terminal_inds, :][:, reordered_terminal_inds]
            
            K_sp_sub_reord = (A_sparse_sub_reordered_terminal.transpose() - sparse.eye(dim_matr)) * sum(transition_rates_table.flatten())

            stat_sol_submatr_blocks = fcn_block_inversion(K_sp_sub_reord, sorted_vertices_terminal_bottom, x0, submatrix_inds)

            stat_sol_blocks[submatrix_inds[sorted_vertices_terminal_bottom]] = stat_sol_submatr_blocks
            term_verts.append(set(stat_sol_blocks.nonzero()[0]).intersection(set(submatrix_inds)))
            
        else:
           
            print('cycles in STG')
            if len(scc_submat) == 1:
#             % if entire graph is one connected component, no reordering needed
                K_sp_sub_reord = (A_sparse_sub.transpose() - sparse.eye(dim_matr, dim_matr)) * sum(transition_rates_table.flatten())
                kernel_col = np.dot(pow(-1, (dim_matr-1)), fcn_adjug_matrix(K_sp_sub_reord, 'col'))
                # normalization
                r0_blocks = (kernel_col.transpose()/np.sum(kernel_col))
                if len(r0_blocks.shape) == 1:
                    r0_blocks = r0_blocks.reshape(r0_blocks.shape[0], 1)
                    
                l0_blocks = fcn_left_kernel(K_sp_sub_reord, r0_blocks, dim_matr)
                
                #stat sol
                stat_sol_submatr_blocks = np.dot(r0_blocks*l0_blocks,x0[submatrix_inds])
                
                stat_sol_blocks[submatrix_inds] = stat_sol_submatr_blocks
                term_verts.append(submatrix_inds)
                
                

            else:
                print("Not a unique connected component")
            
                vert_topol_sort = cyclic_sorted_subgraphs[counter_subgraphs][0]
                term_cycles_ind = cyclic_sorted_subgraphs[counter_subgraphs][1]
                term_cycle_bounds = cyclic_sorted_subgraphs[counter_subgraphs][2]
            
                A_sparse_sub_reordered_terminal = A_sparse_sub[vert_topol_sort,:][:, vert_topol_sort]
                K_sp_sub_reord = (A_sparse_sub_reordered_terminal.transpose() - sparse.eye(dim_matr, dim_matr))*sum(transition_rates_table.flatten())


                # if cycles are non-terminal, stat sol can be calculated by block inversion, sames as for acyclic graphs
                if len(term_cycles_ind) == 0:
                    print("Empty term cycles ind")
                    print(" NEEDS TO BE TESTED")
#                      % here make sure if 'vert_topol_sort' is the right ordering...
                    stat_sol_submatr_blocks = fcn_block_inversion(K_sp_sub_reord, vert_topol_sort, x0, submatrix_inds)
                    stat_sol_blocks[submatrix_inds[vert_topol_sort]] = stat_sol_submatr_blocks
                    term_verts_cell.append(submatrix_inds[vert_topol_sort[np.where(K_sp_sub_reord.diagonal() == 0)]])

                else:
                    print("Non empty term cycles ind")
                    
                    # if there are terminal cycles, stat sol calc a bit more complicated
                    # need to identify terminal cycles, for corresponding columns of
                    # kernel we'll need to calculate adjugate matrix

                    # probably we don't want it in symbolic form, but just in case
                    if K_sp_sub_reord.dtype == np.float64:
                        r_null_cycles = sparse.lil_matrix((dim_matr, len(term_cycle_bounds)))
                    
                    else:
                        print("NOT IMPLEMENTED")
                        return

                    for k, term_cycle_bound in enumerate(term_cycle_bounds):
                        cycle_inds = range(term_cycle_bound[0], term_cycle_bound[-1]+1)
                 
                        #calc kernel of scc
                        scc_cycle = K_sp_sub_reord[cycle_inds, :][:, cycle_inds]
                        # adjugate_matrix -> kernel
                        n = len(cycle_inds)
                        
                        kernel_col = np.dot(pow(-1, n-1) , fcn_adjug_matrix(scc_cycle, 'col'))
                        
                        r_null_cycles[cycle_inds,k] = kernel_col/sum(kernel_col)
                        
                    # if there are single-vertex terminal states too
                    if np.sum(np.isin(K_sp_sub_reord.diagonal(), 0)) > 0:
                        
                        print("single vertex terminal states")
                        print(" NOT IMPLEMENTED")
#                          n_terminal=find(ismember(diag(K_sp_sub_reord),0))'; 
#                         r_null_single_vert = sparse(dim_matr,numel(n_terminal)); 
#                         % (1:numel(n_terminal)-1)*2 + n_terminal
#                         r_null_single_vert( sub2ind(size(r_null_single_vert), n_terminal, 1:numel(n_terminal)) )=1;
#                         % does the order of columns in the kernel matter? I think not, if l0_blocks consistent w r0_blocks
#                         r0_blocks=[r_null_cycles r_null_single_vert];
                        return
                    else:
                        print("no single vertex terminal states")
                        r0_blocks = r_null_cycles
                        
                    # calculate kernel
                    l0_blocks = fcn_left_kernel(K_sp_sub_reord, r0_blocks, dim_matr)
                    
                    # stat sol
                    stat_sol_submatr_blocks = r0_blocks*l0_blocks*x0[submatrix_inds[vert_topol_sort]]
                    stat_sol_blocks[submatrix_inds[vert_topol_sort]] = stat_sol_submatr_blocks
                    row, col = r0_blocks.nonzero()
                    
                    pre_term_verts = []
                    
                    for k in range(len(set(col))):
                        pre_term_verts.append(
                            submatrix_inds[vert_topol_sort[row[np.where(col == k)]]]
                        )
                    term_verts.append(pre_term_verts)

        counter_subgraphs +=1
    return stat_sol_blocks, term_verts, cell_subgraphs

In [29]:
%time stat_sol,term_verts_cell,cell_subgraphs=split_calc_inverse(A_sparse,stg_sorting_cell,transition_rates_table,x0)

NameError: name 'stg_sorting_cell' is not defined

In [30]:
print(stat_sol)

NameError: name 'stat_sol' is not defined

In [None]:
stat_sol.sum()