In [1]:
%load_ext autoreload
%autoreload 2

In [36]:
import numpy as np
import boolean
import scipy as sp
import sys

### Building STG table

In [3]:
def readBNet(filename):
    bnet_model = {}
    algebra = boolean.BooleanAlgebra()
    with open(filename, 'r') as f:
        lines = f.readlines()
        for line in lines:
            species, formula = [value.strip() for value in line.split(",")]
            if species != "target" and formula != "factors":
                b_formula = algebra.parse(formula).simplify()
                bnet_model.update({species:b_formula})
                
    return bnet_model

##### First, loading the model as a minibn

In [4]:
model_name_list = [
    'mammalian_cc', 
    'krasmodel15vars',
    'breast_cancer_zanudo2017',
    'EMT_cohen_ModNet',
    'sahin_breast_cancer_refined'
]
model_index = 1

In [5]:
model = readBNet(("model_files/%s.bnet" % model_name_list[model_index]))

In [6]:
nodes = list(model.keys())

In [7]:
n=len(nodes)
n

15

##### Then, actually building the table

In [8]:
def fcn_gen_node_update(formula, list_binary_states, nodes):

    if isinstance(formula, boolean.boolean.Symbol):
        return list_binary_states[:, nodes.index(str(formula))]
    
    elif isinstance(formula, boolean.boolean.NOT):
        return np.logical_not(
            fcn_gen_node_update(formula.args[0], list_binary_states, nodes)
        )
    
    elif isinstance(formula, boolean.boolean.OR):
        ret = fcn_gen_node_update(formula.args[0], list_binary_states, nodes)
        for i in range(1, len(formula.args)):
            ret = np.logical_or(ret, 
                fcn_gen_node_update(formula.args[i], list_binary_states, nodes)
            )
        return ret
    
    elif isinstance(formula, boolean.boolean.AND):
        ret = fcn_gen_node_update(formula.args[0], list_binary_states, nodes)
        for i in range(1, len(formula.args)):
            ret = np.logical_and(ret, 
                fcn_gen_node_update(formula.args[i], list_binary_states, nodes)
            )
        return ret
    
    else:
        print("Unknown boolean operator : %s" % type(formula))

In [9]:
def fcn_build_update_table(model, list_binary_states, nodes):
    update_matrix = np.array(
        [
            fcn_gen_node_update(model[node], list_binary_states, nodes) 
            for node in nodes
        ]
    ).transpose()
    
    return update_matrix

In [10]:
def fcn_states_inds(yes_no, n_series_exp, n_isl_exp):
    
    n_series_exp = n_series_exp - 1
    yes_no = yes_no - 1
    
    f_mat = np.array(
        range(
            1, 
            pow(2, (n_series_exp-n_isl_exp))+1
        )
    ) + yes_no

    t_repmat = np.array([f_mat]*int(pow(2, n_isl_exp)))
        
    t_reshaped = np.reshape(t_repmat, (1, int(pow(2, n_series_exp))), order='F')
    
    t_mult = t_reshaped*pow(2, n_isl_exp)
    t_last = np.array(
        range(
            1, 
            pow(2, n_series_exp)+1
        )
    )
    
    return np.sum([t_last, t_mult])-1

In [11]:
def fcn_build_stg_table(model, nodes):
    list_binary_states = np.remainder(
        np.floor(
            np.multiply(
                np.array([range(pow(2, n))]).transpose(), 
                np.array([np.power([2.0]*n, np.array(range(0, -n, -1)))])
            )
        ), 2
    ).astype(bool)
    
    update_table = fcn_build_update_table(model, list_binary_states, nodes)
    
    up_trans_source = [
        np.intersect1d(
            np.nonzero(update_table[:, x])[0],
            fcn_states_inds(0, n, x)[0, :]
        ) 
        for x in range(n)
    ]
        
    down_trans_source = [
        np.intersect1d(
            np.nonzero(np.logical_not(update_table[:, x]))[0],
            fcn_states_inds(1, n, x)[0, :]
        ) 
        for x in range(n)
    ]
    
    down_trans_target = [
        np.concatenate(
            (
                np.array([down_trans_source[x]-pow(2, x)]).transpose(), 
                np.repeat(np.array([[x,1]]), len(down_trans_source[x]), axis=0)
            ), axis=1
        )
        for x in range(len(down_trans_source))
    ]
    
    up_trans_target = [
        np.concatenate(
            (
                np.array([up_trans_source[x]+pow(2, x)]).transpose(), 
                np.repeat(np.array([[x,0]]), len(up_trans_source[x]), axis=0)
            ), axis=1
        )
        for x in range(len(up_trans_source))
    ]
    
    source = np.concatenate([
        np.concatenate(down_trans_source, axis=0),
        np.concatenate(up_trans_source, axis=0)
    ])

    target = np.concatenate([
        np.concatenate(down_trans_target, axis=0),
        np.concatenate(up_trans_target, axis=0)
    ])
    
    stg_table = np.concatenate((np.array([source]).transpose(), target), axis=1)
    
    return stg_table

