# Multi CARNIVAL on PANACEA

This notebook shows how the results can be generated. To generate the results for the manuscript, we used the `script.py` in a HPC for convenience. Multi-condition methods require a high performance MILP solver. We used GUROBI for all experiments.

In [None]:
import corneto as cn
import pandas as pd
import numpy as np

cn.info()

In [None]:
selected_drug = "PONATINIB"
selected_cells = ["H1793", "LNCAP", "KRJ1", "HCC1143", "EFO21", "PANC1", "HF2597"]

In [None]:
df_panacea = pd.read_csv("GSE186341-PANACEA.tsv.xz", sep='\t')
df_panacea['drug'] = df_panacea['obs_id'].str.extract(r'_(.*?)_v')
df_panacea['cell'] = df_panacea['obs_id'].str.extract( r'^([^_]*)')
df_panacea['sign'] = np.sign(df_panacea['act'])
df_panacea

In [None]:
def filter_df(
    cells,
    drugs,
    df = df_panacea, 
    resource = "dorothea", 
    pipeline = "NA+deseq2", 
    statparam = "stat", 
    status = "unfiltered", 
    padj = 0.05
):
    c = [c.upper() for c in cells]
    d = [d.upper() for d in drugs]
    dff = df[
        (df.cell.str.upper().isin(c)) & 
        (df.drug.str.upper().isin(d)) & 
        (df.resource == resource) & 
        (df.pipeline == pipeline) &
        (df.statparam == statparam) &
        (df.status == status) &
        (df.padj <= padj)
    ]
    return dff

df_conditions = filter_df(selected_cells, [selected_drug])
df_conditions

In [None]:
def get_measurements(cell, drug, df = df_panacea, resource = "dorothea", pipeline = "NA+deseq2", statparam = "stat", padj = 0.05, as_dict=True):
    df_r = df[
        (df.drug.str.upper() == drug.upper()) & 
        (df.cell.str.upper() == cell.upper()) & 
        (df.resource == resource) & 
        (df.pipeline == pipeline) &
        (df.statparam == statparam) &
        (df.padj <= padj)
    ]
    if as_dict:
        return df_r[["items", "sign"]].set_index("items").to_dict()["sign"]
    return df_r

get_measurements(selected_cells[0], selected_drug)

In [None]:
G_pkn = cn.Graph.from_sif("network_collectri.sif", has_header=True, column_order=[0, 2, 1])
G_pkn.shape

In [None]:
# Drug bank PONATINIB interactions
#BCR, ABL, VEGFR, PDGFR, FGFR, EPH receptors and SRC families of kinases, and KIT, RET, TIE2, and FLT3
ponatinib_targets = ["BCR", "ABL", "VEGFR", "PDGFRA_PDGFRB", "FGFR1", "FGFR2", "FGFR3", "FGFR4", "EPH", "SRC", "KIT", "RET", "TIE2", "FLT3"]

In [None]:
targets = {v: -1 for v in G_pkn.V if v in ponatinib_targets}
targets

In [None]:
from corneto.methods.carnival import milp_carnival
from corneto.methods.carnival import preprocess_graph, milp_carnival
from corneto.methods.signaling import create_flow_graph, signflow
from corneto.methods.carnival import runCARNIVAL_AcyclicFlow, runCARNIVAL_Flow_Acyclic, create_flow_carnival, create_flow_carnival_v2, create_flow_carnival_v3, create_flow_carnival_v4
from corneto.methods import expand_graph_for_flows


In [None]:
def postprocess(G, edge_var, inputs, outputs):
    sel_edges = set(np.flatnonzero(np.abs(edge_var) > 0))
    exclude_edges = set()
    G_edges = G.E
    for eidx in sel_edges:
        s, t = G_edges[eidx]
        s = list(s)
        t = list(t)
        if len(s) == 0 or len(t) == 0 or s[0].startswith("_") or t[0].startswith("_"):
            exclude_edges.add(eidx)
    sel_edges = list(sel_edges.difference(exclude_edges))
    G_sol = G.edge_subgraph(sel_edges)
    sel_edges = set(sel_edges)
    G_solp, p_inputs, p_outputs = preprocess_graph(G_sol, inputs, outputs)
    return G_solp
    
