# Case study 3: The Escherichia coli K-12 transcriptional motif

In [1]:
# from eliater.network_validation import conditional_independence_test_summary
import pandas as pd
from y0.dsl import Variable
from y0.graph import NxMixedGraph


This is case study 3 in Figure 7 in this paper: Eliater: an analytical workflow and open source implementation for causal query estimation in biomolecular networks.

In [23]:
graph = NxMixedGraph.from_str_adj(
    directed={
        "appY": ["appA", "appB", "appX", "hyaA", "hyaB", "hyaF"],
        "arcA": [
            "rpoS",
            "fnr",
            "dpiA",
            "aceE",
            "appY",
            "dpiB",
            "cydD",
            "gcvB",
            "hyaA",
            "hyaB",
            "hyaF",
            "mdh",
            "lrp",
            "ydeO",
            "oxyR",  
        ],
        "btsR": ["mdh"],
        "cra": ["cyoA"],
        "crp": [
            "dpiA",
            "cirA",
            "dcuR",
            "oxyR",
            "fis",
            "fur",
            "aceE",
            "dpiB",
            "cyoA",
            "exuT",
            "gadX",
            "mdh",
            "gutM"
        ],
        "cspA": ["hns"],
        "dcuR": ["dpiA", "dpiB"],
        "dpiA": ["appY", "citC", "citD", "dpiB", "exuT", "mdh"],
        "fis": ["cyoA", "gadX", "hns", "hyaA", "hyaB", "hyaF"],
        "fnr": [
            "dcuR",
            "dpiA",
            "narL",
            "aceE",
            "amtB",
            "aspC",
            "dpiB",
            "cydD",
            "cyoA",
            "gadX",
            "hcp",
        ],
        "fur": ["fnr", "amtB", "aspC", "cirA", "cyoA"],
        "gadX": ["amtB", "hns"],
        "gcvB": ["lrp", "oxyR", "ydeO"],
        "hns": ["appY", "ydeO", "gutM"],
        "ihfA": ["crp", "fnr", "ihfB"],
        "ihfB": ["fnr"],
        "iscR": ["hyaA", "hyaB", "appX"],
        "lrp": ["soxS", "aspC"],
        "modE": ["narL"],
        "narL": ["dpiB", "cydD", "hcp", "hyaA", "hyaB", "hyaF", "dcuR", "dpiA"],
        "narP": ["hyaA", "hyaB", "hyaF"],
        "oxyR": ["fur", "hcp"],
        "phoB": ["cra"],
        "rpoD": [
            "arcA",
            "cirA",
            "crp",
            "dcuR",
            "fis",
            "fnr",
            "fur",
            "ihfB",
            "lrp",
            "narL",
            "oxyR",
            "phoB",
            "rpoS",
            "soxS",
            "aceE",
            "ydeO",
            "hns",
        ],
        "rpoH": ["cra"],
        "rpoS": ["aceE", "appY", "hyaA", "hyaB", "hyaF", "ihfA", "ihfB", "oxyR"],
        "soxS": ["fur"],
        "ydeO": ["hyaA", "hyaF", "hyaB"],
    },
    undirected={
        "dpiA": ["dpiB"],
        "hns": ["rpoS", "lrp"],
        "rpoS": ["lrp"],
        "lrp": ["ydeO", "oxyR"],
        "oxyR": ["ydeO"]
    },
)

In [4]:
# Get the data
data = pd.read_csv(
    "~/Github/eliater/src/eliater/data/EColi_obs_data.csv",
    index_col=False,
)

## Step 1: Verify correctness of the network structure

In [11]:
#remove after branch is merged
import logging
from typing import Dict, Literal, Optional
import pandas as pd

from y0.algorithm.falsification import get_graph_falsifications
from y0.dsl import Variable
from y0.graph import NxMixedGraph
from y0.struct import get_conditional_independence_tests

logging.basicConfig(format="%(message)s", level=logging.DEBUG)


__all__ = [
    "conditional_independence_test_summary",
    "validate_test",
    "get_state_space_map",
    "is_data_discrete",
    "is_data_continuous",
    "CITest",
    "choose_default_test",
]

TESTS = get_conditional_independence_tests()


