In [None]:
import corneto as cn
import numpy as np

cn.info()

In [None]:
G_pkn = cn.Graph.from_sif_tuples(
    [
        ("rec1", 1, "a"),
        ("rec1", -1, "b"),
        ("rec1", 1, "f"),
        ("rec1", -1, "c"),
        ("rec2", 1, "b"),
        ("rec2", 1, "tf2"),
        ("b", 1, "g"),
        ("g", -1, "d"),
        ("rec2", -1, "d"),
        ("a", 1, "c"),
        ("a", -1, "d"),
        ("c", 1, "d"),
        ("c", -1, "e"),
        ("c", 1, "tf3"),
        ("e", 1, "a"),
        ("d", -1, "c"),
        ("e", 1, "tf1"),
        ("a", -1, "tf1"),
        ("d", 1, "tf2"),
        ("c", -1, "tf2"),
        ("tf1", 1, "tf2"),
        ("tf1", -1, "rec2"),
        ("tf2", 1, "rec1"),
        ("tf1", 1, "f")
    ]
)
G_pkn.plot()

In [None]:
def prune_graph(G, conditions, inputs_dict_key="input", outputs_dict_key="output"):
    graph_vertices = set(G.V)
    pruned_conditions = {}
    all_input_vertices = set()
    all_output_vertices = set()

    for cond_name, cond_data in conditions.items():
        # Convert dict keys to sets for easier set operations
        condition_inputs = set(cond_data[inputs_dict_key])
        condition_outputs = set(cond_data[outputs_dict_key])

        # Intersect with the current graph's vertices
        relevant_inputs = graph_vertices & condition_inputs
        relevant_outputs = graph_vertices & condition_outputs

        # Prune the graph based on relevant inputs and outputs
        sub_graph = G.prune(list(relevant_inputs), list(relevant_outputs))
        sub_vertices = set(sub_graph.V)

        # Gather only input/output items that still exist in the pruned subgraph
        pruned_inputs = {
            i: cond_data[inputs_dict_key].get(i, 0)
            for i in sub_vertices & condition_inputs
        }
        pruned_outputs = {
            o: cond_data[outputs_dict_key].get(o, 0)
            for o in sub_vertices & condition_outputs
        }

        # Store the pruned condition
        pruned_conditions[cond_name] = {
            inputs_dict_key: pruned_inputs,
            outputs_dict_key: pruned_outputs,
        }

        # Collect all inputs/outputs to prune the original graph finally
        all_input_vertices.update(pruned_inputs)
        all_output_vertices.update(pruned_outputs)

    # Prune the original graph using all collected inputs/outputs
    pruned_graph = G.prune(list(all_input_vertices), list(all_output_vertices))

    return pruned_conditions, pruned_graph


In [None]:
from corneto.methods.carnival import preprocess_graph
from corneto.methods.signaling import create_flow_graph
#from corneto.methods.signalling.carnival import create_carnival_problem
from corneto.methods.future.carnival import Carnival
from corneto.methods import expand_graph_for_flows
from corneto.backend import PicosBackend, CvxpyBackend

conditions = {
    "c1": {
        "input": {
            "rec2": 1
        },
        "output": {
            "tf1": -1,
            "tf2": 1
        }
    },
    "c2": {
        "input": {
            "rec1": 1,
            "rec2": -1
        },
        "output": {
            "tf1": 1,
            "tf2": -1,
            "tf3": 3,
        }
    }
}

def pivoted_to_standard(pivoted_dict, metadata_key):
    """
    Converts a 'pivoted' dict of the form:
        {
          condition: {
            meta_value: {
               feature_name: feature_value
            },
            ...
          },
          ...
        }
    into a 'standard' dict of the form:
        {
          condition: {
            feature_name: {
              "value": feature_value,
              <metadata_key>: meta_value
            },
            ...
          },
          ...
        }

    Args:
        pivoted_dict (dict): The pivoted dictionary.
        metadata_key (str): The name of the metadata field to inject (e.g. "type").

    Returns:
        dict: The converted dictionary in standard format.
    """
    standard_dict = {}

    for condition, meta_groups in pivoted_dict.items():
        standard_dict[condition] = {}
        for meta_val, features in meta_groups.items():
            for feature_name, feature_value in features.items():
                standard_dict[condition][feature_name] = {
                    "value": feature_value,
                    metadata_key: meta_val
                }
    return standard_dict

pivoted_to_standard(conditions, "type")

In [None]:
cond, Gp = prune_graph(G_pkn, conditions)
cond

In [None]:
Gp.shape, G_pkn.shape

In [None]:
G_flow = expand_graph_for_flows(Gp, cond)
G_flow.plot()

In [None]:
from corneto.methods.future.method import Dataset

In [None]:
#for o in P.objectives:
#    if hasattr(o, 'value'):
#        print(o.value)
#    else:
#        print(o)

In [None]:
Dataset.from_dict(pivoted_to_standard(conditions, "type")).to_dict()

In [None]:
c = Carnival(backend=CvxpyBackend(), lambd=0)
P = c.build(G_pkn, Dataset.from_dict(pivoted_to_standard(conditions, "type")))
P.solve(solver="GUROBI", verbosity=1)

In [None]:
G_flow.edge_subgraph(np.flatnonzero(P.expr.edge_has_signal.value[:, 0])).plot()

In [None]:
G_flow.edge_subgraph(np.flatnonzero(P.expr.edge_has_signal.value[:, 1])).plot()

In [None]:
for o in P.objectives:
    print(o.value)

In [None]:
P.expr

In [None]:
from corneto.methods.future.method import Dataset

d = Dataset.from_dict({
    "c1": {
        "rec2": {
            "value": 1,
            "type": "input"
        },
        "tf1": {
            "value": -1,
            "type": "output"
        },
        "tf2": {
            "value": 1,
            "type": "output"
        }
    },
    "c2": {
        "rec1": {
            "value": 1, 
            "type": "input"
        },
        "rec2": {
            "value": -1,
            "type": "input"
        },
        "tf1": {
            "value": 1,
            "type": "output"
        },
        "tf2": {
            "value": -1,
            "type": "output"
        },
        "tf3": {
            "value": 3,
            "type": "output"
        }
    }
})


In [None]:
d.to_dict(key="type", return_value_only=True)

In [None]:
# 1. Create some "standard" structure data:
standard_data = {
    "sample1": {
        "m1": {"value": 1, "type": "input", "other": "ask"},
        "m2": {"value": 2, "type": "output"}
    },
    "sample2": {
        "m3": {"value": 99, "type": "input"}
    }
}
ds_standard = Dataset.from_dict(standard_data)
Dataset.from_dict(ds_standard.to_dict(key="type", return_value_only=True), key="type").to_dict()