# Motivating example: Figure 4

In [13]:
#from eliater.network_validation import conditional_independence_test_summary

# from src.eliater.frontdoor_backdoor_discrete import (
#     single_mediator_with_multiple_confounders_nuisances_discrete_example,
# )

from y0.graph import NxMixedGraph
from y0.dsl import Variable

This is the motivating example in Figure 4 (a) in this paper: *Eliater: an analytical workflow and open source implementation for causal query
estimation in biomolecular networks*. 

In [14]:
graph = NxMixedGraph.from_str_edges(
    directed=[
        ("Z1", "X"),
        ("Z1", "Z2"),
        ("Z2", "Z3"),
        ("Z3", "Y"),
        ("X", "M1"),
        ("M1", "Y"),
        ("M1", "R1"),
        ("R1", "R2"),
        ("R2", "R3"),
        ("Y", "R3"),
    ],
    undirected=[
        # ("Z1", "X"),
    ],
)

In [15]:
#remove after CI branch is merged
import numpy as np
import pandas as pd

from y0.algorithm.identify import Query
from y0.dsl import Z1, Z2, Z3, Variable, X, Y
from y0.examples import Example

M1 = Variable("M1")
R1 = Variable("R1")
R2 = Variable("R2")
R3 = Variable("R3")

__all__ = [
    "single_mediator_with_multiple_confounders_nuisances_discrete_example",
]


def _r_exp(x):
    return 1 / (1 + np.exp(x))


def generate(
    num_samples: int = 1000,
    treatments: dict[Variable, float] | None = None,
    *,
    seed: int | None = None,
) -> pd.DataFrame:
    """Generate discrete testing data for the multiple_mediators_with_multiple_confounders_nuisances_discrete case study.

    :param num_samples: The number of samples to generate. Try 1000.
    :param treatments: An optional dictionary of the values to fix each variable to.
    :param seed: An optional random seed for reproducibility purposes
    :returns: A pandas Dataframe with columns corresponding
        to the variable names in the multiple_mediators_with_multiple_confounders_nuisances_discrete example
    """
    if treatments is None:
        treatments = {}
    generator = np.random.default_rng(seed)

    values_z1 = [0, 1]
    probs_z1 = [0.4, 0.6]

    if Z1 in treatments:
        z1 = np.full(num_samples, treatments[Z1])
    else:
        z1 = generator.choice(values_z1, num_samples, p=probs_z1)

    beta0_z2 = 1
    beta_z1_to_z2 = 0.3

    if Z2 in treatments:
        z2 = np.full(num_samples, treatments[Z2])
    else:
        probs_z2 = _r_exp(-beta0_z2 - z1 * beta_z1_to_z2)
        z2 = generator.binomial(n=1, p=probs_z2, size=num_samples)

    beta0_z3 = 1.2
    beta_z2_to_z3 = 0.6

    if Z3 in treatments:
        z3 = np.full(num_samples, treatments[Z3])
    else:
        probs_z3 = _r_exp(-beta0_z3 - z2 * beta_z2_to_z3)
        z3 = generator.binomial(n=1, p=probs_z3, size=num_samples)

    beta0_x = 1
    beta_z1_to_x = 0.6

    if X in treatments:
        x = np.full(num_samples, treatments[X])
    else:
        probs_x = _r_exp(-beta0_x - z1 * beta_z1_to_x)
        x = generator.binomial(n=1, p=probs_x, size=num_samples)

    beta0_m1 = 1
    beta_x_to_m1 = 0.7

    if M1 in treatments:
        m1 = np.full(num_samples, treatments[M1])
    else:
        probs_m1 = _r_exp(-beta0_m1 - x * beta_x_to_m1)
        m1 = generator.binomial(n=1, p=probs_m1, size=num_samples)

    beta0_y = 1.8
    beta_z3_to_y = 0.5
    beta_m1_to_y = 0.7
    if Y in treatments:
        y = np.full(num_samples, treatments[Y])
    else:
        probs_y = _r_exp(-beta0_y - z3 * beta_z3_to_y - m1 * beta_m1_to_y)
        y = generator.binomial(n=1, p=probs_y, size=num_samples)

    beta0_r1 = 1.5
    beta_m1_to_r1 = 0.7

    if R1 in treatments:
        r1 = np.full(num_samples, treatments[R1])
    else:
        probs_r1 = _r_exp(-beta0_r1 - m1 * beta_m1_to_r1)
        r1 = generator.binomial(n=1, p=probs_r1, size=num_samples)

    beta0_r2 = 1.4
    beta_r1_to_r2 = 0.4

    if R2 in treatments:
        r2 = np.full(num_samples, treatments[R2])
    else:
        probs_r2 = _r_exp(-beta0_r2 - r1 * beta_r1_to_r2)
        r2 = generator.binomial(n=1, p=probs_r2, size=num_samples)

    beta0_r3 = 1.1
    beta_r2_to_r3 = 0.3
    beta_y_to_r3 = 0.3

    if R3 in treatments:
        r3 = np.full(num_samples, treatments[R3])
    else:
        probs_r3 = _r_exp(-beta0_r3 - r2 * beta_r2_to_r3 - y * beta_y_to_r3)
        r3 = generator.binomial(n=1, p=probs_r3, size=num_samples)

    return pd.DataFrame(
        {
            X.name: x,
            M1.name: m1,
            Z1.name: z1,
            Z2.name: z2,
            Z3.name: z3,
            R1.name: r1,
            R2.name: r2,
            R3.name: r3,
            Y.name: y,
        }
    )


