In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import networkx as nx
import numpy as np
import boolean
import scipy.sparse as sparse
import sys


### Building STG table

In [4]:
def readBNet(filename):
    bnet_model = {}
    algebra = boolean.BooleanAlgebra()
    with open(filename, 'r') as f:
        lines = f.readlines()
        for line in lines:
            species, formula = [value.strip() for value in line.split(",")]
            if species != "target" and formula != "factors":
                b_formula = algebra.parse(formula).simplify()
                bnet_model.update({species:b_formula})
                
    return bnet_model

##### First, loading the model as a minibn

In [5]:
model_name_list = [
    'mammalian_cc', 
    'krasmodel15vars',
    'breast_cancer_zanudo2017',
    'EMT_cohen_ModNet',
    'sahin_breast_cancer_refined',
    'toy'
]
model_index = 1

In [6]:
model = readBNet(("model_files/%s.bnet" % model_name_list[model_index]))

In [7]:
nodes = list(model.keys())

In [8]:
n=len(nodes)
n

15

##### Then, actually building the table

In [22]:
from exastolog.StateTransitionTable import StateTransitionTable
stg_table = StateTransitionTable(model, nodes).stg_table
stg_table.shape

(188416, 4)

In [23]:
print("%.2fMbytes" % (stg_table.shape[0] * stg_table.shape[1] * stg_table.itemsize / (1024*1024)))

5.75Mbytes


##### Building transition rates table

In [24]:
# to define transition rates, we can select given rates to have different values than 1, or from randomly chosen
# name of rates: 'u_nodename' or 'd_nodename'
# chosen_rates=['u_ERBB1','u_ERBB2','u_ERBB3']; chosen_rates_vals=np.zeros(len(chosen_rates));
# OR leave them empty: 
chosen_rates = []
chosen_rates_vals = []

In [25]:
# then we generate the table of transition rates: first row is the 'up'rates, second row 'down' rates, 
# in the order of 'nodes'
# ARGUMENTS

In [26]:
distr_type = ['uniform', 'random'] # <uniform> assigns a value of 1 to all params. other option: <random>
meanval = 0.5 # if 'random' is chosen, the mean and standard dev of a normal distrib has to be defined
sd_val = 0.1

In [45]:
from exastolog.TransRateTable import TransRateTable
transition_rates_table = TransRateTable(nodes, distr_type[0], meanval, sd_val, chosen_rates, chosen_rates_vals).table

(2, 15)
(2, 15)


##### Building the (sparse) transition matrix

In [29]:
def fcn_build_trans_matr(stg_table, transition_rates_table, kin_matr_flag=""):

    dim_matr = pow(2, transition_rates_table.shape[1])
    
    rate_inds = ((stg_table[:, 2])*2)+stg_table[:, 3]

    # Here we reshape the transition_rates_table to a list
    reshaped_trt = np.reshape(transition_rates_table, (1, np.product(transition_rates_table.shape)), order="F")[0, :]

    B = sparse.csr_matrix(
        (
            reshaped_trt[rate_inds]/np.sum(transition_rates_table),
            (stg_table[:, 0], 
            stg_table[:, 1])
        ),
        shape=(dim_matr, dim_matr)
    )

    A_sparse = B + (sparse.eye(B.shape[0]) - sparse.diags(np.array(sparse.csr_matrix.sum(B, axis=1).transpose())[0]))

    if len(kin_matr_flag) > 0:
        K_sparse = (A_sparse.transpose() - sparse.eye(A_sparse.shape[0]))*np.sum(transition_rates_table)

    else:
        K_sparse = []

    return A_sparse, K_sparse

In [30]:
A_sparse, _ = fcn_build_trans_matr(stg_table, transition_rates_table, 'yes')

In [31]:
print("%.2f Mbytes" % (A_sparse.data.nbytes/(1024*1024)))

1.69 Mbytes


##### Setting up initial values