def get_state_space_map(
    data: pd.DataFrame, threshold: Optional[int] = 10
) -> Dict[Variable, Literal["discrete", "continuous"]]:
    """Get a dictionary from each variable to its type.

    :param data: the observed data
    :param threshold: The threshold for determining a column as discrete
        based on the number of unique values
    :return: the mapping from column name to its type
    """
    column_values_unique_count = {
        column_name: data[column_name].nunique() for column_name in data.columns
    }
    return {
        Variable(column): "discrete"
        if column_values_unique_count[column] <= threshold
        else "continuous"
        for column in data.columns
    }


def is_data_discrete(data: pd.DataFrame) -> bool:
    """Check if all the columns in the dataframe has discrete data.

    :param data: observational data.
    :return: True, if all the columns have discrete data, False, otherwise
    """
    variable_types = set(get_state_space_map(data=data).values())
    return variable_types == {"discrete"}


def is_data_continuous(data: pd.DataFrame) -> bool:
    """Check if all the columns in the dataframe has continuous data.

    :param data: observational.
    :return: True, if all the columns have continuous data, False, otherwise
    """
    variable_types = set(get_state_space_map(data).values())
    return variable_types == {"continuous"}


# TODO replace with y0.struct.CITest
CITest = Literal[
    "pearson",
    "chi-square",
    "cressie_read",
    "freeman_tuckey",
    "g_sq",
    "log_likelihood",
    "modified_log_likelihood",
    "power_divergence",
    "neyman",
]


def choose_default_test(data: pd.DataFrame) -> CITest:
    """Choose the default statistical test for testing conditional independencies based on the data.

    :param data: observational data.
    :return: the default test based on data
    :raises NotImplementedError: if data is of mixed type (contains both discrete and continuous columns)
    """
    if is_data_discrete(data):
        return "chi-square"
    if is_data_continuous(data):
        return "pearson"
    raise NotImplementedError(
        "Mixed data types are not allowed. Either all of the columns of data should be discrete / continuous."
    )


def validate_test(
    data: pd.DataFrame,
    test: Optional[CITest],
) -> None:
    """Validate the conditional independency test passed by the user.

    :param data: observational data.
    :param test: the conditional independency test passed by the user.
    :raises ValueError: if the passed test is invalid / unsupported, pearson is used for discrete data or
        chi-square is used for continuous data
    """
    tests = get_conditional_independence_tests()
    if test not in tests:
        raise ValueError(f"`{test}` is invalid. Supported CI tests are: {sorted(tests)}")

    if is_data_continuous(data) and test != "pearson":
        raise ValueError(
            "The data is continuous. Either discretize and use chi-square or use the pearson."
        )

    if is_data_discrete(data) and test == "pearson":
        raise ValueError("Cannot run pearson on discrete data. Use chi-square instead.")


def conditional_independence_test_summary(
    graph: NxMixedGraph,
    data: pd.DataFrame,
    test: Optional[CITest] = None,
    significance_level: Optional[float] = None,
    verbose: Optional[bool] = False,
) -> None:
    """Print the summary of conditional independency test results.

    Prints the summary to the console, which includes the total number of conditional independence tests,
    the number and percentage of failed tests, and statistical information about each test such as p-values,
    and test results.

    :param graph: an NxMixedGraph
    :param data: observational data corresponding to the graph
    :param test: the conditional independency test to use. If None, defaults to ``pearson`` for continuous data
        and ``chi-square`` for discrete data.
    :param significance_level: The statistical tests employ this value for
        comparison with the p-value of the test to determine the independence of
        the tested variables. If none, defaults to 0.01.
    :param verbose: If `False`, only print the details of failed tests.
        If 'True', print the details of all the conditional independency results. Defaults to `False`
    :raises NotImplementedError: if data is of mixed type (contains both discrete and continuous columns)
    """
    if significance_level is None:
        significance_level = 0.01
    if not test:
        test = choose_default_test(data)
    else:
        # Validate test and data
        validate_test(data=data, test=test)
        if len(set(get_state_space_map(data).values())) > 1:
            raise NotImplementedError(
                "Mixed data types are not allowed. Either all of the columns of data should be discrete / continuous."
            )
    test_results = get_graph_falsifications(
        graph=graph, df=data, method=test, significance_level=significance_level
    ).evidence
    # Find the result based on p-value
    test_results["result"] = test_results["p"].apply(
        lambda p_value: "fail" if p_value < significance_level else "pass"
    )
    # Selecting columns of interest
    test_results = test_results[["left", "right", "given", "p", "result"]]
    # Sorting the rows by index
    test_results = test_results.sort_index()
    test_results = test_results.rename(columns={"p": "p-value"})
    failed_tests = test_results[test_results["result"] == "fail"]
    total_no_of_tests = len(test_results)
    total_no_of_failed_tests = len(failed_tests)
    percentage_of_failed_tests = total_no_of_failed_tests / total_no_of_tests
    logging.info(f"Total number of conditional independencies: {total_no_of_tests:,}")
    logging.info(f"Total number of failed tests: {total_no_of_failed_tests:,}")
    logging.info(f"Percentage of failed tests: {percentage_of_failed_tests:.2%}")
    if verbose:
        logging.info(test_results.to_string(index=False))
    else:
        logging.info(failed_tests.to_string(index=False))