single_mediator_with_multiple_confounders_nuisances_discrete_example = Example(
    name="frontdoor with multiple mediators, confounders and nuisance variables",
    reference="Causal workflow paper, figure 4 (a).",
    description="This is an extension of the frontdoor_backdoor example from y0 module"
    " but with more variables directly connecting the treatment to outcome (mediators)"
    "and several additional variables that are a direct cause of both the treatment and outcome"
    "(confounders), and several nuisance variables. The nuisance variables are R1, R2, R3. "
    "They should not be part of query estimation because they are downstream of the outcome."
    " In the data generation process, all the variables are discrete. This "
    "example is designed to check if the conditional independencies implied by the graph are"
    " aligned with the ones implied by the data via the X-square test.",
    graph=graph,
    generate_data=generate,
    example_queries=[Query.from_str(treatments="X", outcomes="Y")],
)

In [16]:
data = single_mediator_with_multiple_confounders_nuisances_discrete_example.generate_data(num_samples=500, seed=1)
data

Unnamed: 0,X,M1,Z1,Z2,Z3,R1,R2,R3,Y
0,1,1,1,1,1,1,1,1,1
1,0,1,1,1,0,1,0,1,1
2,1,1,0,1,1,1,1,1,1
3,1,1,1,1,1,1,1,0,1
4,1,1,0,1,1,0,0,1,1
...,...,...,...,...,...,...,...,...,...
495,1,1,1,1,0,1,1,1,1
496,1,1,1,0,0,1,1,1,0
497,1,1,0,1,1,1,1,1,1
498,1,1,1,0,1,1,1,1,1


## Step 1: Verify correctness of the network structure

In [9]:
#remove after branch is merged
import logging
from typing import Dict, Literal, Optional
import pandas as pd

from y0.algorithm.falsification import get_graph_falsifications
from y0.dsl import Variable
from y0.graph import NxMixedGraph
from y0.struct import get_conditional_independence_tests

logging.basicConfig(format="%(message)s", level=logging.DEBUG)


__all__ = [
    "conditional_independence_test_summary",
    "validate_test",
    "get_state_space_map",
    "is_data_discrete",
    "is_data_continuous",
    "CITest",
    "choose_default_test",
]

TESTS = get_conditional_independence_tests()


def get_state_space_map(
    data: pd.DataFrame, threshold: Optional[int] = 10
) -> Dict[Variable, Literal["discrete", "continuous"]]:
    """Get a dictionary from each variable to its type.

    :param data: the observed data
    :param threshold: The threshold for determining a column as discrete
        based on the number of unique values
    :return: the mapping from column name to its type
    """
    column_values_unique_count = {
        column_name: data[column_name].nunique() for column_name in data.columns
    }
    return {
        Variable(column): "discrete"
        if column_values_unique_count[column] <= threshold
        else "continuous"
        for column in data.columns
    }


def is_data_discrete(data: pd.DataFrame) -> bool:
    """Check if all the columns in the dataframe has discrete data.

    :param data: observational data.
    :return: True, if all the columns have discrete data, False, otherwise
    """
    variable_types = set(get_state_space_map(data=data).values())
    return variable_types == {"discrete"}


