Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions causallearn/utils/cit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from causallearn.utils.KCI.KCI import KCI_CInd, KCI_UInd
from causallearn.utils.PCUtils import Helper

CONST_BINCOUNT_UNIQUE_THRESHOLD = 1e5


def kci(data, X, Y, condition_set=None, kernelX='Gaussian', kernelY='Gaussian', kernelZ='Gaussian',
est_width='empirical', polyd=2, kwidthx=None, kwidthy=None, kwidthz=None):
Expand Down Expand Up @@ -197,7 +199,6 @@ def chisq_or_gsq_test(dataSXY, cardSXY, G_sq=False):
cardSXY: cardinalities of each row (each variable)
G_sq: True if use G-sq, otherwise (False by default), use Chi_sq
"""

def _Fill2DCountTable(dataXY, cardXY):
"""
e.g. dataXY: the observed dataset contains 5 samples, on variable x and y they're
Expand All @@ -224,16 +225,14 @@ def _Fill2DCountTable(dataXY, cardXY):
yMarginalCounts = np.sum(xyJointCounts, axis=0)
return xyJointCounts, xMarginalCounts, yMarginalCounts

def _Fill3DCountTable(dataSXY, cardSXY):
def _Fill3DCountTableByBincount(dataSXY, cardSXY):
cardX, cardY = cardSXY[-2:]
cardS = np.prod(cardSXY[:-2])

cardCumProd = np.ones_like(cardSXY)
cardCumProd[:-1] = np.cumprod(cardSXY[1:][::-1])[::-1]
SxyIndexed = np.dot(cardCumProd[None], dataSXY)[0]

SxyJointCounts = np.bincount(SxyIndexed, minlength=cardS * cardX * cardY).reshape((cardS, cardX, cardY))

SMarginalCounts = np.sum(SxyJointCounts, axis=(1, 2))
SMarginalCountsNonZero = SMarginalCounts != 0
SMarginalCounts = SMarginalCounts[SMarginalCountsNonZero]
Expand All @@ -243,6 +242,40 @@ def _Fill3DCountTable(dataSXY, cardSXY):
SyJointCounts = np.sum(SxyJointCounts, axis=1)
return SxyJointCounts, SMarginalCounts, SxJointCounts, SyJointCounts

def _Fill3DCountTableByUnique(dataSXY, cardSXY):
# Sometimes when the conditioning set contains many variables and each variable's cardinality is large
# e.g. consider an extreme case where
# S contains 7 variables and each's cardinality=20, then cardS = np.prod(cardSXY[:-2]) would be 1280000000
# i.e., there are 1280000000 different possible combinations of S,
# so the SxyJointCounts array would be of size 1280000000 * cardX * cardY * np.int64,
# i.e., ~3.73TB memory! (suppose cardX, cardX are also 20)
# However, samplesize is usually in 1k-100k scale, far less than cardS,
# i.e., not all (and actually only a very small portion of combinations of S appeared in data)
# i.e., SMarginalCountsNonZero in _Fill3DCountTable_by_bincount is a very sparse array
# So when cardSXY is large, we first re-index S (skip the absent combinations) and then count XY table for each.
# See https://github.com/cmu-phil/causal-learn/pull/37.
cardX, cardY = cardSXY[-2:]
cardSs = cardSXY[:-2]

cardSsCumProd = np.ones_like(cardSs)
cardSsCumProd[:-1] = np.cumprod(cardSs[1:][::-1])[::-1]
SIndexed = np.dot(cardSsCumProd[None], dataSXY[:-2])[0]

uniqSIndices, inverseSIndices, SMarginalCounts = np.unique(SIndexed, return_counts=True, return_inverse=True)
cardS_reduced = len(uniqSIndices)
SxyIndexed = inverseSIndices * cardX * cardY + dataSXY[-2] * cardY + dataSXY[-1]
SxyJointCounts = np.bincount(SxyIndexed, minlength=cardS_reduced * cardX * cardY).reshape((cardS_reduced, cardX, cardY))

SxJointCounts = np.sum(SxyJointCounts, axis=2)
SyJointCounts = np.sum(SxyJointCounts, axis=1)
return SxyJointCounts, SMarginalCounts, SxJointCounts, SyJointCounts

def _Fill3DCountTable(dataSXY, cardSXY):
# about the threshold 1e5, see a rough performance example at:
# https://gist.github.com/MarkDana/e7d9663a26091585eb6882170108485e#file-count-unique-in-array-performance-md
if np.prod(cardSXY) < CONST_BINCOUNT_UNIQUE_THRESHOLD: return _Fill3DCountTableByBincount(dataSXY, cardSXY)
return _Fill3DCountTableByUnique(dataSXY, cardSXY)

def _CalculatePValue(cTables, eTables):
"""
calculate the rareness (pValue) of an observation from a given distribution with certain sample size.
Expand Down Expand Up @@ -280,10 +313,10 @@ def _CalculatePValue(cTables, eTables):
xyJointCounts, xMarginalCounts, yMarginalCounts = _Fill2DCountTable(dataSXY, cardSXY)
xyExpectedCounts = np.outer(xMarginalCounts, yMarginalCounts) / dataSXY.shape[1] # divide by sample size
return _CalculatePValue(xyJointCounts[None], xyExpectedCounts[None])

# else, S is not empty: conditioning
SxyJointCounts, SMarginalCounts, SxJointCounts, SyJointCounts = _Fill3DCountTable(dataSXY, cardSXY)
SxyExpectedCounts = SxJointCounts[:, :, None] * SyJointCounts[:, None, :] / SMarginalCounts[:, None, None]

return _CalculatePValue(SxyJointCounts, SxyExpectedCounts)


Expand Down