Skip to content

Commit

Permalink
feat: Add wrapper to GIN algorithm
Browse files Browse the repository at this point in the history
Signed-off-by: Robert Osazuwa Ness <robertness@gmail.com>
Co-Authored-By: Adam Li <adam2392@gmail.com>
  • Loading branch information
robertness and adam2392 committed Jan 13, 2023
1 parent c84ccab commit b327864
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 71 deletions.
15 changes: 15 additions & 0 deletions doc/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,18 @@ @book{Spirtes1993
publisher = {The MIT Press}
}

@article{xie2020generalized,
title = {Generalized independent noise condition for estimating latent variable causal graphs},
author = {Xie, Feng and Cai, Ruichu and Huang, Biwei and Glymour, Clark and Hao, Zhifeng and Zhang, Kun},
journal = {Advances in Neural Information Processing Systems},
volume = {33},
pages = {14891--14902},
year = {2020}
}

@article{dai2022independence,
title = {Independence Testing-Based Approach to Causal Discovery under Measurement Error and Linear Non-Gaussian Models},
author = {Dai, Haoyue and Spirtes, Peter and Zhang, Kun},
journal = {arXiv preprint arXiv:2210.11021},
year = {2022}
}
122 changes: 52 additions & 70 deletions dodiscover/replearning/gin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,91 +3,74 @@
"""

from pandas import DataFrame

from causallearn.search.HiddenCausal.GIN.GIN import GIN as GIN_
from pywhy_graphs import CPDAG
from pywhy_graphs.array.export import clearn_arr_to_graph


class GIN:
"""Wrapper for GIN in the causal-learn package.
The GIN algorithm is a causal discovery algorithm that learns latent
variable structure. We can view it as a causal representation learning
algorithm.
Given an observed set of variables, the GIN algorithm tries to learn a set
of latent parents that d-separate subsets of the observed variables. GIN
will also learn undirected structure between the latent parents.
In GIN, the latent variables are always assumed to be parents of the
observed. Further, it will not learn direct causal edges between
the observed variables. In that sense, we can view it as a causal
representation learning algorithm that learns latent high-level variables
and structure between them from low-level observed variables.
The GIN algorithm assumes a linear non-Gaussian latent variable model
of the observed variables given the latent variables. One should not
expect it to work if the true relationship is Gaussian.
GIN stands for "generalized independent noise" (GIN) condition. Roughly,
the GIN condition is used to divide observed variables into subsets that
are d-separated given the latent variables.
See :footcite:`xie2020generalized` and :footcite:`dai2022independence`
for full details on the algorithm.See https://causal-learn.readthedocs.io
for the causal-learn documentation.
Parameters
----------
indep_test_method : str
The method to use for testing independence, by default "kci"
The method to use for testing independence. The default argument is
"kci" for kernel conditional independence testing. Another option is
"hsic" for the Hilbert Schmidt Independence Criterion. This is a
wrapper for causal-learn's GIN implementation and the causal-learn devs
may or may not add other options in the future.
alpha : float
The significance level for independence tests, by default 0.05
Attributes
----------
graph_ : CPDAG
The estimated causal graph.
causal_learn_graph : CausalGraph
The causal graph object from causal-learn.
causal_ordering : list of str
The causal ordering of the variables.
causal_learn_graph_ : CausalGraph
The causal graph object from causal-learn. Internally, we convert this
to a network-like graph object that supports CPDAGs. This is stored in
the ``graph_`` fitted attribute.
References
----------
.. footbibliography::
"""
def __init__(self, indep_test_method: str = "kci", alpha: float = 0.05):

def __init__(self, ci_estimator_method: str = "kci", alpha: float = 0.05):
"""Initialize GIN object with specified parameters."""

self.graph_ = None
self.graph = None

# GIN default parameters.
self.indep_test_method = indep_test_method
self.ci_estimator_method = ci_estimator_method
self.alpha = alpha
# The follow objects are specific to causal-learn, perhaps they should
# go in a base class too.
self.causal_learn_graph = None
self.causal_ordering = None

def _causal_learn_to_pdag(self, cl_graph):
"""Convert a causal-learn graph to a CPDAG object.
Parameters
----------
cl_graph : CausalGraph
The causal-learn graph to be converted.
Returns
-------
pdag : CPDAG
The equivalent CPDAG object.
"""
def _extract_edgelists(adj_mat, names):
"""Extracts directed and undirected edges from an adjacency matrix.
Parameters:
- adj_mat: numpy array
The adjacency matrix of the graph.
- names: list of str
The names of the nodes in the graph.
Returns:
- directed_edges: list of tuples
The directed edges of the graph.
- undirected_edges: list of sets
The undirected edges of the graph.
"""
directed_edges = []
undirected_edges = []
for i, row in enumerate(adj_mat):
for j, item in enumerate(row):
if item != 0.:
if item == -1. and adj_mat[j][i] == -1.:
undirected_edges.append(set((names[j], names[i])))
if item == 1.:
directed_edges.append((names[j], names[i]))
undirected_edges = list(set(tuple(edge) for edge in undirected_edges))
return directed_edges, undirected_edges

names = [n.name for n in cl_graph.nodes]
adj_mat = cl_graph.graph
directed_edges, undirected_edges = _extract_edgelists(
adj_mat,
names
)
pdag = CPDAG(directed_edges, undirected_edges)
return pdag
self.causal_learn_graph_ = None

def fit(self, data: DataFrame, context: DataFrame):
"""Fit the GIN model to data.
Expand All @@ -105,11 +88,10 @@ def fit(self, data: DataFrame, context: DataFrame):
self : GIN
The fitted GIN object.
"""
causal_learn_graph, ordering = GIN_(
data.to_numpy(),
self.indep_test_method,
self.alpha
)
self.causal_learn_graph = causal_learn_graph
self.causal_ordering = ordering
self.graph_ = self._causal_learn_to_pdag(causal_learn_graph)
from causallearn.search.HiddenCausal.GIN.GIN import GIN as GIN_

causal_learn_graph, _ = GIN_(data.to_numpy(), self.ci_estimator_method, self.alpha)
self.causal_learn_graph_ = causal_learn_graph
names = [n.name for n in causal_learn_graph.nodes]
adj_mat = causal_learn_graph.graph
self.graph = clearn_arr_to_graph(adj_mat, names, "cpdag")
2 changes: 1 addition & 1 deletion tests/unit_tests/replearning/test_gin.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_estimate_gin_testdata():
context = make_context().variables(data=data).build()
gin = GIN()
gin.fit(data, context)
pdag = gin.graph_
pdag = gin.graph

assert nx.is_isomorphic(pdag.sub_undirected_graph(), g_answer.sub_undirected_graph())
assert nx.is_isomorphic(pdag.sub_directed_graph(), g_answer.sub_directed_graph())

0 comments on commit b327864

Please sign in to comment.