def is_data_continuous(data: pd.DataFrame) -> bool:
    """Check if all the columns in the dataframe has continuous data.

    :param data: observational.
    :return: True, if all the columns have continuous data, False, otherwise
    """
    variable_types = set(get_state_space_map(data).values())
    return variable_types == {"continuous"}


# TODO replace with y0.struct.CITest
CITest = Literal[
    "pearson",
    "chi-square",
    "cressie_read",
    "freeman_tuckey",
    "g_sq",
    "log_likelihood",
    "modified_log_likelihood",
    "power_divergence",
    "neyman",
]


def choose_default_test(data: pd.DataFrame) -> CITest:
    """Choose the default statistical test for testing conditional independencies based on the data.

    :param data: observational data.
    :return: the default test based on data
    :raises NotImplementedError: if data is of mixed type (contains both discrete and continuous columns)
    """
    if is_data_discrete(data):
        return "chi-square"
    if is_data_continuous(data):
        return "pearson"
    raise NotImplementedError(
        "Mixed data types are not allowed. Either all of the columns of data should be discrete / continuous."
    )


def validate_test(
    data: pd.DataFrame,
    test: Optional[CITest],
) -> None:
    """Validate the conditional independency test passed by the user.

    :param data: observational data.
    :param test: the conditional independency test passed by the user.
    :raises ValueError: if the passed test is invalid / unsupported, pearson is used for discrete data or
        chi-square is used for continuous data
    """
    tests = get_conditional_independence_tests()
    if test not in tests:
        raise ValueError(f"`{test}` is invalid. Supported CI tests are: {sorted(tests)}")

    if is_data_continuous(data) and test != "pearson":
        raise ValueError(
            "The data is continuous. Either discretize and use chi-square or use the pearson."
        )

    if is_data_discrete(data) and test == "pearson":
        raise ValueError("Cannot run pearson on discrete data. Use chi-square instead.")


def conditional_independence_test_summary(
    graph: NxMixedGraph,
    data: pd.DataFrame,
    test: Optional[CITest] = None,
    max_given: Optional[int] = 5,
    significance_level: Optional[float] = None,
    verbose: Optional[bool] = False,
) -> None:
    """Print the summary of conditional independency test results.

    Prints the summary to the console, which includes the total number of conditional independence tests,
    the number and percentage of failed tests, and statistical information about each test such as p-values,
    and test results.

    :param graph: an NxMixedGraph
    :param data: observational data corresponding to the graph
    :param test: the conditional independency test to use. If None, defaults to ``pearson`` for continuous data
        and ``chi-square`` for discrete data.
    :param max_given: The maximum set size in the power set of the vertices minus the d-separable pairs
    :param significance_level: The statistical tests employ this value for
        comparison with the p-value of the test to determine the independence of
        the tested variables. If none, defaults to 0.01.
    :param verbose: If `False`, only print the details of failed tests.
        If 'True', print the details of all the conditional independency results. Defaults to `False`
    :raises NotImplementedError: if data is of mixed type (contains both discrete and continuous columns)
    """
    if significance_level is None:
        significance_level = 0.01
    if not test:
        test = choose_default_test(data)
    else:
        # Validate test and data
        validate_test(data=data, test=test)
        if len(set(get_state_space_map(data).values())) > 1:
            raise NotImplementedError(
                "Mixed data types are not allowed. Either all of the columns of data should be discrete / continuous."
            )
    test_results = get_graph_falsifications(
        graph=graph,
        df=data,
        method=test,
        significance_level=significance_level,
        max_given=max_given,
    ).evidence
    # Find the result based on p-value
    test_results["result"] = test_results["p"].apply(
        lambda p_value: "fail" if p_value < significance_level else "pass"
    )
    # Selecting columns of interest
    test_results = test_results[["left", "right", "given", "p", "result"]]
    # Sorting the rows by index
    test_results = test_results.sort_index()
    test_results = test_results.rename(columns={"p": "p-value"})
    failed_tests = test_results[test_results["result"] == "fail"]
    total_no_of_tests = len(test_results)
    total_no_of_failed_tests = len(failed_tests)
    percentage_of_failed_tests = total_no_of_failed_tests / total_no_of_tests
    logging.info(f"Total number of conditional independencies: {total_no_of_tests:,}")
    logging.info(f"Total number of failed tests: {total_no_of_failed_tests:,}")
    logging.info(f"Percentage of failed tests: {percentage_of_failed_tests:.2%}")
    if verbose:
        logging.info(test_results.to_string(index=False))
    else:
        logging.info(failed_tests.to_string(index=False))


