diff --git a/docs/source/example_notebooks/dowhy_efficient_backdoor_example.ipynb b/docs/source/example_notebooks/dowhy_efficient_backdoor_example.ipynb index 3c68b81908..e59598a498 100644 --- a/docs/source/example_notebooks/dowhy_efficient_backdoor_example.ipynb +++ b/docs/source/example_notebooks/dowhy_efficient_backdoor_example.ipynb @@ -57,7 +57,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -84,7 +84,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -149,7 +149,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -165,7 +165,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -174,13 +174,25 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "__init__() got an unexpected keyword argument 'method_name'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/andresmor/Projects/dowhy/docs/source/example_notebooks/dowhy_efficient_backdoor_example.ipynb Cell 13\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ident_eff \u001b[39m=\u001b[39m BackdoorIdentifier(\n\u001b[1;32m 2\u001b[0m estimand_type\u001b[39m=\u001b[39;49mCausalIdentifierEstimandType\u001b[39m.\u001b[39;49mNONPARAMETRIC_ATE,\n\u001b[1;32m 3\u001b[0m method_name\u001b[39m=\u001b[39;49mBackdoorAdjustmentMethod\u001b[39m.\u001b[39;49mBACKDOOR_EFFICIENT,\n\u001b[1;32m 4\u001b[0m )\n\u001b[1;32m 5\u001b[0m \u001b[39mprint\u001b[39m(\n\u001b[1;32m 6\u001b[0m ident_eff\u001b[39m.\u001b[39midentify_effect(\n\u001b[1;32m 7\u001b[0m graph\u001b[39m=\u001b[39mG, treatment_name\u001b[39m=\u001b[39mtreatment_name, outcome_name\u001b[39m=\u001b[39moutcome_name, conditional_node_names\u001b[39m=\u001b[39mconditional_node_names\n\u001b[1;32m 8\u001b[0m )\n\u001b[1;32m 9\u001b[0m )\n", + "\u001b[0;31mTypeError\u001b[0m: __init__() got an unexpected keyword argument 'method_name'" + ] + } + ], "source": [ "ident_eff = BackdoorIdentifier(\n", " estimand_type=CausalIdentifierEstimandType.NONPARAMETRIC_ATE,\n", - " method_name=BackdoorAdjustmentMethod.BACKDOOR_EFFICIENT,\n", + " backdoor_adjustment=BackdoorAdjustmentMethod.BACKDOOR_EFFICIENT,\n", ")\n", "print(\n", " ident_eff.identify_effect(\n", @@ -211,7 +223,7 @@ "source": [ "ident_minimal_eff = BackdoorIdentifier(\n", " estimand_type=CausalIdentifierEstimandType.NONPARAMETRIC_ATE,\n", - " method_name=BackdoorAdjustmentMethod.BACKDOOR_MIN_EFFICIENT,\n", + " backdoor_adjustment=BackdoorAdjustmentMethod.BACKDOOR_MIN_EFFICIENT,\n", ")\n", "print(\n", " ident_minimal_eff.identify_effect(\n", @@ -235,7 +247,7 @@ "source": [ "ident_mincost_eff = BackdoorIdentifier(\n", " estimand_type=CausalIdentifierEstimandType.NONPARAMETRIC_ATE,\n", - " method_name=BackdoorAdjustmentMethod.BACKDOOR_MINCOST_EFFICIENT,\n", + " backdoor_adjustment=BackdoorAdjustmentMethod.BACKDOOR_MINCOST_EFFICIENT,\n", ")\n", "print(\n", " ident_mincost_eff.identify_effect(\n", @@ -314,7 +326,7 @@ "source": [ "ident_eff = BackdoorIdentifier(\n", " estimand_type=CausalIdentifierEstimandType.NONPARAMETRIC_ATE,\n", - " method_name=BackdoorAdjustmentMethod.BACKDOOR_EFFICIENT,\n", + " backdoor_adjustment=BackdoorAdjustmentMethod.BACKDOOR_EFFICIENT,\n", ")\n", "try:\n", " results_eff = ident_eff.identify_effect(graph=G, treatment_name=treatment_name, outcome_name=outcome_name)\n", @@ -330,7 +342,7 @@ "source": [ "ident_eff = BackdoorIdentifier(\n", " estimand_type=CausalIdentifierEstimandType.NONPARAMETRIC_ATE,\n", - " method_name=BackdoorAdjustmentMethod.BACKDOOR_MIN_EFFICIENT,\n", + " backdoor_adjustment=BackdoorAdjustmentMethod.BACKDOOR_MIN_EFFICIENT,\n", ")\n", "print(\n", " ident_minimal_eff.identify_effect(\n", @@ -349,7 +361,7 @@ "source": [ "ident_eff = BackdoorIdentifier(\n", " estimand_type=CausalIdentifierEstimandType.NONPARAMETRIC_ATE,\n", - " method_name=BackdoorAdjustmentMethod.BACKDOOR_MINCOST_EFFICIENT,\n", + " backdoor_adjustment=BackdoorAdjustmentMethod.BACKDOOR_MINCOST_EFFICIENT,\n", ")\n", "print(\n", " ident_mincost_eff.identify_effect(\n", @@ -404,7 +416,7 @@ "source": [ "ident_eff = BackdoorIdentifier(\n", " estimand_type=CausalIdentifierEstimandType.NONPARAMETRIC_ATE,\n", - " method_name=BackdoorAdjustmentMethod.BACKDOOR_EFFICIENT,\n", + " backdoor_adjustment=BackdoorAdjustmentMethod.BACKDOOR_EFFICIENT,\n", ")\n", "try:\n", " results_eff = ident_eff.identify_effect(\n", @@ -499,7 +511,7 @@ "source": [ "ident_eff = BackdoorIdentifier(\n", " estimand_type=CausalIdentifierEstimandType.NONPARAMETRIC_ATE,\n", - " method_name=BackdoorAdjustmentMethod.BACKDOOR_MINCOST_EFFICIENT,\n", + " backdoor_adjustment=BackdoorAdjustmentMethod.BACKDOOR_MINCOST_EFFICIENT,\n", " costs=costs,\n", ")\n", "print(\n", @@ -524,7 +536,7 @@ "source": [ "ident_eff = BackdoorIdentifier(\n", " estimand_type=CausalIdentifierEstimandType.NONPARAMETRIC_ATE,\n", - " method_name=BackdoorAdjustmentMethod.BACKDOOR_MIN_EFFICIENT,\n", + " backdoor_adjustment=BackdoorAdjustmentMethod.BACKDOOR_MIN_EFFICIENT,\n", ")\n", "print(\n", " ident_minimal_eff.identify_effect(\n", diff --git a/docs/source/example_notebooks/functional_api.ipynb b/docs/source/example_notebooks/functional_api.ipynb index b09fa79068..ee606fbb44 100644 --- a/docs/source/example_notebooks/functional_api.ipynb +++ b/docs/source/example_notebooks/functional_api.ipynb @@ -6,8 +6,9 @@ "metadata": {}, "outputs": [], "source": [ - "from dowhy import identify_effect, CausalModel\n", + "from dowhy import CausalModel\n", "from dowhy.causal_identifier import (\n", + " identify_effect,\n", " BackdoorIdentifier,\n", " BackdoorAdjustmentMethod,\n", " IDIdentifier,\n", @@ -61,16 +62,13 @@ "metadata": {}, "outputs": [], "source": [ - "# identify_effect method returns a tuple: identifier (CausalIdentifier instance) and an IdentifiedEstimand|IDExpression instance\n", - "# the identifier is returned for backwards compatibility with old api\n", "# New functional API\n", "\n", - "_, identified_estimand = identify_effect(\n", + "identified_estimand = identify_effect(\n", " graph=graph,\n", " treatment=treatment_name,\n", " outcome=outcome_name,\n", - " method=BackdoorIdentifier,\n", - " estimand_type=CausalIdentifierEstimandType.NONPARAMETRIC_ATE,\n", + " method=BackdoorIdentifier(estimand_type=CausalIdentifierEstimandType.NONPARAMETRIC_ATE),\n", ")\n", "print(identified_estimand)" ] @@ -100,7 +98,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Advanced Configuration check \n", + "# Another way of executing the identify effect by directly calling the object\n", "\n", "identifier = BackdoorIdentifier(\n", " estimand_type=CausalIdentifierEstimandType.NONPARAMETRIC_ATE, backdoor_adjustment=BackdoorAdjustmentMethod.BACKDOOR_DEFAULT\n", @@ -122,12 +120,11 @@ "outputs": [], "source": [ "# New functional API (IDIdentifier)\n", - "_, identified_estimand = identify_effect(\n", + "identified_estimand = identify_effect(\n", " graph=graph,\n", " treatment=treatment_name,\n", " outcome=outcome_name,\n", - " estimand_type=CausalIdentifierEstimandType.NONPARAMETRIC_ATE,\n", - " method=IDIdentifier,\n", + " method=IDIdentifier(estimand_type=CausalIdentifierEstimandType.NONPARAMETRIC_ATE),\n", ")\n", "print(identified_estimand)" ] @@ -157,7 +154,7 @@ "metadata": {}, "outputs": [], "source": [ - "# Advanced Configuration (IDIdentifier)\n", + "# Another way of executing the identify effect by directly calling the object\n", "\n", "identifier = IDIdentifier(estimand_type=CausalIdentifierEstimandType.NONPARAMETRIC_ATE)\n", "\n", @@ -169,6 +166,13 @@ "\n", "print(identified_estimand)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/dowhy/__init__.py b/dowhy/__init__.py index 6996ef01fa..a22b01b3e7 100755 --- a/dowhy/__init__.py +++ b/dowhy/__init__.py @@ -1,6 +1,6 @@ import logging -from dowhy.causal_model import CausalModel, identify_effect +from dowhy.causal_model import CausalModel logging.getLogger(__name__).addHandler(logging.NullHandler()) diff --git a/dowhy/causal_estimators/two_stage_regression_estimator.py b/dowhy/causal_estimators/two_stage_regression_estimator.py index eda58c93d1..bd00cf1ffe 100644 --- a/dowhy/causal_estimators/two_stage_regression_estimator.py +++ b/dowhy/causal_estimators/two_stage_regression_estimator.py @@ -6,7 +6,7 @@ from dowhy.causal_estimator import CausalEstimate, CausalEstimator from dowhy.causal_estimators.linear_regression_estimator import LinearRegressionEstimator -from dowhy.causal_identifier.causal_identifier import CausalIdentifierEstimandType +from dowhy.causal_identifier.identify_effect import CausalIdentifierEstimandType from dowhy.utils.api import parse_state diff --git a/dowhy/causal_identifier/__init__.py b/dowhy/causal_identifier/__init__.py index 5124ff4693..06af3125c5 100644 --- a/dowhy/causal_identifier/__init__.py +++ b/dowhy/causal_identifier/__init__.py @@ -1,3 +1,14 @@ from dowhy.causal_identifier.backdoor_identifier import BackdoorIdentifier, BackdoorAdjustmentMethod -from dowhy.causal_identifier.causal_identifier import CausalIdentifierEstimandType, IdentifiedEstimand +from dowhy.causal_identifier.identify_effect import CausalIdentifierEstimandType, IdentifiedEstimand from dowhy.causal_identifier.id_identifier import IDIdentifier +from dowhy.causal_identifier.identify_effect import identify_effect + + +__all__ = [ + "BackdoorIdentifier", + "BackdoorAdjustmentMethod", + "CausalIdentifierEstimandType", + "IdentifiedEstimand", + "IDIdentifier", + "identify_effect", +] diff --git a/dowhy/causal_identifier/backdoor_identifier.py b/dowhy/causal_identifier/backdoor_identifier.py index d0affa7f67..c951ba6a69 100644 --- a/dowhy/causal_identifier/backdoor_identifier.py +++ b/dowhy/causal_identifier/backdoor_identifier.py @@ -7,7 +7,7 @@ import sympy.stats as spstats from dowhy.causal_graph import CausalGraph -from dowhy.causal_identifier.causal_identifier import CausalIdentifierEstimandType, IdentifiedEstimand +from dowhy.causal_identifier.identify_effect import CausalIdentifierEstimandType, IdentifiedEstimand from dowhy.causal_identifier.efficient_backdoor import EfficientBackdoor from dowhy.utils.api import parse_state @@ -56,7 +56,6 @@ def __init__( proceed_when_unidentifiable: bool = False, optimize_backdoor: bool = False, costs: Optional[List] = None, - **kwargs, ): self.estimand_type = estimand_type self.backdoor_adjustment = backdoor_adjustment @@ -71,6 +70,7 @@ def identify_effect( treatment_name: Union[str, List[str]], outcome_name: Union[str, List[str]], conditional_node_names: List[str] = None, + **kwargs, ): """Main method that returns an identified estimand (if one exists). diff --git a/dowhy/causal_identifier/id_identifier.py b/dowhy/causal_identifier/id_identifier.py index a855b22590..4aa66aa5e0 100644 --- a/dowhy/causal_identifier/id_identifier.py +++ b/dowhy/causal_identifier/id_identifier.py @@ -1,7 +1,7 @@ import networkx as nx from dowhy.causal_graph import CausalGraph -from dowhy.causal_identifier.causal_identifier import CausalIdentifierEstimandType +from dowhy.causal_identifier.identify_effect import CausalIdentifierEstimandType from dowhy.utils.api import parse_state from dowhy.utils.graph_operations import find_ancestor, find_c_components, induced_graph from dowhy.utils.ordered_set import OrderedSet @@ -87,11 +87,7 @@ def __str__(self): class IDIdentifier: - def __init__( - self, - estimand_type=CausalIdentifierEstimandType.NONPARAMETRIC_ATE, - **kwargs, - ): + def __init__(self, estimand_type=CausalIdentifierEstimandType.NONPARAMETRIC_ATE): """ Class to perform identification using the ID algorithm. @@ -103,14 +99,7 @@ def __init__( if estimand_type != CausalIdentifierEstimandType.NONPARAMETRIC_ATE: raise Exception("The estimand type should be 'non-parametric ate' for the ID method type.") - # - def identify_effect( - self, - graph: CausalGraph, - treatment_name, - outcome_name, - conditional_node_names=None, - ): + def identify_effect(self, graph: CausalGraph, treatment_name, outcome_name, node_names=None, **kwargs): """ Implementation of the ID algorithm. Link - https://ftp.cs.ucla.edu/pub/stat_ser/shpitser-thesis.pdf @@ -118,14 +107,13 @@ def identify_effect( :param treatment_names: OrderedSet comprising names of treatment variables. :param outcome_names:OrderedSet comprising names of outcome variables. - :param adjacency_matrix: Graph adjacency matrix. :param node_names: OrderedSet comprising names of all nodes in the graph :returns: target estimand, an instance of the IDExpression class. """ - if conditional_node_names is None: - conditional_node_names = OrderedSet(graph._graph.nodes) + if node_names is None: + node_names = OrderedSet(graph._graph.nodes) adjacency_matrix = graph.get_adjacency_matrix() @@ -139,7 +127,7 @@ def identify_effect( OrderedSet(parse_state(treatment_name)), OrderedSet(parse_state(outcome_name)), tsort_node_names, - conditional_node_names, + node_names, ) def _adjacency_matrix_identify_effect( @@ -148,10 +136,10 @@ def _adjacency_matrix_identify_effect( treatment_name, outcome_name, tsort_node_names, - conditional_node_names=None, + node_names=None, ): - node2idx, idx2node = self.__idx_node_mapping(conditional_node_names) + node2idx, idx2node = self.__idx_node_mapping(node_names) # Estimators list for returning after identification estimators = IDExpression() @@ -161,31 +149,29 @@ def _adjacency_matrix_identify_effect( if len(treatment_name) == 0: identifier = IDExpression() estimator = {} - estimator["outcome_vars"] = conditional_node_names + estimator["outcome_vars"] = node_names estimator["condition_vars"] = OrderedSet() identifier.add_product(estimator) - identifier.add_sum(conditional_node_names.difference(outcome_name)) + identifier.add_sum(node_names.difference(outcome_name)) estimators.add_product(identifier) return estimators # Line 2 # If we are interested in the effect on Y, it is sufficient to restrict our attention on the parts of the model ancestral to Y. - ancestors = find_ancestor(outcome_name, conditional_node_names, adjacency_matrix, node2idx, idx2node) + ancestors = find_ancestor(outcome_name, node_names, adjacency_matrix, node2idx, idx2node) if ( - len(conditional_node_names.difference(ancestors)) != 0 + len(node_names.difference(ancestors)) != 0 ): # If there are elements which are not the ancestor of the outcome variables # Modify list of valid nodes treatment_name = treatment_name.intersection(ancestors) - conditional_node_names = conditional_node_names.intersection(ancestors) - adjacency_matrix = induced_graph( - node_set=conditional_node_names, adjacency_matrix=adjacency_matrix, node2idx=node2idx - ) + node_names = node_names.intersection(ancestors) + adjacency_matrix = induced_graph(node_set=node_names, adjacency_matrix=adjacency_matrix, node2idx=node2idx) return self._adjacency_matrix_identify_effect( treatment_name=treatment_name, outcome_name=outcome_name, adjacency_matrix=adjacency_matrix, tsort_node_names=tsort_node_names, - conditional_node_names=conditional_node_names, + node_names=node_names, ) # Line 3 - forces an action on any node where such an action would have no effect on Y – assuming we already acted on X. @@ -193,23 +179,23 @@ def _adjacency_matrix_identify_effect( adjacency_matrix_do_x = adjacency_matrix.copy() for x in treatment_name: x_idx = node2idx[x] - for i in range(len(conditional_node_names)): + for i in range(len(node_names)): adjacency_matrix_do_x[i, x_idx] = 0 - ancestors = find_ancestor(outcome_name, conditional_node_names, adjacency_matrix_do_x, node2idx, idx2node) - W = conditional_node_names.difference(treatment_name).difference(ancestors) + ancestors = find_ancestor(outcome_name, node_names, adjacency_matrix_do_x, node2idx, idx2node) + W = node_names.difference(treatment_name).difference(ancestors) if len(W) != 0: return self._adjacency_matrix_identify_effect( treatment_name=treatment_name.union(W), outcome_name=outcome_name, adjacency_matrix=adjacency_matrix, tsort_node_names=tsort_node_names, - conditional_node_names=conditional_node_names, + node_names=node_names, ) # Line 4 - Decomposes the problem into a set of smaller problems using the key property of C-component factorization of causal models. # If the entire graph is a single C-component already, further problem decomposition is impossible, and we must provide base cases. # Modify adjacency matrix to remove treatment variables - node_names_minus_x = conditional_node_names.difference(treatment_name) + node_names_minus_x = node_names.difference(treatment_name) node2idx_minus_x, idx2node_minus_x = self.__idx_node_mapping(node_names_minus_x) adjacency_matrix_minus_x = induced_graph( node_set=node_names_minus_x, adjacency_matrix=adjacency_matrix, node2idx=node2idx @@ -219,14 +205,14 @@ def _adjacency_matrix_identify_effect( ) if len(c_components) > 1: identifier = IDExpression() - sum_over_set = conditional_node_names.difference(outcome_name.union(treatment_name)) + sum_over_set = node_names.difference(outcome_name.union(treatment_name)) for component in c_components: expressions = self._adjacency_matrix_identify_effect( - treatment_name=conditional_node_names.difference(component), + treatment_name=node_names.difference(component), outcome_name=OrderedSet(list(component)), adjacency_matrix=adjacency_matrix, tsort_node_names=tsort_node_names, - conditional_node_names=conditional_node_names, + node_names=node_names, ) for expression in expressions.get_val(return_type="prod"): identifier.add_product(expression) @@ -236,10 +222,8 @@ def _adjacency_matrix_identify_effect( # Line 5 - The algorithms fails due to the presence of a hedge - the graph G, and a subgraph S that does not contain any X nodes. S = c_components[0] - c_components_G = find_c_components( - adjacency_matrix=adjacency_matrix, node_set=conditional_node_names, idx2node=idx2node - ) - if len(c_components_G) == 1 and c_components_G[0] == conditional_node_names: + c_components_G = find_c_components(adjacency_matrix=adjacency_matrix, node_set=node_names, idx2node=idx2node) + if len(c_components_G) == 1 and c_components_G[0] == node_names: return None # Line 6 - If there are no bidirected arcs from X to the other nodes in the current subproblem under consideration, then we can replace acting on X by conditioning, and thus solve the subproblem. @@ -269,7 +253,7 @@ def _adjacency_matrix_identify_effect( node_set=component, adjacency_matrix=adjacency_matrix, node2idx=node2idx ), tsort_node_names=tsort_node_names, - conditional_node_names=conditional_node_names, + node_names=node_names, ) def __idx_node_mapping(self, node_names): diff --git a/dowhy/causal_identifier/causal_identifier.py b/dowhy/causal_identifier/identify_effect.py similarity index 83% rename from dowhy/causal_identifier/causal_identifier.py rename to dowhy/causal_identifier/identify_effect.py index 5f7d503916..1793f9f975 100755 --- a/dowhy/causal_identifier/causal_identifier.py +++ b/dowhy/causal_identifier/identify_effect.py @@ -20,9 +20,7 @@ class CausalIdentifierEstimandType(Enum): class CausalIdentifier(Protocol): - def identify_effect( - graph: CausalGraph, treatment_name: List[str], outcome_name: List[str], conditional_node_names: List[str] = None - ): + def identify_effect(self, graph: CausalGraph, treatment_name: List[str], outcome_name: List[str], **kwargs): ... @@ -151,3 +149,28 @@ def __str__(self, only_target_estimand=False, show_all_backdoor_sets=False): j += 1 i += 1 return s + + +def identify_effect( + graph: CausalGraph, + treatment: List[str], + outcome: List[str], + method: CausalIdentifier, + node_names=None, + conditional_node_names=None, +): + """Identify the causal effect to be estimated based on a CausalGraph + + :param graph: CausalGraph to be analyzed + :param treatment: name of the treatment + :param outcome: name of the outcome + :param method: CausalIdentifier instance to use to identify effects + :param node_names: OrderedSet comprising names of all nodes in the graph (Used for IDIdentifier only) + :param conditional_node_names: variables that are used to determine treatment. If none are + provided, it is assumed that the intervention is static (Used for BackdoorIdentifier only). + :returns: a probability expression (estimand) for the causal effect if identified, else NULL + """ + identified_estimand = method.identify_effect( + graph, treatment, outcome, node_names=node_names, conditional_node_names=conditional_node_names + ) + return identified_estimand diff --git a/dowhy/causal_model.py b/dowhy/causal_model.py index 28d85b2c7c..f0b7997dae 100755 --- a/dowhy/causal_model.py +++ b/dowhy/causal_model.py @@ -3,18 +3,18 @@ """ import logging from itertools import combinations -from typing import List, Type from sympy import init_printing import dowhy.causal_estimators as causal_estimators from dowhy.causal_identifier import BackdoorIdentifier, IDIdentifier, BackdoorAdjustmentMethod +from dowhy.causal_identifier import identify_effect import dowhy.causal_refuters as causal_refuters import dowhy.graph_learners as graph_learners import dowhy.utils.cli_helpers as cli from dowhy.causal_estimator import CausalEstimate from dowhy.causal_graph import CausalGraph -from dowhy.causal_identifier.causal_identifier import CausalIdentifier, CausalIdentifierEstimandType +from dowhy.causal_identifier.identify_effect import CausalIdentifier, CausalIdentifierEstimandType from dowhy.causal_refuters.graph_refuter import GraphRefuter from dowhy.utils.api import parse_state @@ -197,19 +197,20 @@ def identify_effect( estimand_type = CausalIdentifierEstimandType(estimand_type) if method_name == "id-algorithm": - identifier_class = IDIdentifier + identifier = IDIdentifier(estimand_type=estimand_type) else: - identifier_class = BackdoorIdentifier + identifier = BackdoorIdentifier( + estimand_type=estimand_type, + backdoor_adjustment=BackdoorAdjustmentMethod(method_name), + proceed_when_unidentifiable=proceed_when_unidentifiable, + optimize_backdoor=optimize_backdoor, + ) - identifier, identified_estimand = identify_effect( + identified_estimand = identify_effect( graph=self._graph, treatment=self._treatment, outcome=self._outcome, - estimand_type=estimand_type, - method=identifier_class, - proceed_when_unidentifiable=proceed_when_unidentifiable, - optimize_backdoor=optimize_backdoor, - backdoor_adjustment=BackdoorAdjustmentMethod(method_name) if method_name != "id-algorithm" else None, + method=identifier, ) self.identifier = identifier @@ -543,35 +544,3 @@ def refute_graph(self, k=1, independence_test=None, independence_constraints=Non self.logger.info(refuter._refutation_passed) return res - - -def identify_effect( - graph: CausalGraph, - treatment: List[str], - outcome: List[str], - estimand_type: CausalIdentifierEstimandType, - method: Type[CausalIdentifier], - proceed_when_unidentifiable: bool = False, - optimize_backdoor: bool = False, - backdoor_adjustment: BackdoorAdjustmentMethod = BackdoorAdjustmentMethod.BACKDOOR_DEFAULT, -): - """ - Identify the causal effect to be estimated, using properties of the causal graph. - :param graph: the causal graph to use for identification - :param treatment: treatment variable - :param outcome: outcome variable - :param estimand_type: the type of estimand requested (one of CausalIdentifierEstimandType) - :param backdoor_adjustment: method name for identification algorithm (one of CausalIdentifierMethodName) - :param proceed_when_unidentifiable: does the identification proceed by ignoring potential unobserved confounders. Binary flag. - :param optimize_backdoor: if True, uses an optimised algorithm to compute the backdoor sets (Ignored for method_name = ID_ALGORITHM) - """ - identifier = method( - estimand_type=estimand_type, - proceed_when_unidentifiable=proceed_when_unidentifiable, - optimize_backdoor=optimize_backdoor, - backdoor_adjustment=backdoor_adjustment, - ) - - identified_estimand = identifier.identify_effect(graph, treatment, outcome) - - return identifier, identified_estimand diff --git a/tests/causal_identifiers/test_backdoor_identifier.py b/tests/causal_identifiers/test_backdoor_identifier.py index 4aab584175..e074c0b905 100644 --- a/tests/causal_identifiers/test_backdoor_identifier.py +++ b/tests/causal_identifiers/test_backdoor_identifier.py @@ -1,7 +1,7 @@ import pytest from dowhy.causal_graph import CausalGraph -from dowhy.causal_identifier.causal_identifier import CausalIdentifierEstimandType +from dowhy.causal_identifier.identify_effect import CausalIdentifierEstimandType from dowhy.causal_identifier import BackdoorIdentifier, BackdoorAdjustmentMethod from .base import IdentificationTestGraphSolution, example_graph_solution diff --git a/tests/causal_identifiers/test_efficient_backdoor_identifier.py b/tests/causal_identifiers/test_efficient_backdoor_identifier.py index 42f7def086..3cafc837f0 100644 --- a/tests/causal_identifiers/test_efficient_backdoor_identifier.py +++ b/tests/causal_identifiers/test_efficient_backdoor_identifier.py @@ -3,7 +3,7 @@ import pytest from dowhy.causal_graph import CausalGraph -from dowhy.causal_identifier.causal_identifier import CausalIdentifierEstimandType +from dowhy.causal_identifier.identify_effect import CausalIdentifierEstimandType from dowhy.causal_identifier import BackdoorIdentifier, BackdoorAdjustmentMethod from tests.causal_identifiers.example_graphs_efficient import TEST_EFFICIENT_BD_SOLUTIONS diff --git a/tests/test_causal_refuter.py b/tests/test_causal_refuter.py index e87a7124b9..aeb52f820d 100644 --- a/tests/test_causal_refuter.py +++ b/tests/test_causal_refuter.py @@ -3,7 +3,7 @@ from flaky import flaky from dowhy.causal_estimator import CausalEstimate -from dowhy.causal_identifier.causal_identifier import IdentifiedEstimand +from dowhy.causal_identifier.identify_effect import IdentifiedEstimand from dowhy.causal_refuter import CausalRefuter