# Case study 1: The T signaling pathway

In [4]:
# from eliater.network_validation import conditional_independence_test_summary
import pandas as pd
from y0.graph import NxMixedGraph
from y0.dsl import Variable

This is case study 1 in Figure 5 (a) in this paper: Eliater: an analytical workflow and open source implementation for causal query estimation in biomolecular networks.

In [7]:
graph1 = NxMixedGraph.from_str_adj(
    directed={
        "PKA": ["Raf", "Mek", "Erk", "Akt", "Jnk", "P38"],
        "PKC": ["Mek", "Raf", "PKA", "Jnk", "P38"],
        "Raf": ["Mek"],
        "Mek": ["Erk"],
        "Erk": ["Akt"],
        "Plcg": ["PKC", "PIP2", "PIP3"],
        "PIP3": ["PIP2", "Akt"],
        "PIP2": ["PKC"],
    }
)

In [9]:
# Get the data
data = pd.read_csv(
    "~/Github/eliater/src/eliater/data/sachs_discretized_2bin.csv",
    index_col=False,
)
data

Unnamed: 0,0,0.1,0.2,0.3,1,1.1,1.2,0.4,1.3,1.4,1.5,1.6
0,1,0,1,1,1,1,1,1,1,1,1,1
1,2,1,0,0,1,1,1,1,1,1,0,0
2,3,1,0,0,1,1,1,1,1,1,0,0
3,4,0,0,1,1,0,1,1,1,1,1,1
4,5,0,0,0,0,1,1,1,1,1,1,1
...,...,...,...,...,...,...,...,...,...,...,...,...
847,848,1,0,0,0,1,1,1,0,1,0,0
848,849,0,0,0,1,1,1,1,1,1,1,0
849,850,1,1,0,1,1,1,0,1,1,0,0
850,851,1,0,0,1,1,1,1,0,1,1,0


## Step 1: Verify correctness of the network structure

In [7]:
#remove after branch is merged
import logging
from typing import Dict, Literal, Optional
import pandas as pd

from y0.algorithm.falsification import get_graph_falsifications
from y0.dsl import Variable
from y0.graph import NxMixedGraph
from y0.struct import get_conditional_independence_tests

logging.basicConfig(format="%(message)s", level=logging.DEBUG)


__all__ = [
    "conditional_independence_test_summary",
    "validate_test",
    "get_state_space_map",
    "is_data_discrete",
    "is_data_continuous",
    "CITest",
    "choose_default_test",
]

TESTS = get_conditional_independence_tests()


def get_state_space_map(
    data: pd.DataFrame, threshold: Optional[int] = 10
) -> Dict[Variable, Literal["discrete", "continuous"]]:
    """Get a dictionary from each variable to its type.

    :param data: the observed data
    :param threshold: The threshold for determining a column as discrete
        based on the number of unique values
    :return: the mapping from column name to its type
    """
    column_values_unique_count = {
        column_name: data[column_name].nunique() for column_name in data.columns
    }
    return {
        Variable(column): "discrete"
        if column_values_unique_count[column] <= threshold
        else "continuous"
        for column in data.columns
    }


def is_data_discrete(data: pd.DataFrame) -> bool:
    """Check if all the columns in the dataframe has discrete data.

    :param data: observational data.
    :return: True, if all the columns have discrete data, False, otherwise
    """
    variable_types = set(get_state_space_map(data=data).values())
    return variable_types == {"discrete"}


def is_data_continuous(data: pd.DataFrame) -> bool:
    """Check if all the columns in the dataframe has continuous data.

    :param data: observational.
    :return: True, if all the columns have continuous data, False, otherwise
    """
    variable_types = set(get_state_space_map(data).values())
    return variable_types == {"continuous"}


# TODO replace with y0.struct.CITest
CITest = Literal[
    "pearson",
    "chi-square",
    "cressie_read",
    "freeman_tuckey",
    "g_sq",
    "log_likelihood",
    "modified_log_likelihood",
    "power_divergence",
    "neyman",
]


