Skip to content

Conversation

@MarkDana
Copy link
Collaborator

Updated functions:

chisq_or_gsq_test in utils/cit.py.

What is improved:

  • When conditioning set and cardinalities are large: both speed and memory usage get better.
  • When conditioning set and cardinalities are small: same as before.

Why this update is needed:

Consider a chi-squared CITest over discrete variables X and Y given conditioning variables set S. We'll need to count the joint probability table over each occurred configuration of S. In pull request #6, we parallelize this counting cells process and gain huge speedup. However, that implementation is based on an assumption that the cardinalities of variables are usually small (e.g. <5), and the size of S is usually small (e.g. <5).

Yet sometimes the conditioning set may contain many variables and each variable's cardinality is large. Consider a case where S contains 7 variables and each's cardinality=20, then cardS = np.prod(cardSXY[:-2]) would be 1280000000, i.e., there are 1280000000 different possible configuration of S, so the SxyJointCounts (to the line) array would be of size 1280000000 * cardX * cardY * np.int64, i.e., ~3.73TB memory! (suppose cardX, cardY are also 20).

However, the sample size is usually in 1k-100k scale, which is far less than cardS. Not all (and actually only a very small portion of configurations of S appeared in data), i.e., SMarginalCountsNonZero (to the line) is a very sparse array.

Hence, when cardSXY is large, we first re-index S (skip the absent configurations) and then count the joint XY table for each configuration. Specifically, two functions _Fill3DCountTable_by_bincount and _Fill3DCountTable_by_unique are used under different scale of cardSXY.

Testing:

  • Code:
from causallearn.utils.cit_new import chisq as chisq_new
from causallearn.utils.cit_old import chisq as chisq_old
import time, tracemalloc

data = np.random.randint(0, 20, (8, 20000))
X, Y = 0, 1
S = tuple(range(2, 8))
# S contains 6 variables. Each has a cardinality of 20.

tracemalloc.start()
tic = time.time()
p_new = chisq_new(data, X, Y, S)
tac = time.time()
_, peak_new = tracemalloc.get_traced_memory()
tracemalloc.stop()
print(f'new: used {tac - tic:}s, peak memory {peak_new / 1024 ** 2} MiB')

tracemalloc.start()
tic = time.time()
p_old = chisq_old(data, X, Y, S)
tac = time.time()
_, peak_old = tracemalloc.get_traced_memory()
tracemalloc.stop()
print(f'old: used {tac - tic:}s, peak memory {peak_old / 1024 ** 2} MiB')

assert p_old == p_new
  • Results (on Apple M1 Max):
new: used 0.0003800392150878906s, peak memory 0.307403564453125 MiB
old: used 11.38077425956726s, peak memory 70793.66288280487 MiB

Empirical threshold: how to choose between np.bincount and np.unique?

Refer here:

time
memory

return xyJointCounts, xMarginalCounts, yMarginalCounts

def _Fill3DCountTable(dataSXY, cardSXY):
def _Fill3DCountTable_by_bincount(dataSXY, cardSXY):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: usually we use one consistent style for naming:

either _Fill3DCountTableByBincount()
or _fill_3D_count_table_by_bincount()

Copy link
Contributor

@tofuwen tofuwen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code looks great, thanks for the great work!

When submitting PR, you can add reviewers. So reviewers will get automatic notifications?

def _Fill3DCountTable(dataSXY, cardSXY):
# about the threshold 1e5, see a rough performance example at:
# https://gist.github.com/MarkDana/e7d9663a26091585eb6882170108485e#file-count-unique-in-array-performance-md
if np.prod(cardSXY) < 1e5: return _Fill3DCountTable_by_bincount(dataSXY, cardSXY)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It would be better to define 1e5 as a constant (with meaningful naming).

Numbers like this are called "magic number in code", and generally we should avoid that.

https://codeburst.io/software-anti-patterns-magic-numbers-7bc484f40544

@MarkDana
Copy link
Collaborator Author

@tofuwen Hi Yewen, the two nits are fixed. Thanks for your comments :))

@kunwuz kunwuz merged commit c4e9299 into py-why:main Jun 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants