Skip to content

API for Global Causal Discovery

Robert Osazuwa Ness edited this page Sep 1, 2022 · 6 revisions

Global causal discovery refers to learning causal structure between a set of variables observed in data. Global causal discovery now lives in the dodiscover repo.

This is a running document. We will be updating the API as we write issues and start writing code. Please add your suggestions and comments on the discussion page.

Bias towards "context" over data

A key differentiator of PyWhy causal discovery should be a departure from "data-first causal discovery," where users provide data as the primary input to a discovery algorithm. This problem with this approach is that it encourages novice users to see the algorithm as a philosopher's stone that converts data to causal relationships. With this mindset, users tend surrender the task of providing domain-specific assumptions that enable identifiability to the algorithm. In contrast, DoWhy's key strength is how it guides users to specifying domain assumptions up front (in the form of a DAG) before the data is added, and then addresses identifiability given those assumptions and data. Global causal discovery should follow this same pattern.

Base case

We propose a Context data structure that stores assumptions, domain knowledge, include/exclude lists, priors, and other contexts that can constrain causal discovery. It is passed in along side data. This Context object is what we passed to a discovery algorithm. In the basic "assumption-free" case, context is empty.

context = Context()
model = learn_graph(context, data)

Include and exclude list

Include and exclude lists are highly effective ways to elicit expert knowledge and reduce the size of the search space.

context = Context(
                included_edges=included_df,
                excluded_edges=excluded_df,
                data=df
)

Internally, these are transformed to networkx graphs themselves for easy parsing of edges.

Functional

cpdag = learn_graph( context, method=varlingam, ... ) cpdag = learn_graph( context, method=postnonlinear, ... ) cpdag = learn_graph( context, method=anm, ... )


# Advanced Usage of Algorithms

Besides the `learn_graph` entrypoint for high-level users, more experienced users can work with the algorithms themselves following a scikit-learn-like API with `fit()`.

```Python
discoverer1= PC(test="chi-squared", alpha=.05)
discoverer1.fit(context)
cpdag = discoverer1.graph_
separating_sets = discoverer1.sep_sets_

discoverer2 = BIC()
discoverer2.fit(context)

Representations for equivalence classes as default

Discovery algorithms should return equivalence class graph representations (CPDAGs, PAGs) by default. For example, a score-based algorithm returns a DAG, but that DAG should be converted to a CPDAG before being output to the user. This explicity exposes identifiability to the user.

cpdag = learn_graph(
        context,
        method=score_discoverer
)

If the user wants to work with a DAG member of the equivalence class, they should convert the CPDAG or PAG to a DAG in a separate step.

dag = cextend(cpdag)

Ensemble discovery, idiomatic python, and functional programming

Ensembles of graphs address the problem of uncertainty in causal discovery.

For example, a user should be able to do causal discovery via a bootstrapping procedure. Given a list of bootstapped datasets:

get_context = partial(BayesianContext, prior={
                    "pseudocounts": 10, 
                    "edgewise": prior_df 
                }
)
contexts = [get_context(boot_df) for boot_df in bootstrap_dfs]

def get_pdags(context):
        return learn_graph(
                context,
                method=score_discoverer,
        )

pdags = map(contexts, get_pdags)

Assemble averaging can reduce a list of graphs into a consensus graph.

pdag = consensus_graph(pdags)

However, the user should be able to do an end-to-end workflow in the consensus setting.

def get_effect(context):
        pdag = learn_graph(
                context,
                method=score_discoverer
        )
        dag = cextend(pdag)
        estimand = dowhy.identify_effect(dag,
                                 action_node="A", 
                                 outcome_node="B",
                                 observed_nodes=...)
        estimator = dowhy.LinearRegressionEstimator(estimand)
        estimator.fit(context.data)
        estimate = estimator.estimate_effect(action_value=..., control_value=...)
        return estimate


estimates = map(contexts, get_effect)
mean(estimates)