In [6]:
conditional_independence_test_summary(graph, data, verbose=True)

  for z_state, df in data.groupby(Z):
  for z_state, df in data.groupby(Z):
  for z_state, df in data.groupby(Z):
  for z_state, df in data.groupby(Z):
  for z_state, df in data.groupby(Z):
  for z_state, df in data.groupby(Z):
  for z_state, df in data.groupby(Z):
  for z_state, df in data.groupby(Z):
  for z_state, df in data.groupby(Z):
  for z_state, df in data.groupby(Z):
  for z_state, df in data.groupby(Z):
  for z_state, df in data.groupby(Z):
  for z_state, df in data.groupby(Z):
  for z_state, df in data.groupby(Z):
  for z_state, df in data.groupby(Z):
  for z_state, df in data.groupby(Z):
  for z_state, df in data.groupby(Z):
Total number of conditional independencies: 26
Total number of failed tests: 0
Percentage of failed tests: 0.00%
left right given  p-value result
  R2    Z2    R1 0.925150   pass
   X     Y M1|Z1 0.915287   pass
  M1    Z2     X 0.238941   pass
  M1    R2    R1 0.184011   pass
  R1    Z2    M1 0.962174   pass
  R3     X  R1|Y 0.415925   pass
   X    Z2

All the d-separations implied by the network are validated by the data. No test failed. Hence, we can proceed to step 2.

## Step 2: Check query identifiability

In [7]:
from y0.algorithm.identify import Identification
from y0.dsl import P
id_in = Identification.from_expression(
    query=P(Variable('Y') @ Variable('X')),
    graph=graph,
)
id_in