def choose_default_test(data: pd.DataFrame) -> CITest:
    """Choose the default statistical test for testing conditional independencies based on the data.

    :param data: observational data.
    :return: the default test based on data
    :raises NotImplementedError: if data is of mixed type (contains both discrete and continuous columns)
    """
    if is_data_discrete(data):
        return "chi-square"
    if is_data_continuous(data):
        return "pearson"
    raise NotImplementedError(
        "Mixed data types are not allowed. Either all of the columns of data should be discrete / continuous."
    )


def validate_test(
    data: pd.DataFrame,
    test: Optional[CITest],
) -> None:
    """Validate the conditional independency test passed by the user.

    :param data: observational data.
    :param test: the conditional independency test passed by the user.
    :raises ValueError: if the passed test is invalid / unsupported, pearson is used for discrete data or
        chi-square is used for continuous data
    """
    tests = get_conditional_independence_tests()
    if test not in tests:
        raise ValueError(f"`{test}` is invalid. Supported CI tests are: {sorted(tests)}")

    if is_data_continuous(data) and test != "pearson":
        raise ValueError(
            "The data is continuous. Either discretize and use chi-square or use the pearson."
        )

    if is_data_discrete(data) and test == "pearson":
        raise ValueError("Cannot run pearson on discrete data. Use chi-square instead.")


def conditional_independence_test_summary(
    graph: NxMixedGraph,
    data: pd.DataFrame,
    test: Optional[CITest] = None,
    max_given: Optional[int] = 5,
    significance_level: Optional[float] = None,
    verbose: Optional[bool] = False,
) -> None:
    """Print the summary of conditional independency test results.

    Prints the summary to the console, which includes the total number of conditional independence tests,
    the number and percentage of failed tests, and statistical information about each test such as p-values,
    and test results.

    :param graph: an NxMixedGraph
    :param data: observational data corresponding to the graph
    :param test: the conditional independency test to use. If None, defaults to ``pearson`` for continuous data
        and ``chi-square`` for discrete data.
    :param max_given: The maximum set size in the power set of the vertices minus the d-separable pairs
    :param significance_level: The statistical tests employ this value for
        comparison with the p-value of the test to determine the independence of
        the tested variables. If none, defaults to 0.01.
    :param verbose: If `False`, only print the details of failed tests.
        If 'True', print the details of all the conditional independency results. Defaults to `False`
    :raises NotImplementedError: if data is of mixed type (contains both discrete and continuous columns)
    """
    if significance_level is None:
        significance_level = 0.01
    if not test:
        test = choose_default_test(data)
    else:
        # Validate test and data
        validate_test(data=data, test=test)
        if len(set(get_state_space_map(data).values())) > 1:
            raise NotImplementedError(
                "Mixed data types are not allowed. Either all of the columns of data should be discrete / continuous."
            )
    test_results = get_graph_falsifications(
        graph=graph,
        df=data,
        method=test,
        significance_level=significance_level,
        max_given=max_given,
    ).evidence
    # Find the result based on p-value
    test_results["result"] = test_results["p"].apply(
        lambda p_value: "fail" if p_value < significance_level else "pass"
    )
    # Selecting columns of interest
    test_results = test_results[["left", "right", "given", "p", "result"]]
    # Sorting the rows by index
    test_results = test_results.sort_index()
    test_results = test_results.rename(columns={"p": "p-value"})
    failed_tests = test_results[test_results["result"] == "fail"]
    total_no_of_tests = len(test_results)
    total_no_of_failed_tests = len(failed_tests)
    percentage_of_failed_tests = total_no_of_failed_tests / total_no_of_tests
    logging.info(f"Total number of conditional independencies: {total_no_of_tests:,}")
    logging.info(f"Total number of failed tests: {total_no_of_failed_tests:,}")
    logging.info(f"Percentage of failed tests: {percentage_of_failed_tests:.2%}")
    if verbose:
        logging.info(test_results.to_string(index=False))
    else:
        logging.info(failed_tests.to_string(index=False))

In [79]:
conditional_independence_test_summary(graph, data, verbose=True)