In [12]:
conditional_independence_test_summary(graph, data, verbose=True)

  for z_state, df in data.groupby(Z):
  for z_state, df in data.groupby(Z):
  for z_state, df in data.groupby(Z):
Total number of conditional independencies: 5
Total number of failed tests: 0
Percentage of failed tests: 0.00%
left right given  p-value result
  M2     X    M1      NaN   pass
  M1     U     X      NaN   pass
   X     Y  M2|U      NaN   pass
  M2     U     X      NaN   pass
  M1     Y  M2|U      NaN   pass


## Step 2: Check query identifiability

In [24]:
from y0.algorithm.identify import Identification
from y0.dsl import P
id_in = Identification.from_expression(
    query=P(Variable('dpiA') @ Variable('fur')),
    graph=graph,
)
id_in

Identification(outcomes="{dpiA}, treatments="{fur}",conditions="set()",  graph="NxMixedGraph(directed=<networkx.classes.digraph.DiGraph object at 0x127bfb1d0>, undirected=<networkx.classes.graph.Graph object at 0x127bfba90>)", estimand="P(aceE, amtB, appA, appB, appX, appY, arcA, aspC, btsR, cirA, citC, citD, citX, cra, crp, cspA, cydD, cyoA, dcuR, dpiA, dpiB, exuT, fis, fnr, fur, gadX, gcvB, gutM, hcp, hns, hyaA, hyaB, hyaF, ihfA, ihfB, iscR, lrp, mdh, modE, narL, narP, oxyR, phoB, rpoD, rpoH, rpoS, soxS, ydeO)")

The query is identifiable. Hence, we can proceed to the next step.

## Step 3: Find nuisance variables and mark them as latent

In [25]:
#remove after merging simplify branch to main
import itertools
from typing import Iterable, Optional, Set, Union

import networkx as nx

from y0.algorithm.simplify_latent import simplify_latent_dag
from y0.dsl import Variable
from y0.graph import DEFAULT_TAG, NxMixedGraph

__all__ = [
    "remove_latent_variables",
    "mark_nuisance_variables_as_latent",
    "find_all_nodes_in_causal_paths",
    "find_nuisance_variables",
]


def remove_latent_variables(
    graph: NxMixedGraph,
    treatments: Union[Variable, Set[Variable]],
    outcomes: Union[Variable, Set[Variable]],
    tag: Optional[str] = None,
) -> NxMixedGraph:
    """Find all nuissance variables and remove them based on Evans' simplification rules.

    :param graph: an NxMixedGraph
    :param treatments: a list of treatments
    :param outcomes: a list of outcomes
    :param tag: The tag for which variables are latent
    :return: the new graph after simplification
    """
    lv_dag = mark_nuisance_variables_as_latent(
        graph=graph, treatments=treatments, outcomes=outcomes, tag=tag
    )
    simplified_latent_dag = simplify_latent_dag(lv_dag, tag=tag)
    return NxMixedGraph.from_latent_variable_dag(simplified_latent_dag.graph, tag=tag)


def mark_nuisance_variables_as_latent(
    graph: NxMixedGraph,
    treatments: Union[Variable, Set[Variable]],
    outcomes: Union[Variable, Set[Variable]],
    tag: Optional[str] = None,
) -> nx.DiGraph:
    """Find all the nuisance variables and mark them as latent.

    Mark nuisance variables as latent by first identifying them, then creating a new graph where these
    nodes are marked as latent. Nuisance variables are the descendants of nodes in all proper causal paths
    that are not ancestors of the outcome variables nodes. A proper causal path is a directed path from
    treatments to the outcome. Nuisance variables should not be included in the estimation of the causal
    effect as they increase the variance.

    :param graph: an NxMixedGraph
    :param treatments: a list of treatments
    :param outcomes: a list of outcomes
    :param tag: The tag for which variables are latent
    :return: the modified graph after simplification, in place
    """
    if tag is None:
        tag = DEFAULT_TAG
    nuisance_variables = find_nuisance_variables(graph, treatments=treatments, outcomes=outcomes)
    lv_dag = NxMixedGraph.to_latent_variable_dag(graph, tag=tag)
    # Set nuisance variables as latent
    for node, data in lv_dag.nodes(data=True):
        if Variable(node) in nuisance_variables:
            data[tag] = True
    return lv_dag


def find_all_nodes_in_causal_paths(
    graph: NxMixedGraph,
    treatments: Union[Variable, Set[Variable]],
    outcomes: Union[Variable, Set[Variable]],
) -> Set[Variable]:
    """Find all the nodes in proper causal paths from treatments to outcomes.

    A proper causal path is a directed path from treatments to the outcome.

    :param graph: an NxMixedGraph
    :param treatments: a list of treatments
    :param outcomes: a list of outcomes
    :return: the nodes on all causal paths from treatments to outcomes.
    """
    if isinstance(treatments, Variable):
        treatments = {treatments}
    if isinstance(outcomes, Variable):
        outcomes = {outcomes}

    return {
        node
        for treatment, outcome in itertools.product(treatments, outcomes)
        for causal_path in nx.all_simple_paths(graph.directed, treatment, outcome)
        for node in causal_path
    }


def find_nuisance_variables(
    graph: NxMixedGraph,
    treatments: Union[Variable, Set[Variable]],
    outcomes: Union[Variable, Set[Variable]],
) -> Iterable[Variable]:
    """Find the nuisance variables in the graph.

    Nuisance variables are the descendants of nodes in all proper causal paths that are
    not ancestors of the outcome variables' nodes. A proper causal path is a directed path
    from treatments to the outcome. Nuisance variables should not be included in the estimation
    of the causal effect as they increase the variance.

    :param graph: an NxMixedGraph
    :param treatments: a list of treatments
    :param outcomes: a list of outcomes
    :returns: The nuisance variables.
    """
    if isinstance(treatments, Variable):
        treatments = {treatments}
    if isinstance(outcomes, Variable):
        outcomes = {outcomes}

    # Find the nodes on all causal paths
    nodes_on_causal_paths = find_all_nodes_in_causal_paths(
        graph=graph, treatments=treatments, outcomes=outcomes
    )

    # Find the descendants of interest
    descendants_of_nodes_on_causal_paths = graph.descendants_inclusive(nodes_on_causal_paths)

    # Find the ancestors of outcome variables
    ancestors_of_outcomes = graph.ancestors_inclusive(outcomes)

    descendants_not_ancestors = descendants_of_nodes_on_causal_paths.difference(
        ancestors_of_outcomes
    )

    nuisance_variables = descendants_not_ancestors.difference(treatments.union(outcomes))
    return nuisance_variables

This function finds the nuisance variables for the input graph.

In [27]:
nuisance_variables = find_nuisance_variables(graph, treatments=Variable('fur'), outcomes=Variable('dpiA'))
nuisance_variables

24

## Step 4: Simplify the network

The following function find the nuisance variable (step 3), marks them as latent and then applies Evan's simplification rules to remove the nuisance variables.

In [None]:
new_graph = remove_latent_variables(graph, treatments=Variable("fur"), outcomes=Variable("dpiA"))

## Step 5: Estimate the query