In [1]:
import os
import sys

from y0.dsl import (
    One,
    P,
    A,
    B,
    C,
    D,
    Q,
    R,
    S,
    T,
    W,
    W1,
    W2,
    W3,
    W4,
    W5,
    X,
    Y,
    Z,
    Zero,
    Sum,
    Variable,
    Product,
    PP,
    Pi1,
    Pi2,
    Pi3,
    PopulationProbability,
    Fraction,
)
from y0.graph import NxMixedGraph
from y0.algorithm.counterfactual_transportability import (
    _any_variables_with_inconsistent_values,
    _compute_ancestral_components_from_ancestral_sets,
    _counterfactual_factor_is_inconsistent,
    _get_ancestral_components,
    _get_ancestral_set_after_intervening_on_conditioned_variables,
    _get_conditioned_variables_in_ancestral_set,
    _inconsistent_counterfactual_factor_variable_and_intervention_values,
    _inconsistent_counterfactual_factor_variable_intervention_values,
    _no_intervention_variables_in_domain,
    _no_transportability_nodes_in_domain,
    _reduce_reflexive_counterfactual_variables_to_interventions,
    _remove_repeated_variables_and_values,
    _remove_transportability_vertices,
    _split_event_by_reflexivity,
    _transport_unconditional_counterfactual_query_line_2,
    convert_to_counterfactual_factor_form,
    counterfactual_factors_are_transportable,
    do_counterfactual_factor_factorization,
    get_ancestors_of_counterfactual,
    get_counterfactual_factors,
    is_counterfactual_factor_form,
    make_selection_diagram,
    minimize,
    minimize_event,
    same_district,
    simplify,
    transport_conditional_counterfactual_query,
    transport_district_intervening_on_parents,
    transport_unconditional_counterfactual_query,
)
from y0.algorithm.transport import transport_variable
from y0.dsl import (
    PP,
    TARGET_DOMAIN,
    W1,
    W2,
    W3,
    W4,
    W5,
    X1,
    X2,
    CounterfactualVariable,
    Fraction,
    Intervention,
    P,
    Pi1,
    Pi2,
    Product,
    R,
    Sum,
    Variable,
    W,
    X,
    Y,
    Z,
)
from y0.graph import NxMixedGraph

ImportError: cannot import name '_inconsistent_counterfactual_factor_variable_and_intervention_values' from 'y0.algorithm.counterfactual_transportability' (/Users/cthoyt/dev/y0/src/y0/algorithm/counterfactual_transportability.py)

In [None]:
# [correa22a]_, Figure 1, without the transportability node.
# (This graph represents the target domain, so there is no
# transportability node. Figure 1 may include a transportability
# node because at that point in the paper, the notion of target
# and source domains had not been introduced.)
figure_1_graph_no_transportability_nodes = NxMixedGraph.from_edges(
    directed=[
        (X, Z),
        (Z, Y),
        (X, Y),
    ],
    undirected=[(Z, X)],
)
figure_1_graph_no_transportability_nodes_topo = list(
    figure_1_graph_no_transportability_nodes.topological_sort()
)

# The graph for Domain 1 as described by the text of Example 1.1 and
# Figure 1 of [correa22a]_. The graph isn't in any figures in the
# paper, but a reader can infer it.
figure_1_graph_domain_1_with_interventions = NxMixedGraph.from_edges(
    directed=[(X, Z), (X, Y), (transport_variable(Y), Y)],
    undirected=[],
)
figure_1_graph_domain_1_with_interventions_topo = list(
    figure_1_graph_domain_1_with_interventions.topological_sort()
)

In [None]:
## Example 4.5, get_conditioned_variables_in_ancestral_set
expected_result_1 = frozenset({Z})
result_1 = _get_conditioned_variables_in_ancestral_set(
    conditioned_variables={Z @ -X, X},
    ancestral_set_root_variable=Y @ -X,
    graph=figure_1_graph_no_transportability_nodes,
)
print(str(expected_result_1))
print(str(result_1))

In [None]:
## Get ancestral components
"""First test of a function to compute ancestral components for a graph.

Source: Example 4.5 and Figure 6 of [correa22a]_.
"""
expected_result_1 = frozenset({frozenset({Y @ -X}), frozenset({Z @ -X, X})})
result_1 = _get_ancestral_components(
    conditioned_variables=frozenset({X, Z @ -X}),
    root_variables=frozenset({Y @ -X, Z @ -X, X}),
    graph=figure_1_graph_no_transportability_nodes,
)
print(str(expected_result_1))
print(str(result_1))

In [None]:
## Transport a conditional counterfactual query! Example 4.5.

outcomes = [(Y @ -X, -Y)]
conditions = [(Z @ -X, -Z), (X, +X)]
target_domain_graph = figure_1_graph_no_transportability_nodes
domain_graphs = [
    (
        figure_1_graph_no_transportability_nodes,
        figure_1_graph_no_transportability_nodes_topo,
    ),
    (
        figure_1_graph_domain_1_with_interventions,
        figure_1_graph_domain_1_with_interventions_topo,
    ),
]
domain_data = [(set(), PP[TARGET_DOMAIN](X, Y, Z)), ({X}, PP[Pi1](X, Y, Z))]
# DSL isn't smart enough to replace the denominator with 1
expected_result_expr, expected_result_event = (
    Fraction(PP[TARGET_DOMAIN](Y | X, Z), Sum.safe(PP[TARGET_DOMAIN](Y | X, Z), {Y})),
    [(Y, -Y), (X, +X), (Z, -Z)],
)

result_expr, result_event = transport_conditional_counterfactual_query(
    outcomes=outcomes,
    conditions=conditions,
    target_domain_graph=target_domain_graph,
    domain_graphs=domain_graphs,
    domain_data=domain_data,
)
print("expected_result_expr = " + expected_result_expr.to_latex())
print("Result_expr = " + result_expr.to_latex())
print("Result_event = " + str(result_event))
print("expected_result_event = " + str(expected_result_event))