Total number of conditional independencies: 35
Total number of failed tests: 1
Percentage of failed tests: 2.86%
left right        given  p-value result
PIP3   PKC    PIP2|Plcg 0.416708   pass
Plcg   Raf          PKC 0.315851   pass
 Erk   P38      PKA|PKC 0.550582   pass
PIP2   PKA          PKC 0.605099   pass
 Akt   P38      PKA|PKC 0.233727   pass
 P38   Raf      PKA|PKC 0.534014   pass
 Akt  Plcg     PIP3|PKC 0.069127   pass
 Erk  Plcg          PKC 0.075253   pass
PIP3   PKA          PKC 0.604881   pass
 Akt   Raf  Erk|PKA|PKC 0.989789   pass
 Mek  PIP2          PKC 1.000000   pass
 Akt   PKC Erk|PIP3|PKA 0.951124   pass
 Jnk  PIP2          PKC 0.773824   pass
 Erk  PIP3          PKC 0.926823   pass
 Jnk   P38      PKA|PKC 0.153375   pass
 Jnk  Plcg          PKC 0.806956   pass
 Akt   Mek  Erk|PKA|PKC 0.777743   pass
 PKA  Plcg          PKC 0.855297   pass
 Mek   P38      PKA|PKC 0.421722   pass
 P38  PIP2          PKC 0.004250   fail
 Jnk  PIP3          PKC 0.072930   pass
 Mek  P

Out of 35 d-separations implied by the network, only one failed. As the precentage of failed tests is below 30 percent, its effect on the estimation of causal query is minor. Hence, we proceed to the next step.

## Step 2: Check query identifiability

In [8]:
graph2 = NxMixedGraph.from_str_adj(
    directed={
        "X": ["Y"]
    },
    undirected={"X":["Y"]}
)

In [9]:
from y0.algorithm.identify import Identification
from y0.dsl import P
id_in = Identification.from_expression(
    #query=P(Variable('Erk') @ [Variable('Raf'), Variable('Mek')]),
    query=P(Variable('Erk') @ Variable('Raf')),
    graph=graph1,
)
id_in