def score(G, edge_var, vertex_var, inputs, outputs):
    # Manual score a solution to avoid impl. specific differences
    G_sol = postprocess(G, edge_var, inputs, outputs)
    err_outputs = 0
    Vsol = list(G_sol.V)
    total = 0
    V = list(G.V)
    for k, v in outputs.items():
        if k not in Vsol:
            err_outputs += abs(v)
        else:
            err_outputs += abs(vertex_var[V.index(k)] - v)
        total += abs(v)
    return err_outputs/total, G_sol.shape[1]

def prune(c_data, G_pkn):
    all_inputs, all_outputs = set(), set()
    for k, v in c_data.items():
        for ki, (ti, vi) in v.items():
            if isinstance(ki, str):
                if ti == 'P':
                    all_inputs.add(ki)
                else:
                    all_outputs.add(ki)
            else:
                print(ki, ti, vi)
    
    V = set(G_pkn.vertices)
    c_inputs = V.intersection(all_inputs)
    c_outputs = V.intersection(all_outputs)
    print(f"{len(c_inputs)}/{len(all_inputs)} inputs mapped to the graph")
    print(f"{len(c_outputs)}/{len(all_outputs)} outputs mapped to the graph")
    print(f"Pruning the graph with size: V x E = {G_pkn.shape}...")
    Gp = G_pkn.prune(list(c_inputs), list(c_outputs))
    print(f"Finished. Final size: V x E = {Gp.shape}.")
    return Gp, c_inputs, c_outputs

def convert_input_dict(dataset):
    conditions = dict()
    for k, exp in dataset.items():
        d_k = dict()
        for inp, val in exp["input"].items():
            d_k[inp] = ('P', val)
        for outp, val in exp["output"].items():
            d_k[outp] = ('M', val)
        conditions[k] = d_k
    return conditions

def single_carnival(G, dataset, beta=0.25, max_time=600, norel=0, seed=0):
    all_edges = np.zeros(G.shape[1])
    selected_edges_per_sample = []
    scores = []
    E = list(G.E)
    problems = []
    for k, exp in dataset.items():
        sol_edges = np.zeros(G.shape[1])
        exp_inputs, exp_outputs = exp["input"], exp["output"]
        Gp, cp_inputs, cp_outputs = preprocess_graph(G, exp_inputs, exp_outputs)
        print(k, Gp.shape, len(cp_inputs), len(cp_outputs))
        P = milp_carnival(Gp, cp_inputs, cp_outputs, beta_weight=beta)
        P.solve(solver="GUROBI", IntegralityFocus=1, TimeLimit=max_time, NoRelHeurTime=norel, Seed=seed, verbosity=0)
        for o in P.objectives:
            print(o.value)
        s = score(Gp, P.expr.edge_values.value, P.expr.vertex_values.value, cp_inputs, cp_outputs)
        scores.append(s)
        problems.append(P)
        # Select the edges
        E_gp = Gp.E
        sel_edges = np.flatnonzero(np.abs(P.expr.edge_values.value)>0)
        sel_edges = [E_gp[idx] for idx in sel_edges]
        for i, e in enumerate(E):
            if e in sel_edges:
                all_edges[i] += 1
                sol_edges[i] = 1
        selected_edges_per_sample.append(sol_edges)
    return problems, scores, all_edges, selected_edges_per_sample


def multi_carnival(G, dataset, lambd=0.25, norel=0, max_time=600, seed=0):
    d = convert_input_dict(dataset)
    G_multi, input_multi, output_multi = prune(d, G)
    print(G_multi.shape, len(input_multi), len(output_multi))
    all_v = input_multi.union(output_multi)
    # Remove non reachable so error is 0 if reachable are fit
    # as in carnival single
    d2 = dict()
    for k, v in d.items():
        filtered_dict = {key: value for key, value in v.items() if key in all_v}
        d2[k] = filtered_dict
    d = d2
    G_multi, input_multi, output_multi = prune(d, G) 
    # Clean non reachable vertices
    G_pkn = create_flow_graph(G_multi, d)
    P = signflow(
        G_pkn,
        d,
        l0_penalty_edges = lambd
    )
    P.solve(solver="GUROBI", IntegralityFocus=1, NoRelHeurTime=norel, Seed=seed, TimeLimit=max_time, verbosity=1)
    valid_edges = set()
    for i, (s, t) in enumerate(G_pkn.E):
        s = list(s)
        t = list(t)
        if len(s)==1 and len(t)==1 and (not s[0].startswith("_")) and (not t[0].startswith("_")):
            valid_edges.add(i)
    all_edges_multi = np.zeros(G_pkn.shape[1])
    sel_edges = np.zeros(G_pkn.shape[1])
    E_multi = list(G_pkn.E)
    scores = []
    for i, (k, v) in enumerate(dataset.items()):
        _, cp_inputs, cp_outputs = preprocess_graph(G_pkn, v["input"], v["output"])
        s = score(G_pkn, P.expr[f"edge_values_{k}"].value, P.expr[f"vertex_values_{k}"].value, cp_inputs, cp_outputs)
        scores.append(s)
        sol_edges = np.flatnonzero(np.abs(P.expr[f"edge_values_{k}"].value) > 0)
        all_edges_multi[sol_edges] += 1
        sel_edges[sol_edges] = 1
    return P, G_pkn, scores, all_edges_multi[list(valid_edges)]