In [12]:
stg_table = fcn_build_stg_table(model, nodes)
stg_table.shape

(188416, 4)

In [13]:
print("%.2fMbytes" % (stg_table.shape[0] * stg_table.shape[1] * stg_table.itemsize / (1024*1024)))

5.75Mbytes


##### Building transition rates table

In [24]:
# to define transition rates, we can select given rates to have different values than 1, or from randomly chosen
# name of rates: 'u_nodename' or 'd_nodename'
# chosen_rates=['u_ERBB1','u_ERBB2','u_ERBB3']; chosen_rates_vals=np.zeros(len(chosen_rates));
# OR leave them empty: 
chosen_rates = []
chosen_rates_vals = []

# chosen_rates = ["u_CHEK1", "d_CHEK1"]
# chosen_rates_vals = np.ones(len(chosen_rates))

In [25]:
# then we generate the table of transition rates: first row is the 'up'rates, second row 'down' rates, 
# in the order of 'nodes'
# ARGUMENTS

In [26]:
distr_type = ['uniform', 'random'] # <uniform> assigns a value of 1 to all params. other option: <random>
meanval = 0.5 # if 'random' is chosen, the mean and standard dev of a normal distrib has to be defined
sd_val = 0.1

In [27]:
def fcn_trans_rates_table(nodes, uniform_or_rand, meanval, sd_val, chosen_rates, chosen_rates_vals):
    n = len(nodes)

    if uniform_or_rand == "uniform":
        rate_vals_num = np.ones((1, 2*n)).astype(np.int64)

    elif uniform_or_rand == "random":
        rate_vals_num = np.random.normal(meanval, sd_val, (1, 2*n))
        if np.any(rate_vals_num < 0):
            neg_cnt = 0
            while np.any(rate_vals_num < 0):
                rate_vals_num = np.random.normal(meanval, sd_val, (1, 2*n))
                neg_cnt += 1
                if neg_cnt > 100:
                    break
                    
    else:
        print("Choose 'uniform' or 'random' to generate transition rates", file=sys.stderr)
        return

    for k, chosen_rate in enumerate(chosen_rates):
        split_rate = chosen_rate.split("_")
        if len(split_rate) > 2:
            node_mod_ind = "_".join(split_rate[1:])
        else:
            node_mod_ind = split_rate[1]

        if split_rate[0] == "d":
            rate_vals_num[:, nodes.index(node_mod_ind)+n] = chosen_rates_vals[k]
        elif split_rate[0] == "u":
            rate_vals_num[:, nodes.index(node_mod_ind)] = chosen_rates_vals[k]
        else:
            print("Wrong name for transition rate, has to be 'u_nodename' or 'd_nodename'", file=sys.stderr)
            return
    
    return np.reshape(rate_vals_num, (n, 2)).transpose()


In [28]:
transition_rates_table = fcn_trans_rates_table(nodes, distr_type[0], meanval, sd_val, chosen_rates, chosen_rates_vals)

##### Building the (sparse) transition matrix

In [29]:
def fcn_build_trans_matr(stg_table, transition_rates_table, kin_matr_flag=""):

    dim_matr = pow(2, transition_rates_table.shape[1])
    
    rate_inds = ((stg_table[:, 2])*2)+stg_table[:, 3]

    # Here we reshape the transition_rates_table to a list
    reshaped_trt = np.reshape(transition_rates_table, (1, np.product(transition_rates_table.shape)), order="F")[0, :]

    B = sp.sparse.csr_matrix(
        (
            reshaped_trt[rate_inds]/np.sum(transition_rates_table),
            (stg_table[:, 0], 
            stg_table[:, 1])
        ),
        shape=(dim_matr, dim_matr)
    )

    A_sparse = B + (sp.sparse.eye(B.shape[0]) - sp.sparse.diags(np.array(sp.sparse.csr_matrix.sum(B, axis=1).transpose())[0]))

    if len(kin_matr_flag) > 0:
        K_sparse = (A_sparse.transpose() - sp.sparse.eye(A_sparse.shape[0]))*np.sum(transition_rates_table)

    else:
        K_sparse = []

    return A_sparse, K_sparse

In [30]:
A_sparse, _ = fcn_build_trans_matr(stg_table, transition_rates_table, 'yes')

In [31]:
print("%.2f Mbytes" % (A_sparse.data.nbytes/(1024*1024)))

1.69 Mbytes


##### Setting up initial values