In [32]:
initial_fixed_nodes_list = [
    ['CycE','CycA','CycB','Cdh1','Rb_b1','Rb_b2','p27_b1','p27_b2'], # mammalian_cc
    ['cc','KRAS','DSB','cell_death'], #krasmodel15vars
    ['Alpelisib', 'Everolimus','PIM','Proliferation','Apoptosis'], # breast_cancer_zanudo2017
    ['ECMicroenv','DNAdamage','Metastasis','Migration','Invasion','EMT','Apoptosis','Notch_pthw','p53'], # EMT_cohen_ModNet 
    ['EGF','ERBB1','ERBB2','ERBB3','p21','p27'], # sahin_breast_cancer_refined
    ['A','C','D'] # toy model
] 

initial_fixed_nodes_vals_list = [
    [0, 0, 0, 1, 1, 1, 1, 1], # mammalian_cc
    [1, 1, 1, 0], # krasmodel15vars: [1 1] is cell cycle ON, KRAS mutation ON
    [0, 1, 0] + [0]*2, # breast_cancer_zanudo2017
    [1, 1] + [0]*5 + [1, 0], # EMT-Cohen model: [0/1 0/1 zeros(1,5)]
    [1, 0, 0, 0, 1, 1],
    [0, 0, 0]
] 

initial_fixed_nodes = initial_fixed_nodes_list[model_index]
initial_fixed_nodes_vals = initial_fixed_nodes_vals_list[model_index]


# what is the probability of this state, (eg. dom_prob=0.8, ie. 80% probability)
dom_prob = 1
# if plot_flag non-empty, we get a bar plot of initial values
plot_flag = ''

In [33]:
def fcn_define_initial_states(initial_fixed_nodes,initial_fixed_nodes_vals,dom_prob,nodes,distrib_type,plot_flag):
    
    n_nodes = len(nodes)
    
    truth_table_inputs = np.remainder(
        np.floor(
            np.multiply(
                np.array([range(pow(2, n_nodes))]).transpose(), 
                np.array([np.power([2.0]*n_nodes, np.array(range(0, -n_nodes, -1)))])
            )
        ), 2
    ).astype(bool)
    
    # define initial values
    x0 = np.zeros((int(pow(2, n_nodes)), 1))
    
    # defining a dominant initial state (eg. dom_prob=0.8, ie. 80% probability
    initial_on_nodes_inds = [node in initial_fixed_nodes for node in nodes]                                  

    statespace_decim = np.sum(
        truth_table_inputs[:, initial_on_nodes_inds]*np.power(
            2, 
            np.array(
                list(reversed(range(np.sum(initial_on_nodes_inds))))
            )
        ), axis=1
    )

    initial_fixed_nodes_vals_decim = np.sum(
        initial_fixed_nodes_vals*np.power(
            2, 
            np.array(
                list(reversed(range(len(initial_fixed_nodes_vals))))
            )
        )
    )

    inds_condition = np.isin(statespace_decim, initial_fixed_nodes_vals_decim)

    if distrib_type == "uniform":
        x0[inds_condition] = np.array([[dom_prob/sum(inds_condition)]*sum(inds_condition)]).transpose()
        x0[np.logical_not(inds_condition)] = np.array([[(1-dom_prob)/(len(x0)-sum(inds_condition))]*(len(x0)-sum(inds_condition))]).transpose()
    
    elif distrib_type == "random":
        x0[inds_condition] = np.random.uniform(0, 1, (sum(inds_condition), 1))
        x0 = dom_prob*x0/sum(x0)

        x0[np.logical_not(inds_condition)] = np.random.uniform(0, 1, (len(x0)-sum(inds_condition), 1))
        x0[np.logical_not(inds_condition)] = (1-dom_prob)*x0[np.logical_not(inds_condition)]/sum(x0[np.logical_not(inds_condition)])
    
    else:
        print("distrib type should be 'uniform' or 'random'", file=sys.stderr)
    
    # rounding precision
    n_prec=3;
    if round(sum(x0)[0],n_prec) == 1:
        print('sum(x0)=1, OK.')
    
    else:
        print('sum(x0)~=1, something wrong!')