Identification(outcomes="{Y}, treatments="{X}",conditions="set()",  graph="NxMixedGraph(directed=<networkx.classes.digraph.DiGraph object at 0x118041e10>, undirected=<networkx.classes.graph.Graph object at 0x118040550>)", estimand="P(M1, R1, R2, R3, X, Y, Z1, Z2, Z3)")

The query is identifiable. Hence we can proceed to step 3.

## Step 3: Find nuisance variables and mark them as latent

In [8]:
#remove after merging simplify branch to main
import itertools
from typing import Iterable, Optional, Set, Union

import networkx as nx

from y0.algorithm.simplify_latent import simplify_latent_dag
from y0.dsl import Variable
from y0.graph import DEFAULT_TAG, NxMixedGraph

__all__ = [
    "remove_latent_variables",
    "mark_nuisance_variables_as_latent",
    "find_all_nodes_in_causal_paths",
    "find_nuisance_variables",
]


def remove_latent_variables(
    graph: NxMixedGraph,
    treatments: Union[Variable, Set[Variable]],
    outcomes: Union[Variable, Set[Variable]],
    tag: Optional[str] = None,
) -> NxMixedGraph:
    """Find all nuissance variables and remove them based on Evans' simplification rules.

    :param graph: an NxMixedGraph
    :param treatments: a list of treatments
    :param outcomes: a list of outcomes
    :param tag: The tag for which variables are latent
    :return: the new graph after simplification
    """
    lv_dag = mark_nuisance_variables_as_latent(
        graph=graph, treatments=treatments, outcomes=outcomes, tag=tag
    )
    simplified_latent_dag = simplify_latent_dag(lv_dag, tag=tag)
    return NxMixedGraph.from_latent_variable_dag(simplified_latent_dag.graph, tag=tag)


def mark_nuisance_variables_as_latent(
    graph: NxMixedGraph,
    treatments: Union[Variable, Set[Variable]],
    outcomes: Union[Variable, Set[Variable]],
    tag: Optional[str] = None,
) -> nx.DiGraph:
    """Find all the nuisance variables and mark them as latent.

    Mark nuisance variables as latent by first identifying them, then creating a new graph where these
    nodes are marked as latent. Nuisance variables are the descendants of nodes in all proper causal paths
    that are not ancestors of the outcome variables nodes. A proper causal path is a directed path from
    treatments to the outcome. Nuisance variables should not be included in the estimation of the causal
    effect as they increase the variance.

    :param graph: an NxMixedGraph
    :param treatments: a list of treatments
    :param outcomes: a list of outcomes
    :param tag: The tag for which variables are latent
    :return: the modified graph after simplification, in place
    """
    if tag is None:
        tag = DEFAULT_TAG
    nuisance_variables = find_nuisance_variables(graph, treatments=treatments, outcomes=outcomes)
    lv_dag = NxMixedGraph.to_latent_variable_dag(graph, tag=tag)
    # Set nuisance variables as latent
    for node, data in lv_dag.nodes(data=True):
        if Variable(node) in nuisance_variables:
            data[tag] = True
    return lv_dag


def find_all_nodes_in_causal_paths(
    graph: NxMixedGraph,
    treatments: Union[Variable, Set[Variable]],
    outcomes: Union[Variable, Set[Variable]],
) -> Set[Variable]:
    """Find all the nodes in proper causal paths from treatments to outcomes.

    A proper causal path is a directed path from treatments to the outcome.

    :param graph: an NxMixedGraph
    :param treatments: a list of treatments
    :param outcomes: a list of outcomes
    :return: the nodes on all causal paths from treatments to outcomes.
    """
    if isinstance(treatments, Variable):
        treatments = {treatments}
    if isinstance(outcomes, Variable):
        outcomes = {outcomes}

    return {
        node
        for treatment, outcome in itertools.product(treatments, outcomes)
        for causal_path in nx.all_simple_paths(graph.directed, treatment, outcome)
        for node in causal_path
    }


def find_nuisance_variables(
    graph: NxMixedGraph,
    treatments: Union[Variable, Set[Variable]],
    outcomes: Union[Variable, Set[Variable]],
) -> Iterable[Variable]:
    """Find the nuisance variables in the graph.

    Nuisance variables are the descendants of nodes in all proper causal paths that are
    not ancestors of the outcome variables' nodes. A proper causal path is a directed path
    from treatments to the outcome. Nuisance variables should not be included in the estimation
    of the causal effect as they increase the variance.

    :param graph: an NxMixedGraph
    :param treatments: a list of treatments
    :param outcomes: a list of outcomes
    :returns: The nuisance variables.
    """
    if isinstance(treatments, Variable):
        treatments = {treatments}
    if isinstance(outcomes, Variable):
        outcomes = {outcomes}

    # Find the nodes on all causal paths
    nodes_on_causal_paths = find_all_nodes_in_causal_paths(
        graph=graph, treatments=treatments, outcomes=outcomes
    )

    # Find the descendants of interest
    descendants_of_nodes_on_causal_paths = graph.descendants_inclusive(nodes_on_causal_paths)

    # Find the ancestors of outcome variables
    ancestors_of_outcomes = graph.ancestors_inclusive(outcomes)

    descendants_not_ancestors = descendants_of_nodes_on_causal_paths.difference(
        ancestors_of_outcomes
    )

    nuisance_variables = descendants_not_ancestors.difference(treatments.union(outcomes))
    return nuisance_variables

This function finds the nuisance variables for the input graph.

In [9]:
nuisance_variables = find_nuisance_variables(graph, treatments=Variable("X"), outcomes=Variable("Y"))
nuisance_variables

{R1, R2, R3}

The nuisance variables are $R_1$, $R_2$, and $R_3$.

## Step 4: Simplify the network

The following function find the nuisance variable (step 3), marks them as latent and then applies Evan's simplification rules to remove the nuisance variables. The new graph does not contain nuisance variables.

In [10]:
new_graph = remove_latent_variables(graph, treatments=Variable("X"), outcomes=Variable("Y"))

## Step 5: Estimate the query

In [10]:
from y0.algorithm.estimation import estimate_ace

In [12]:
ATE_value = estimate_ace(graph=graph,
                         treatments=Variable("Raf"),
                         outcomes=Variable("Erk"),
                         data=data)
ATE_value

SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
  for colname, colvalues in df.iteritems():
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG
SG
ADMG


-0.6949811379122186

In [11]:
#remove after estimation branch is merged to main
#pip install ananke-causal

In [16]:
#get the real value of ATE
intv_data1 = single_mediator_with_multiple_confounders_nuisances_discrete_example.generate_data(num_samples=500, treatments={Variable('X'):1}, seed=1)
intv_data0 = single_mediator_with_multiple_confounders_nuisances_discrete_example.generate_data(num_samples=500, treatments={Variable('X'):0}, seed=1)
np.mean(intv_data1['Y']) - np.mean(intv_data0['Y'])

0.0040000000000000036