In [32]:
initial_fixed_nodes_list = [
    ['CycE','CycA','CycB','Cdh1','Rb_b1','Rb_b2','p27_b1','p27_b2'], # mammalian_cc
    ['cc','KRAS','DSB','cell_death'], #krasmodel15vars
    ['Alpelisib', 'Everolimus','PIM','Proliferation','Apoptosis'], # breast_cancer_zanudo2017
    ['ECMicroenv','DNAdamage','Metastasis','Migration','Invasion','EMT','Apoptosis','Notch_pthw','p53'], # EMT_cohen_ModNet 
    ['EGF','ERBB1','ERBB2','ERBB3','p21','p27'] # sahin_breast_cancer_refined
] 

initial_fixed_nodes_vals_list = [
    [0, 0, 0, 1, 1, 1, 1, 1], # mammalian_cc
    [1, 1, 1, 0], # krasmodel15vars: [1 1] is cell cycle ON, KRAS mutation ON
    [0, 1, 0] + [0]*2, # breast_cancer_zanudo2017
    [1, 1] + [0]*5 + [1, 0], # EMT-Cohen model: [0/1 0/1 zeros(1,5)]
    [1, 0, 0, 0, 1, 1]
] 

initial_fixed_nodes = initial_fixed_nodes_list[model_index]
initial_fixed_nodes_vals = initial_fixed_nodes_vals_list[model_index]


# what is the probability of this state, (eg. dom_prob=0.8, ie. 80% probability)
dom_prob = 1
# if plot_flag non-empty, we get a bar plot of initial values
plot_flag = ''

In [33]:
def fcn_define_initial_states(initial_fixed_nodes,initial_fixed_nodes_vals,dom_prob,nodes,distrib_type,plot_flag):
    
    n_nodes = len(nodes)
    
    truth_table_inputs = np.remainder(
        np.floor(
            np.multiply(
                np.array([range(pow(2, n_nodes))]).transpose(), 
                np.array([np.power([2.0]*n_nodes, np.array(range(0, -n_nodes, -1)))])
            )
        ), 2
    ).astype(bool)
    
    # define initial values
    x0 = np.zeros((int(pow(2, n_nodes)), 1))
    
    # defining a dominant initial state (eg. dom_prob=0.8, ie. 80% probability
    initial_on_nodes_inds = [node in initial_fixed_nodes for node in nodes]                                  

    statespace_decim = np.sum(
        truth_table_inputs[:, initial_on_nodes_inds]*np.power(
            2, 
            np.array(
                list(reversed(range(np.sum(initial_on_nodes_inds))))
            )
        ), axis=1
    )

    initial_fixed_nodes_vals_decim = np.sum(
        initial_fixed_nodes_vals*np.power(
            2, 
            np.array(
                list(reversed(range(len(initial_fixed_nodes_vals))))
            )
        )
    )

    inds_condition = np.isin(statespace_decim, initial_fixed_nodes_vals_decim)

    if distrib_type == "uniform":
        x0[inds_condition] = np.array([[dom_prob/sum(inds_condition)]*sum(inds_condition)]).transpose()
        x0[np.logical_not(inds_condition)] = np.array([[(1-dom_prob)/(len(x0)-sum(inds_condition))]*(len(x0)-sum(inds_condition))]).transpose()
    
    elif distrib_type == "random":
        x0[inds_condition] = np.random.uniform(0, 1, (sum(inds_condition), 1))
        x0 = dom_prob*x0/sum(x0)

        x0[np.logical_not(inds_condition)] = np.random.uniform(0, 1, (len(x0)-sum(inds_condition), 1))
        x0[np.logical_not(inds_condition)] = (1-dom_prob)*x0[np.logical_not(inds_condition)]/sum(x0[np.logical_not(inds_condition)])
    
    else:
        print("distrib type should be 'uniform' or 'random'", file=sys.stderr)
    
    # rounding precision
    n_prec=3;
    if round(sum(x0)[0],n_prec) == 1:
        print('sum(x0)=1, OK.')
    
    else:
        print('sum(x0)~=1, something wrong!')

#     if ~isempty(plot_flag)
#     bar(x0); set(gca,'yscale','log'); xlim([1 2^n_nodes]); % ylim([(1-dom_prob)/2^n_nodes 1])
#     % subplot(2,1,2); x0=fcn_define_initial_states(initial_on_nodes,dom_prob,nodes,'broad'); 
#     % bar(x0); xlim([1 2^13]);set(gca,'yscale','log'); ylim([(1-dom_prob)/2^n_nodes 1])
#     end

    return x0

In [34]:
x0 = fcn_define_initial_states(initial_fixed_nodes, initial_fixed_nodes_vals, dom_prob, nodes, "uniform", plot_flag)

sum(x0)=1, OK.


In [35]:
x0.shape

(32768, 1)