#     if ~isempty(plot_flag)
#     bar(x0); set(gca,'yscale','log'); xlim([1 2^n_nodes]); % ylim([(1-dom_prob)/2^n_nodes 1])
#     % subplot(2,1,2); x0=fcn_define_initial_states(initial_on_nodes,dom_prob,nodes,'broad'); 
#     % bar(x0); xlim([1 2^13]);set(gca,'yscale','log'); ylim([(1-dom_prob)/2^n_nodes 1])
#     end

    return x0

In [34]:
x0 = fcn_define_initial_states(initial_fixed_nodes, initial_fixed_nodes_vals, dom_prob, nodes, "uniform", plot_flag)

sum(x0)=1, OK.


In [35]:
x0.shape

(32768, 1)

In [36]:
import networkx as nx

def fcn_metagraph_scc(A_sparse_sub):
    
    matr_size = A_sparse_sub.shape[0]
    g_sub = nx.from_scipy_sparse_matrix(A_sparse_sub, create_using=nx.DiGraph())
    g_sub.remove_edges_from(nx.selfloop_edges(G))
    
    
    
    
    return None, None, None, None, None


def fcn_scc_subgraphs(A_sparse, x0):
    
    print("Indentifying SCCs")
    G = nx.from_scipy_sparse_matrix(A_sparse, create_using=nx.DiGraph())
    G.remove_edges_from(nx.selfloop_edges(G))
    
    # Here we get a generator. Do I really need to compute it now ?
    subnetws = [list(g) for g in nx.weakly_connected_components(G)]
    cell_subgraphs = []
    scc_submats = []
    nonempty_subgraphs = []
#     print(len(subnetws))
    print("Identifying SCCs in subgraphs")
    for i, subnet in enumerate(subnetws):
        cell_subgraphs.append(subnet)
        
        # Slicing done it two steps : First the rows, which is the most efficient for csr sparse matrix
        # then columns. I should probably dig deeper
        t_sparse = A_sparse[subnet, :]
        t_sparse = t_sparse[:, subnet]
        
        t_g = nx.from_scipy_sparse_matrix(t_sparse, create_using=nx.DiGraph())
        t_g.remove_edges_from(nx.selfloop_edges(t_g))
        
        # Again, do I really need to compute it ?
        scc_submats.append([list(g) for g in nx.strongly_connected_components(t_g)])

        print(len(scc_submats[i]))
        if sum(x0[subnet]) > 0:
            nonempty_subgraphs.append(i)
    
    sorted_vertices = []
    cyclic_sorted_subgraphs = []
    counter = 0
    for nonempty_subgraph in nonempty_subgraphs:
        A_sparse_sub = A_sparse[subnetws[nonempty_subgraph], :]
        A_sparse_sub = A_sparse_sub[:, subnetws[nonempty_subgraph]]
    
        if A_sparse_sub.shape[0] == len(scc_submats[nonempty_subgraph]):
            t_g = nx.from_scipy_sparse_matrix(A_sparse_sub, create_using=nx.DiGraph())
            t_g.remove_edges_from(nx.selfloop_edges(t_g))
            sorted_vertices.append(list(nx.topological_sort(t_g)))
#             print("toposort results")
#             print(list(nx.topological_sort(t_g)))
        else:
            print("Cycles in STG")
            
            # If entire graph is only one connected component, no need for re-ordering
            if len(scc_submats[nonempty_subgraph]) == 1:
                sorted_vertices.append(scc_submats[nonempty_subgraph])
            else:
                print("NOT IMPLEMENTED YET")
                ## THIS IS NOT IMPLEMENTED YET, FOCUSING ON FINISHING THE FIRST EXAMPLE
                vert_topol_sort,term_cycles_ind,_,scc_cell,term_cycle_bounds=fcn_metagraph_scc(A_sparse_sub)

        counter += 1
               
    return (subnetws,scc_submats,nonempty_subgraphs,sorted_vertices,cyclic_sorted_subgraphs)
                   
                   

