Skip to content

Commit

Permalink
fixed conjunction variable ordering issue, fixed one incorrect nonide…
Browse files Browse the repository at this point in the history
…ntifiable case, more tests
  • Loading branch information
santikka committed Sep 27, 2023
1 parent d6598de commit f76f959
Show file tree
Hide file tree
Showing 16 changed files with 231 additions and 68 deletions.
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: cfid
Type: Package
Title: Identification of Counterfactual Queries in Causal Models
Version: 0.1.5
Version: 0.1.6
Authors@R: c(
person(given = "Santtu",
family = "Tikka",
Expand All @@ -16,12 +16,12 @@ Description: Facilitates the identification of counterfactual queries in
Provides a simple interface for defining causal diagrams and counterfactual
conjunctions. Construction of parallel worlds graphs and counterfactual graphs
is carried out automatically based on the counterfactual query and the causal
diagram.
diagram. See Tikka, S. (2022) for a tutorial of the package <arXiv:2210.14745>.
License: GPL (>= 3)
Encoding: UTF-8
URL: https://github.com/santikka/cfid
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.1
RoxygenNote: 7.2.3
Suggests:
covr,
dagitty,
Expand Down
17 changes: 12 additions & 5 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
# cfid 0.1.6

* Summation variables are now properly distinguished from query variables in the output formulas of `identifiable()`.
* Inputs for `causal_effects()` are now validated.
* Fixed a rare issue that could result in duplicate variables in the output distributions.
* Fixed an issue where some identifiable but inconsistent queries were reported as nonidentifiable.

# cfid 0.1.5

* Added missing value assignments to the query for `causal_effect`.
* Added missing value assignments to the query for `causal_effect()`.

# cfid 0.1.4

* Fixed an issue related to fixed variables in the counterfactual graph.
* Fixed an issue related to counterfactual graph generation and equivalence of random variables.
* Now uses lower case and snake case for classes.
* Added `id` and `idc` algorithms for full identification pipeline. These algorithms can also be used directly via `causal_effect`.
* The syntax used by `dag` is now more flexible, allowing edges within subgraph definitions and nested subgraphs.
* Added the ID and IDC algorithms for a full identification pipeline. These algorithms can be used directly via `causal_effect()`.
* The syntax used by `dag()` is now more flexible, allowing edges within subgraph definitions and nested subgraphs.
* Dropped dependency on R version 4.1.0.
* Improved the package documentation.
* Changed the default value of `var_sep` argument.
Expand All @@ -17,11 +24,11 @@

* An identifiable conditional counterfactual of the form P(gamma)/P(gamma) will have value 1 instead of the formula.
* Fixed an error when identifying conditional counterfactuals that had common counterfactual variables.
* Inputs for `identifiable` are now checked more thoroughly.
* Inputs for `identifiable()` are now checked more thoroughly.
* Added additional package tests.
* Further refined documentation in general.
* Fixed an internal indexing issue when no bidirected edges were present in a DAG.
* Added documentation for `format.probability` with examples.
* Added documentation for `format.probability()` with examples.
* Added citation to Shpitser and Pearl (2007). "What counterfactuals can be tested".
* The package now correctly lists dependency on R >= 4.1.0.
* Added NEWS.
Expand Down
14 changes: 7 additions & 7 deletions R/algorithms.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#' functional as a `functional` or a `probability` object, respectively.
#' @noRd
id_star <- function(g, gamma) {
# ID* algorithm line order changed here to avoid unnecesary recursion
# ID* algorithm line order changed here to avoid unnecessary recursion
# Line 2
if (is_inconsistent(gamma)) {
return(list(id = TRUE, formula = probability(val = 0L)))
Expand All @@ -30,6 +30,7 @@ id_star <- function(g, gamma) {
vars(lab_prime[!attr(g_prime, "latent") & !assigned(lab_prime)])
)
gamma_prime <- tmp$conjunction
gamma_prime[duplicated(gamma_prime)] <- NULL
gamma_var <- vars(gamma)
gamma_obs <- obs(gamma)
gamma_obs_var <- vars(gamma_obs)
Expand All @@ -54,29 +55,28 @@ id_star <- function(g, gamma) {
obs_ix <- which(gamma_obs_var %in% s_sub_j)
if (length(obs_ix) > 0) {
s_val <- unlist(evs(gamma_obs)[obs_ix])
s_ix <- which(s_sub_j %in% gamma_obs_var)
sub_new[s_ix] <- s_val
sub_new[names(s_val)] <- s_val
}
comp[[i]][[j]]$sub <- c(comp[[i]][[j]]$sub, sub_new)
}
}
#s_conj <- do.call(counterfactual_conjunction, comp[[i]])
s_conj <- try(
do.call(counterfactual_conjunction, comp[[i]]), silent = TRUE
)
if (inherits(s_conj, "try-error")) {
return(list(id = TRUE, formula = probability(val = 0L)))
}
c_factors[[i]] <- id_star(g, s_conj)
if (!c_factors[[i]]$id) {
return(list(id = FALSE, formula = NULL))
}
if (is.probability(c_factors[[i]]$formula) &&
length(c_factors[[i]]$formula$val) > 0L &&
c_factors[[i]]$formula$val == 0L) {
return(list(id = TRUE, formula = probability(val = 0L)))
}
}
nonid_factors <- !vapply(c_factors, "[[", logical(1L), "id")
if (any(nonid_factors)) {
return(list(id = FALSE, formula = NULL))
}
sumset <- setdiff(v_g, gamma_var)
form_terms <- lapply(c_factors, "[[", "formula")
if (length(sumset) > 0L) {
Expand Down
68 changes: 31 additions & 37 deletions R/causal_effect.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,39 @@ causal_effect <- function(g, y, x = character(0),
"Argument `v` must have names."
)
}
v_g <- attr(g, "labels")[!attr(g, "latent")]
stopifnot_(
all(y %in% v_g),
paste0(
"Argument `y` contains variables that are not present in `g`: ",
comma_sep(setdiff(y, v_g))
)
)
stopifnot_(
all(x %in% v_g),
paste0(
"Argument `x` contains variables that are not present in `g`: ",
comma_sep(setdiff(x, v_g))
)
)
stopifnot_(
length(z) == 0L || all(z %in% v_g),
paste0(
"Argument `z` contains variables that are not present in `g`: ",
comma_sep(setdiff(z, v_g))
)
)
stopifnot_(
n_v == 0 || all(v_names %in% v_g),
paste0(
"Argument `v` has names that are not present in `g`: ",
comma_sep(setdiff(v_names, v_g))
)
)
n_obs <- sum(!attr(g, "latent"))
if (n_v != n_obs) {
v_temp <- v
v <- set_names(integer(n_obs), attr(g, "labels")[!attr(g, "latent")])
v <- set_names(integer(n_obs), v_g)
v[v_names] <- v_temp
v_names <- names(v)
}
Expand All @@ -66,7 +95,7 @@ causal_effect <- function(g, y, x = character(0),
names(bound) <- v_names
xyz <- c(x, y, z)
bound[xyz] <- bound[xyz] + 1L
out$formula <- assign_values(out$formula, v, v_names, bound)
out$formula <- assign_values(out$formula, bound, v)
}
out$counterfactual <- FALSE
out$causaleffect <- probability(
Expand All @@ -81,38 +110,3 @@ causal_effect <- function(g, y, x = character(0),
class = "query"
)
}

#' Set Value Assignment Levels for a Probability Distribution
#'
#' @param x A `functional` object.
#' @param v A named `integer` vector of values to assign
#' @param v_names A `character` vector of the names of `v`
#' @param bound An `integer` vector counting the number of times specific
#' variables have been bound by summation.
#' @noRd
assign_values <- function(x, v, v_names, bound) {
sumset_vars <- vars(x$sumset)
now_bound <- v_names %in% sumset_vars
sumset_bound <- match(sumset_vars, v_names)
bound[now_bound] <- bound[now_bound] + 1L
for (i in seq_along(x$sumset)) {
x$sumset[[i]]$obs <- -bound[sumset_bound[i]]
}
if (length(x$terms) > 0) {
for (i in seq_along(x$terms)) {
x$terms[[i]] <- assign_values(x$terms[[i]], v, v_names, bound)
}
} else if (length(x$numerator) > 0) {
x$numerator <- assign_values(x$numerator, v, v_names, bound)
x$denominator <- assign_values(x$denominator, v, v_names, bound)
} else {
v[bound > 0] <- -bound[bound > 0]
var <- vars(x$var)
cond <- vars(x$cond)
x <- probability(
var = .mapply(function(a, b) cf(a, b), list(var, v[var]), list()),
cond = .mapply(function(a, b) cf(a, b), list(cond, v[cond]), list())
)
}
x
}
10 changes: 7 additions & 3 deletions R/cf_conjunction.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,11 @@ is.counterfactual_conjunction <- function(x) {
#' [cfid::format.counterfactual_variable()].
#' @export
format.counterfactual_conjunction <- function(x, var_sep = " /\\ ", ...) {
cf <- sapply(x, function(y) format.counterfactual_variable(y, ...))
cf <- vapply(
x,
function(y) format.counterfactual_variable(y, ...),
character(1L)
)
paste0(cf, collapse = var_sep)
}

Expand Down Expand Up @@ -222,7 +226,7 @@ trivial_conflicts <- function(cf_list) {
y_cf <- list(cfvar(y))
z <- which(x %in% y_cf)
if (length(z) > 1L) {
x_vals <- sapply(cf_list[z], "[[", "obs")
x_vals <- vapply(cf_list[z], "[[", integer(1L), "obs")
if (length(unique(x_vals)) > 1L) {
out <- c(out, y_cf)
}
Expand All @@ -243,7 +247,7 @@ trivial_conflict <- function(y, gamma) {
x <- cfvars(gamma)
z <- which(x %in% y_cf)
if (length(z) > 0L) {
xy_vals <- c(sapply(gamma[z], "[[", "obs"), y$obs)
xy_vals <- c(vapply(gamma[z], "[[", integer(1L), "obs"), y$obs)
if (length(unique(xy_vals)) > 1L) {
return(y_cf)
}
Expand Down
14 changes: 11 additions & 3 deletions R/cf_variable.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,19 @@ format.counterfactual_variable <- function(x, use_primes = TRUE, ...) {
} else {
form$var <- x$var
}
if (length(x$sub)) {
if (length(x$sub) > 0L) {
if (use_primes) {
super_sub <- sapply(x$sub, function(y) rep_char("'", y))
super_sub <- vapply(
x$sub,
function(y) rep_char("'", y),
character(1L)
)
} else {
super_sub <- sapply(x$sub, function(y) paste0("^{(", y, ")}"))
super_sub <- vapply(
x$sub,
function(y) paste0("^{(", y, ")}"),
character(1L)
)
super_sub[x$sub == 0L] <- ""
}
form$sub <- paste0(
Expand Down
6 changes: 5 additions & 1 deletion R/cfid-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
#' \eqn{P_*} is the set of all interventional distributions
#' in causal models inducing \eqn{G}. Identification is carried
#' out by the ID* and IDC* algorithms by Shpitser and Pearl (2008) which aim to
#' convenrt the input counterfactual probability into an expression which can
#' convert the input counterfactual probability into an expression which can
#' be represented solely in terms of interventional distributions. These
#' algorithms are sound and complete, meaning that their output is always
#' correct, and in the case of a non-identifiable counterfactual, one can
Expand Down Expand Up @@ -108,4 +108,8 @@
#' Makhlouf, K., Zhioua, S. and Palamidessi, C. (2021).
#' Survey on causal-based machine learning fairness notions.
#' *arXiv:2010.09553*
#'
#' Tikka, S. (2022).
#' Identifying counterfactual queries with the R package cfid.
#' *arXiv:2210.14745*
NULL
17 changes: 16 additions & 1 deletion R/extractors.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,21 @@ vars <- function(gamma) {
vapply(gamma, function(x) x$var, character(1L))
}

#' Determines both variables and interventional variables in a
#' counterfactual conjunction
#'
#' @param gamma A `counterfactual_conjunction` object..
#' @return A `character` vector of variable names.
#' @noRd
all_vars <- function(gamma) {
unique(
c(
vars(gamma),
unlist(lapply(subs(gamma), names))
)
)
}

#' Get the counterfactual variables present in a counterfactual conjunction.
#'
#' @param gamma A `counterfactual_conjunction` object.
Expand Down Expand Up @@ -111,7 +126,7 @@ obs <- function(gamma) {
#'
#' @param x A `counterfactual_variable` object.
#' @param gamma A `counterfactual_conjunction` object.
#' @return An `integer` correspodning to the value assignment if present
#' @return An `integer` corresponding to the value assignment if present.
#' or `NULL`.
#' @noRd
val <- function(x, gamma) {
Expand Down
46 changes: 45 additions & 1 deletion R/functional.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#'
#' Identifying functionals are more complicated probabilistic expressions
#' that cannot be expressed as simple observational or interventional
#* probability using [cfid::probability()].
#' probabilities using [cfid::probability()].
#'
#' When formatted via `print` or `format`, the arguments are
#' prioritized in the following order if conflicting definitions are given:
Expand Down Expand Up @@ -139,4 +139,48 @@ format.functional <- function(x, use_primes = TRUE, use_do = FALSE, ...) {
#' @export
print.functional <- function(x, ...) {
cat(format(x, ...), "\n")
}

#' Set Value Assignment Levels for a Functional
#'
#' @param x A `functional` or a `probability` object.
#' @param v A named `integer` vector of values to assign.
#' @param bound An `integer` vector counting the number of times specific
#' variables have been bound by summation.
#' @noRd
assign_values <- function(x, bound, v, termwise = FALSE) {
sumset_vars <- vars(x$sumset)
v_names <- names(bound)
now_bound <- v_names %in% sumset_vars
sumset_bound <- match(sumset_vars, v_names)
bound[now_bound] <- bound[now_bound] + 1L
for (i in seq_along(x$sumset)) {
x$sumset[[i]]$obs <- -bound[sumset_bound[i]]
}
if (length(x$terms) > 0) {
for (i in seq_along(x$terms)) {
x$terms[[i]] <- assign_values(x$terms[[i]], bound, v, termwise)
}
} else if (length(x$numerator) > 0) {
x$numerator <- assign_values(x$numerator, bound, v, termwise)
x$denominator <- assign_values(x$denominator, bound, v, termwise)
} else {
if (!is.null(x$val)) {
return(x)
}
if (termwise) {
v_term <- unlist(c(evs(x$var), evs(x$cond), evs(x$do)))
v[names(v_term)] <- v_term
}
v[bound > 0] <- -bound[bound > 0]
var <- vars(x$var)
cond <- vars(x$cond)
do <- vars(x$do)
x <- probability(
var = .mapply(function(a, b) cf(a, b), list(var, v[var]), list()),
cond = .mapply(function(a, b) cf(a, b), list(cond, v[cond]), list()),
do = .mapply(function(a, b) cf(a, b), list(do, v[do]), list())
)
}
x
}
5 changes: 3 additions & 2 deletions R/graphs.R
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,8 @@ pwg <- function(g, gamma) {
lat <- attr(g, "latent")
ord <- attr(g, "order")
sub_lst <- unique(subs(gamma))
sub_var <- lapply(sub_lst, function(i) which(lab %in% names(i)))
# sub_var <- lapply(sub_lst, function(i) which(lab %in% names(i)))
sub_var <- lapply(sub_lst, function(i) match(names(i), lab))
n_worlds <- length(sub_lst)
n <- length(lab)
n_unobs <- sum(lat)
Expand Down Expand Up @@ -790,7 +791,7 @@ export_graph <- function(g, type = c("dagitty", "causaleffect", "dosearch"),
out <- NULL
type <- match.arg(type)
lab <- attr(g, "labels")
lab_form <- sapply(lab, format)
lab_form <- vapply(lab, format, character(1L))
lat <- attr(g, "latent")
lat_ix <- which(lat)
e_ix <- which(g > 0L, arr.ind = TRUE)
Expand Down
Loading

0 comments on commit f76f959

Please sign in to comment.