From efa463bec178210ac8e2b50ec7d970551afa80ec Mon Sep 17 00:00:00 2001 From: Haoyue Dai Date: Thu, 23 Jun 2022 00:57:42 +0800 Subject: [PATCH 1/2] Fix cit.chisq big memory bug --- causallearn/utils/cit.py | 41 +++++++++++++++++++++++++++++++++++----- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/causallearn/utils/cit.py b/causallearn/utils/cit.py index 701b38cc..325846a7 100644 --- a/causallearn/utils/cit.py +++ b/causallearn/utils/cit.py @@ -197,7 +197,6 @@ def chisq_or_gsq_test(dataSXY, cardSXY, G_sq=False): cardSXY: cardinalities of each row (each variable) G_sq: True if use G-sq, otherwise (False by default), use Chi_sq """ - def _Fill2DCountTable(dataXY, cardXY): """ e.g. dataXY: the observed dataset contains 5 samples, on variable x and y they're @@ -224,16 +223,14 @@ def _Fill2DCountTable(dataXY, cardXY): yMarginalCounts = np.sum(xyJointCounts, axis=0) return xyJointCounts, xMarginalCounts, yMarginalCounts - def _Fill3DCountTable(dataSXY, cardSXY): + def _Fill3DCountTable_by_bincount(dataSXY, cardSXY): cardX, cardY = cardSXY[-2:] cardS = np.prod(cardSXY[:-2]) - cardCumProd = np.ones_like(cardSXY) cardCumProd[:-1] = np.cumprod(cardSXY[1:][::-1])[::-1] SxyIndexed = np.dot(cardCumProd[None], dataSXY)[0] SxyJointCounts = np.bincount(SxyIndexed, minlength=cardS * cardX * cardY).reshape((cardS, cardX, cardY)) - SMarginalCounts = np.sum(SxyJointCounts, axis=(1, 2)) SMarginalCountsNonZero = SMarginalCounts != 0 SMarginalCounts = SMarginalCounts[SMarginalCountsNonZero] @@ -243,6 +240,40 @@ def _Fill3DCountTable(dataSXY, cardSXY): SyJointCounts = np.sum(SxyJointCounts, axis=1) return SxyJointCounts, SMarginalCounts, SxJointCounts, SyJointCounts + def _Fill3DCountTable_by_unique(dataSXY, cardSXY): + # Sometimes when the conditioning set contains many variables and each variable's cardinality is large + # e.g. consider an extreme case where + # S contains 7 variables and each's cardinality=20, then cardS = np.prod(cardSXY[:-2]) would be 1280000000 + # i.e., there are 1280000000 different possible combinations of S, + # so the SxyJointCounts array would be of size 1280000000 * cardX * cardY * np.int64, + # i.e., ~3.73TB memory! (suppose cardX, cardX are also 20) + # However, samplesize is usually in 1k-100k scale, far less than cardS, + # i.e., not all (and actually only a very small portion of combinations of S appeared in data) + # i.e., SMarginalCountsNonZero in _Fill3DCountTable_by_bincount is a very sparse array + # So when cardSXY is large, we first re-index S (skip the absent combinations) and then count XY table for each. + # See https://github.com/cmu-phil/causal-learn/pull/37. + cardX, cardY = cardSXY[-2:] + cardSs = cardSXY[:-2] + + cardSsCumProd = np.ones_like(cardSs) + cardSsCumProd[:-1] = np.cumprod(cardSs[1:][::-1])[::-1] + SIndexed = np.dot(cardSsCumProd[None], dataSXY[:-2])[0] + + uniqSIndices, inverseSIndices, SMarginalCounts = np.unique(SIndexed, return_counts=True, return_inverse=True) + cardS_reduced = len(uniqSIndices) + SxyIndexed = inverseSIndices * cardX * cardY + dataSXY[-2] * cardY + dataSXY[-1] + SxyJointCounts = np.bincount(SxyIndexed, minlength=cardS_reduced * cardX * cardY).reshape((cardS_reduced, cardX, cardY)) + + SxJointCounts = np.sum(SxyJointCounts, axis=2) + SyJointCounts = np.sum(SxyJointCounts, axis=1) + return SxyJointCounts, SMarginalCounts, SxJointCounts, SyJointCounts + + def _Fill3DCountTable(dataSXY, cardSXY): + # about the threshold 1e5, see a rough performance example at: + # https://gist.github.com/MarkDana/e7d9663a26091585eb6882170108485e#file-count-unique-in-array-performance-md + if np.prod(cardSXY) < 1e5: return _Fill3DCountTable_by_bincount(dataSXY, cardSXY) + return _Fill3DCountTable_by_unique(dataSXY, cardSXY) + def _CalculatePValue(cTables, eTables): """ calculate the rareness (pValue) of an observation from a given distribution with certain sample size. @@ -280,10 +311,10 @@ def _CalculatePValue(cTables, eTables): xyJointCounts, xMarginalCounts, yMarginalCounts = _Fill2DCountTable(dataSXY, cardSXY) xyExpectedCounts = np.outer(xMarginalCounts, yMarginalCounts) / dataSXY.shape[1] # divide by sample size return _CalculatePValue(xyJointCounts[None], xyExpectedCounts[None]) + # else, S is not empty: conditioning SxyJointCounts, SMarginalCounts, SxJointCounts, SyJointCounts = _Fill3DCountTable(dataSXY, cardSXY) SxyExpectedCounts = SxJointCounts[:, :, None] * SyJointCounts[:, None, :] / SMarginalCounts[:, None, None] - return _CalculatePValue(SxyJointCounts, SxyExpectedCounts) From d4c3916ac1dd8e52ab54f43aef61571b245f4364 Mon Sep 17 00:00:00 2001 From: Haoyue Dai Date: Thu, 23 Jun 2022 22:35:55 +0800 Subject: [PATCH 2/2] Two nits fixed --- causallearn/utils/cit.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/causallearn/utils/cit.py b/causallearn/utils/cit.py index 325846a7..e1671f40 100644 --- a/causallearn/utils/cit.py +++ b/causallearn/utils/cit.py @@ -6,6 +6,8 @@ from causallearn.utils.KCI.KCI import KCI_CInd, KCI_UInd from causallearn.utils.PCUtils import Helper +CONST_BINCOUNT_UNIQUE_THRESHOLD = 1e5 + def kci(data, X, Y, condition_set=None, kernelX='Gaussian', kernelY='Gaussian', kernelZ='Gaussian', est_width='empirical', polyd=2, kwidthx=None, kwidthy=None, kwidthz=None): @@ -223,7 +225,7 @@ def _Fill2DCountTable(dataXY, cardXY): yMarginalCounts = np.sum(xyJointCounts, axis=0) return xyJointCounts, xMarginalCounts, yMarginalCounts - def _Fill3DCountTable_by_bincount(dataSXY, cardSXY): + def _Fill3DCountTableByBincount(dataSXY, cardSXY): cardX, cardY = cardSXY[-2:] cardS = np.prod(cardSXY[:-2]) cardCumProd = np.ones_like(cardSXY) @@ -240,7 +242,7 @@ def _Fill3DCountTable_by_bincount(dataSXY, cardSXY): SyJointCounts = np.sum(SxyJointCounts, axis=1) return SxyJointCounts, SMarginalCounts, SxJointCounts, SyJointCounts - def _Fill3DCountTable_by_unique(dataSXY, cardSXY): + def _Fill3DCountTableByUnique(dataSXY, cardSXY): # Sometimes when the conditioning set contains many variables and each variable's cardinality is large # e.g. consider an extreme case where # S contains 7 variables and each's cardinality=20, then cardS = np.prod(cardSXY[:-2]) would be 1280000000 @@ -271,8 +273,8 @@ def _Fill3DCountTable_by_unique(dataSXY, cardSXY): def _Fill3DCountTable(dataSXY, cardSXY): # about the threshold 1e5, see a rough performance example at: # https://gist.github.com/MarkDana/e7d9663a26091585eb6882170108485e#file-count-unique-in-array-performance-md - if np.prod(cardSXY) < 1e5: return _Fill3DCountTable_by_bincount(dataSXY, cardSXY) - return _Fill3DCountTable_by_unique(dataSXY, cardSXY) + if np.prod(cardSXY) < CONST_BINCOUNT_UNIQUE_THRESHOLD: return _Fill3DCountTableByBincount(dataSXY, cardSXY) + return _Fill3DCountTableByUnique(dataSXY, cardSXY) def _CalculatePValue(cTables, eTables): """