In [37]:
%time stg_sorting_cell = fcn_scc_subgraphs(A_sparse, x0)

Indentifying SCCs
Identifying SCCs in subgraphs
8192
8192
8192
8192
CPU times: user 6.15 s, sys: 68.8 ms, total: 6.22 s
Wall time: 6.21 s


In [38]:
def fcn_block_inversion(K_sp_sub_reord, sorted_vertices_terminal_bottom, x0, submatrix_inds):
    """
        This function calculate kernels and stationary solution if all terminal
    """
    
    
    # Construct kernels from matrix blocks
    dim_kernel = sum(K_sp_sub_reord.diagonal() == 0)
    dim_matr = K_sp_sub_reord.shape[0]
    
    colnum_r_null_array = range(dim_kernel)
    term_block_inds = range(dim_matr - dim_kernel, dim_matr)
    nonterm_block_inds = range(dim_matr - dim_kernel)
    term_block = sparse.eye(dim_kernel)
  
    # Right kernel
    r0_blocks = sparse.lil_matrix((dim_matr, dim_kernel), dtype=np.float32)
    r0_blocks[np.ix_(term_block_inds, colnum_r_null_array)] = term_block    
    
    # Left kernel
    l0_blocks = sparse.lil_matrix((r0_blocks.shape[0], r0_blocks.shape[1]), dtype=np.float32).transpose()
    nonzeros = r0_blocks.nonzero()
    l0_blocks[(nonzeros[1], nonzeros[0])] = 1
    
    X_block = (
        -r0_blocks[np.ix_(term_block_inds, colnum_r_null_array)]
        *K_sp_sub_reord[np.ix_(term_block_inds, nonterm_block_inds)]
    )
    
    # Solution 6
    # https://stackoverflow.com/questions/1007442/mrdivide-function-in-matlab-what-is-it-doing-and-how-can-i-do-it-in-python
    #TL;DR: A/B = np.linalg.solve(B.conj().T, A.conj().T).conj().T
    # import time
    # import scipy
    # Here we have 3 solutions : scipy sparse, scipy dense, numpy
    # And numpy is faster on the kras example
    # Using sparse solve
    
    # t0 = time.time()
    # X_block = sparse.linalg.spsolve(
    #     K_sp_sub_reord[np.ix_(nonterm_block_inds,nonterm_block_inds)].tocsr().conj().transpose(),
    #     X_block.conj().transpose()
    # ).conj().transpose()
    
    # Using scipy solve
    # t1 = time.time()
    # X_block = scipy.linalg.solve(
    #     K_sp_sub_reord[np.ix_(nonterm_block_inds,nonterm_block_inds)].todense().conj().transpose(),
    #     X_block.todense().conj().transpose()
    # ).conj().transpose()
    
    # Using numpy's solve
    # t2 = time.time()
    X_block = np.linalg.solve(
        K_sp_sub_reord[np.ix_(nonterm_block_inds,nonterm_block_inds)].toarray().conj().transpose(),
        X_block.toarray().conj().transpose()
    ).conj().transpose()
    # print("1 : %.2gs, 2 : %.2gs, 3 : %.2gs" % (t1-t0, t2-t1, time.time()-t2))
    
    l0_blocks[np.ix_(colnum_r_null_array, nonterm_block_inds)] = X_block;

    stat_sol_submatr_blocks = r0_blocks * l0_blocks * x0[submatrix_inds[sorted_vertices_terminal_bottom]]
    
    return stat_sol_submatr_blocks