Identification(outcomes="{Erk}, treatments="{Raf}",conditions="set()",  graph="NxMixedGraph(directed=<networkx.classes.digraph.DiGraph object at 0x114b75690>, undirected=<networkx.classes.graph.Graph object at 0x114b757d0>)", estimand="P(Akt, Erk, Jnk, Mek, P38, PIP2, PIP3, PKA, PKC, Plcg, Raf)")

In [10]:
from y0.algorithm.identify import Identification
from y0.dsl import P
id_in = Identification.from_expression(
    #query=P(Variable('Erk') @ [Variable('Raf'), Variable('Mek')]),
    query=P(Variable('Y') @ Variable('X')),
    graph=graph2,
)
id_in

Identification(outcomes="{Y}, treatments="{X}",conditions="set()",  graph="NxMixedGraph(directed=<networkx.classes.digraph.DiGraph object at 0x114b77750>, undirected=<networkx.classes.graph.Graph object at 0x114b778d0>)", estimand="P(X, Y)")

The query is identifiable. Hence, we can proceed to the next step.

## Step 3: Find nuisance variables and mark them as latent

In [9]:
#remove after merging simplify branch to main
import itertools
from typing import Iterable, Optional, Set, Union

import networkx as nx

from y0.algorithm.simplify_latent import simplify_latent_dag
from y0.dsl import Variable
from y0.graph import DEFAULT_TAG, NxMixedGraph

__all__ = [
    "remove_latent_variables",
    "mark_nuisance_variables_as_latent",
    "find_all_nodes_in_causal_paths",
    "find_nuisance_variables",
]


def remove_latent_variables(
    graph: NxMixedGraph,
    treatments: Union[Variable, Set[Variable]],
    outcomes: Union[Variable, Set[Variable]],
    tag: Optional[str] = None,
) -> NxMixedGraph:
    """Find all nuissance variables and remove them based on Evans' simplification rules.

    :param graph: an NxMixedGraph
    :param treatments: a list of treatments
    :param outcomes: a list of outcomes
    :param tag: The tag for which variables are latent
    :return: the new graph after simplification
    """
    lv_dag = mark_nuisance_variables_as_latent(
        graph=graph, treatments=treatments, outcomes=outcomes, tag=tag
    )
    simplified_latent_dag = simplify_latent_dag(lv_dag, tag=tag)
    return NxMixedGraph.from_latent_variable_dag(simplified_latent_dag.graph, tag=tag)


def mark_nuisance_variables_as_latent(
    graph: NxMixedGraph,
    treatments: Union[Variable, Set[Variable]],
    outcomes: Union[Variable, Set[Variable]],
    tag: Optional[str] = None,
) -> nx.DiGraph:
    """Find all the nuisance variables and mark them as latent.

    Mark nuisance variables as latent by first identifying them, then creating a new graph where these
    nodes are marked as latent. Nuisance variables are the descendants of nodes in all proper causal paths
    that are not ancestors of the outcome variables nodes. A proper causal path is a directed path from
    treatments to the outcome. Nuisance variables should not be included in the estimation of the causal
    effect as they increase the variance.

    :param graph: an NxMixedGraph
    :param treatments: a list of treatments
    :param outcomes: a list of outcomes
    :param tag: The tag for which variables are latent
    :return: the modified graph after simplification, in place
    """
    if tag is None:
        tag = DEFAULT_TAG
    nuisance_variables = find_nuisance_variables(graph, treatments=treatments, outcomes=outcomes)
    lv_dag = NxMixedGraph.to_latent_variable_dag(graph, tag=tag)
    # Set nuisance variables as latent
    for node, data in lv_dag.nodes(data=True):
        if Variable(node) in nuisance_variables:
            data[tag] = True
    return lv_dag


def find_all_nodes_in_causal_paths(
    graph: NxMixedGraph,
    treatments: Union[Variable, Set[Variable]],
    outcomes: Union[Variable, Set[Variable]],
) -> Set[Variable]:
    """Find all the nodes in proper causal paths from treatments to outcomes.

    A proper causal path is a directed path from treatments to the outcome.

    :param graph: an NxMixedGraph
    :param treatments: a list of treatments
    :param outcomes: a list of outcomes
    :return: the nodes on all causal paths from treatments to outcomes.
    """
    if isinstance(treatments, Variable):
        treatments = {treatments}
    if isinstance(outcomes, Variable):
        outcomes = {outcomes}

    return {
        node
        for treatment, outcome in itertools.product(treatments, outcomes)
        for causal_path in nx.all_simple_paths(graph.directed, treatment, outcome)
        for node in causal_path
    }


def find_nuisance_variables(
    graph: NxMixedGraph,
    treatments: Union[Variable, Set[Variable]],
    outcomes: Union[Variable, Set[Variable]],
) -> Iterable[Variable]:
    """Find the nuisance variables in the graph.

    Nuisance variables are the descendants of nodes in all proper causal paths that are
    not ancestors of the outcome variables' nodes. A proper causal path is a directed path
    from treatments to the outcome. Nuisance variables should not be included in the estimation
    of the causal effect as they increase the variance.

    :param graph: an NxMixedGraph
    :param treatments: a list of treatments
    :param outcomes: a list of outcomes
    :returns: The nuisance variables.
    """
    if isinstance(treatments, Variable):
        treatments = {treatments}
    if isinstance(outcomes, Variable):
        outcomes = {outcomes}

    # Find the nodes on all causal paths
    nodes_on_causal_paths = find_all_nodes_in_causal_paths(
        graph=graph, treatments=treatments, outcomes=outcomes
    )

    # Find the descendants of interest
    descendants_of_nodes_on_causal_paths = graph.descendants_inclusive(nodes_on_causal_paths)

    # Find the ancestors of outcome variables
    ancestors_of_outcomes = graph.ancestors_inclusive(outcomes)

    descendants_not_ancestors = descendants_of_nodes_on_causal_paths.difference(
        ancestors_of_outcomes
    )

    nuisance_variables = descendants_not_ancestors.difference(treatments.union(outcomes))
    return nuisance_variables

This function finds the nuisance variables for the input graph.

In [10]:
nuisance_variables = find_nuisance_variables(graph,
                                             treatments=Variable("Raf"),
                                             outcomes=Variable("Erk"))
nuisance_variables

{Akt}

The nuisance variable is $Akt$.

In [11]:
latent_variable_dag =  mark_nuisance_variables_as_latent(graph,
                                                         treatments=Variable("Raf"),
                                                         outcomes=Variable("Erk"),
) 

<networkx.classes.digraph.DiGraph at 0x12d77e2d0>

## Step 4: Simplify the network

The following function find the nuisance variable (step 3), marks them as latent and then applies Evan's simplification rules to remove the nuisance variables. The new graph does not contain nuisance variables.

In [18]:
new_graph = remove_latent_variables(graph,
                                    treatments=Variable("Raf"),
                                    outcomes=Variable("Erk"))

## Step 5: Estimate the query

In [8]:
data

Unnamed: 0,0,0.1,0.2,0.3,1,1.1,1.2,0.4,1.3,1.4,1.5,1.6
0,1,0,1,1,1,1,1,1,1,1,1,1
1,2,1,0,0,1,1,1,1,1,1,0,0
2,3,1,0,0,1,1,1,1,1,1,0,0
3,4,0,0,1,1,0,1,1,1,1,1,1
4,5,0,0,0,0,1,1,1,1,1,1,1
...,...,...,...,...,...,...,...,...,...,...,...,...
847,848,1,0,0,0,1,1,1,0,1,0,0
848,849,0,0,0,1,1,1,1,1,1,1,0
849,850,1,1,0,1,1,1,0,1,1,0,0
850,851,1,0,0,1,1,1,1,0,1,1,0


In [19]:
#remove after estimation branch is merged to main and replace with eliater codes
from ananke.graphs import ADMG
from ananke.identification import OneLineID
from ananke.estimation import CausalEffect
from ananke.datasets import load_afixable_data
from ananke.estimation import AutomatedIF
from ananke import identification
import numpy as np
import pandas as pd
from ananke.models import LinearGaussianSEM
import warnings
import numpy
import random
warnings.filterwarnings('ignore')

In [28]:
#remove after estimation branch is merged to main
vertices = ["Plcg", "PIP3", "PIP2", "PKC", "PKA", "Raf", "Mek", "Erk", "Jnk", "P38"]
di_edges = [("Plcg", "PIP3"), 
            ("Plcg", "PIP2"),
            ("Plcg", "PKC"),
            ("PIP3", "PIP2"),
            ("PIP2", "PKC"),
            ("PKC", "Mek"),
            ("PKC", "Raf"),
            ("PKC", "PKA"),
            ("PKC", "Jnk"),
            ("PKC", "P38"),
            ("PKA", "Raf"),
            ("PKA", "Mek"),
            ("PKA", "Erk"),
            ("PKA", "Jnk"),
            ("PKA", "P38"),
            ("Raf", "Mek"),
            ("Mek", "Erk")]
bi_edges = []
new_graph = ADMG(vertices, di_edges, bi_edges)

ate_obj = CausalEffect(graph=new_graph, treatment='Raf', outcome='Erk')  # setting up the CausalEffect object
ate_obj

SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG



 Treatment is a-fixable and graph is mb-shielded. 

 Available estimators are:
 
1. IPW (ipw)
2. Outcome regression (gformula)
3. Generalized AIPW (aipw)
4. Efficient Generalized AIPW (eff-aipw) 
 
Suggested estimator is Efficient Generalized AIPW 


<ananke.estimation.counterfactual_mean.CausalEffect at 0x12292bcd0>

In [76]:
#remove after estimation branch is merged to main and use eliater function
#Estimated ATE
ate_obj.compute_effect(data, "eff-aipw")

SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG


-0.010641706233553646