def multi_carnival_flow(G, dataset, acyclic_signal_version=True, acyclic_signal_name="v3", excl_vertex_value=False, slack_reg=False, lambd=0.25, norel=0, max_time=600, solver="GUROBI", seed=0):
    d = convert_input_dict(dataset)
    G_multi, input_multi, output_multi = prune(d, G)
    print(G_multi.shape, len(input_multi), len(output_multi))
    all_v = input_multi.union(output_multi)
    exp_list = dict()
    for k, v in dataset.items():
        filtered_in = {key: value for key, value in v["input"].items() if key in all_v}
        filtered_out = {key: value for key, value in v["output"].items() if key in all_v}
        exp_list[k] = {"input": filtered_in, "output": filtered_out}
        print(k, len(filtered_in), len(filtered_out))

    G_exp_e = expand_graph_for_flows(G_multi, exp_list)
    if acyclic_signal_version:
        # v3 has L0 reg
        if acyclic_signal_name == "v3":
            P = create_flow_carnival_v3(G_exp_e, exp_list, lambd=lambd, exclusive_vertex_values=excl_vertex_value)
        elif acyclic_signal_name == "v4":
            P = create_flow_carnival_v4(G_exp_e, exp_list, lambd=lambd, slack_reg=slack_reg, upper_bound_flow=10, exclusive_vertex_values=excl_vertex_value)
        else:
            raise ValueError(acyclic_signal_name)
    else:
        P = create_flow_carnival(G_exp_e, exp_list, lambd=lambd)
    if solver == "GUROBI":
        P.solve(solver=solver, IntegralityFocus=1, NoRelHeurTime=norel, Seed=seed, TimeLimit=max_time, verbosity=1)
    else:
        P.solve(solver=solver, verbosity=1)
    valid_edges = set()
    for i, (s, t) in enumerate(G_exp_e.E):
        s = list(s)
        t = list(t)
        if len(s)==1 and len(t)==1 and (not s[0].startswith("_")) and (not t[0].startswith("_")):
            valid_edges.add(i)
    all_edges_multi = np.zeros(G_exp_e.shape[1])
    E_multi = list(G_pkn.E)
    scores = []
    for i, (k, v) in enumerate(dataset.items()):
        _, cp_inputs, cp_outputs = preprocess_graph(G_exp_e, v["input"], v["output"])
        s = score(G_exp_e, P.expr.edge_value.value[:,i], P.expr.vertex_value.value[:,i], cp_inputs, cp_outputs)
        #_, cp_inputs, cp_outputs = preprocess_graph(G_pkn, exp_list[k]["input"], exp_list[k]["output"])
        #s = score(G_pkn, P.expr[f"edge_values_{k}"].value, P.expr[f"vertex_values_{k}"].value, cp_inputs, cp_outputs)
        scores.append(s)
        all_edges_multi[np.flatnonzero(np.abs(P.expr.edge_value.value[:,i]) > 0)] += 1
    return P, G_exp_e, scores, all_edges_multi[list(valid_edges)]


In [None]:
# Create the input dict
input_data = dict()
for cell in selected_cells:
    d = dict()
    input_data[cell] = d
    d["input"] = targets
    d["output"] = get_measurements(cell, selected_drug)
#input_data

In [None]:
def summary(scores, edge_vector):
    total = 0
    edges = 0
    for err, num_edges in scores:
        total += err
        edges += num_edges
    diff_edges = np.sum(edge_vector > 0)
    return total, diff_edges, diff_edges/edges

In [None]:
P_fa, G_fa, scores_fa, edges_fa = multi_carnival_flow(G_pkn, input_data, solver="GUROBI", acyclic_signal_version=True, excl_vertex_value=True, norel=300, lambd=0.01, acyclic_signal_name="v4")
for o in P_fa.objectives:
    print(o.value)