def split_calc_inverse(A_sparse, stg_sorting_cell, transition_rates_table, x0):
    (subnetws,scc_submats,nonempty_subgraphs,sorted_vertices,cyclic_sorted_subgraphs) = stg_sorting_cell
    
    # is the STG disconnected?
    stat_sol_blocks=sparse.lil_matrix((x0.shape[0], 1))
    # A_digraph=digraph(A_sparse,'omitselfloops'); 
    num_subnets = len(subnetws)
    # preallocate cell of term vertices and of subgraphs
    term_verts = []
    cell_subgraphs = []

    if num_subnets>1:
        print('STG has multiple subgraphs')

    counter_subgraphs=0
    
    for i in nonempty_subgraphs:
        
        submatrix_inds = np.array(subnetws[i])
        cell_subgraphs.append(submatrix_inds)

        if num_subnets > 1:
            print("Calculating subgraph #%d of %d" % (i+1, num_subnets))
            
        A_sparse_sub = A_sparse[subnetws[i], :][:, subnetws[i]]
        dim_matr = A_sparse_sub.shape[0]
        scc_submat = scc_submats[i]
        
        # IF all SCCs are single vertices (ie. no cycles)
        if len(set([tuple(t_submat) for t_submat in scc_submat])) == dim_matr:
            
            # function to reorder vertices and keep ordering
            terminal_nodes = np.where(A_sparse_sub.diagonal() == 1)
#             print(terminal_nodes)
            # this is a consistent ordering but terminals are not necessarily in lower right corner of matrix
            A_orig_reordered = A_sparse_sub[sorted_vertices[counter_subgraphs], :][:, sorted_vertices[counter_subgraphs]]

            
            # but we want to have terminal states at the bottom
            #print(sorted_vertices[counter_subgraphs])
            # This weird assignment syntax is because it returns a tuple of length one. This is valid, and it works
            terminal_indices, = np.where(np.isin(sorted_vertices[counter_subgraphs], terminal_nodes))
            terminal_rem_inds, = np.where(np.logical_not(np.isin(sorted_vertices[counter_subgraphs], terminal_nodes)))
            t_inds, = np.where(np.logical_not(np.isin(sorted_vertices[counter_subgraphs], terminal_nodes)))
            
            array_sorted_vertices = np.array(sorted_vertices[counter_subgraphs])

            sorted_vertices_terminal_bottom = (
                list(array_sorted_vertices[t_inds]) + list(array_sorted_vertices[terminal_indices])
#                 axis=1
            )
                  
            reordered_terminal_inds = list(terminal_rem_inds) + list(terminal_indices)
            
            A_sparse_sub_reordered_terminal = A_orig_reordered[reordered_terminal_inds, :][:, reordered_terminal_inds]
            
            K_sp_sub_reord = (A_sparse_sub_reordered_terminal.transpose() - sparse.eye(dim_matr)) * sum(transition_rates_table.flatten())

            stat_sol_submatr_blocks = fcn_block_inversion(K_sp_sub_reord, sorted_vertices_terminal_bottom, x0, submatrix_inds)

            stat_sol_blocks[submatrix_inds[sorted_vertices_terminal_bottom]] = stat_sol_submatr_blocks
            term_verts.append(set(stat_sol_blocks.nonzero()[0]).intersection(set(submatrix_inds)))
            
        else:
            #Non implemented yet
            print("Not implemented yet !")
            pass
        
        counter_subgraphs +=1
    return stat_sol_blocks, term_verts, cell_subgraphs

In [39]:
%time stat_sol,term_verts_cell,cell_subgraphs=split_calc_inverse(A_sparse,stg_sorting_cell,transition_rates_table,x0)


STG has multiple subgraphs
Calculating subgraph #4 of 4
CPU times: user 14.5 s, sys: 545 ms, total: 15 s
Wall time: 3.96 s


In [40]:
print(stat_sol)

  (35, 0)	0.16982301660209487
  (67, 0)	0.01069429074391337
  (291, 0)	0.25817666968396225
  (323, 0)	0.01918748778666668
  (16159, 0)	0.3131778050428693
  (16387, 0)	0.09025865387229715
  (16643, 0)	0.1386820789120975


In [41]:
stat_sol.sum()

1.000000002643901