Skip to content

Conversation

@MarkDana
Copy link
Collaborator

@MarkDana MarkDana commented Jul 19, 2022

Updated files:

  • cit.py: Last time we rewrite all cit functions into one CIT class, with all methods in one class. This time we further separate each test method into a subclass inherited from a base class CIT_Base.

How to use the new class(es):

  • Code logic is consistent with the current version (ac3f4e7).
  • For users, any algorithms can be run end-to-end in the same way as before , e.g., (see here)
from causallearn.search.ConstraintBased.PC import pc
from causallearn.utils.cit import fisherz

cg = pc(data, 0.05, fisherz)
  • For developers, to declare a cit instance is also the same as before (so no other files are changed):
from causallearn.utils.cit import CIT

fisherz_obj = CIT(data, "fisherz") # construct a CIT instance with data and method name
pValue = fisherz_obj(X, Y, S) # a simple call is ok. no need to consider cache/corr_mat etc. by yourself.

though in code before, CIT is a class while now CIT is a function API that returns the respective class. So an alternative way of writing code above is:

from causallearn.utils.cit import FisherZ
fisherz_obj = FisherZ(data)
pValue = fisherz_obj(X, Y, S)
  • Issues on MVPC's inaccurate fisherz result is solved. It's due to samplesize's change (my fault). Code logic is consistent as before.

  • Functions for cit's resume-from-break-point is added in the CIT_Base class. I will create a new pr for reference.

Test plan:

Same as #46 Rewrite CITests as a class && re-use covariance matrix for fisherz:

python -m unittest TestPC    # should pass
python -m unittest TestFCI    # should pass
python -m unittest TestCDNOD    # should pass
python -m unittest TestMVPC    # should pass
python -m unittest TestMVPC_mv_fisherz_test.py    # should pass
python -m unittest TestMVPC_mv_fisherz.py    # should pass except for test_pc_with_mv_fisherz_MCAR_data_assertion, which cannot pass in the original code.

Copy link
Contributor

@tofuwen tofuwen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the awesome work!

The code looks much better!!

Only some nits comment :)

class CIT(object):
def __init__(self, data, method='fisherz', **kwargs):

def CIT(data, method='fisherz', **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm, I don't like this kargs design, because users doesn't know what to input?

I think the current way is fine (for backward compatibility), and later, maybe we can change it to
cit = FisherZ(data, args) for all caller?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point! And I am also confused about this...

Now **kwargs is placed here to support some user-defined parameters (just at the algorithm call level, e.g., pc, s.t. users don't need to edit codes inside pc). E.g.,

  • pc(data, 0.05, kci, est_width='median') if the user wants to use pc+kci with another kernel width (which is not supported in the old func kci, where all parameters are set by default.)
  • pc(data, 0.05, kci, True, 0, -1, cache_path='/my/path/to/cache.json') to save&load citest cache.

Now I use kwargs because the additional arguments that different methods (e.g., FisherZ, KCI) can take are different, and we have to use func CIT as the entrance for all callers (just for backward compatibility).

args won't work because users still don't know what to input (and even the order)? A perfect way would be to declare cit outside the algorithm call:

kci_obj = KCI(data, kernelZ='Polynomial', est_width='median', cache_path='/my/path/to/cache.json')
pc(data, 0.05, kci_obj, True, 0, -1) # only take the algorithm-related parameters

But this is not backward compatible (for user input). So a compromise might be to still use func CIT as entrance, and just put more instructions (on what parameters are allowable) at CIT's comment and CIT, pc's documents?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I totally agree. The final example is the perfect solution we should pursue eventually.

I think the design is fine for now, but maybe later we may want to change the design and make it better.

I think in order to make our package, we have to do some backward incompatible things --- the current input / output for each algorithm is not even consistent, which is bad... So we need to change things anyway.

How about we add a todo here to remind us what the good design looks like and later when we do huge refactor, we can do it. cc @kunwuz if you'd like to share some feedback.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ic. Cool!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarkDana how about add a todo here to remind us later we want to remove the kargs argument?

assert isinstance(data, np.ndarray), "Input data must be a numpy array."
self.data = data
self.data_hash = hash(str(data))
self.data_hash = hashlib.md5(str(data).encode('utf-8')).hexdigest()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When data is huge, this will be slow?

And when path is None, we don't need to compute data_hash?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be fast (and used only once). np.ndarray.__str__ only returns a preview:

In [41]: data = np.random.randn(10000, 100)

In [42]: str(data)
Out[42]: '[[-1.39552534  1.37974053 -1.1619043  ...  0.13616104 -0.12120668\n  -1.00001339]\n [-0.25197878 -2.00971912  0.63008704 ... -0.97997436 -1.21297862\n   1.42272323]\n [-1.22421999  0.90022162 -1.33748472 ...  1.32908047 -1.37618144\n  -0.28312766]\n ...\n [ 1.71461535  0.10882434  0.08604805 ...  1.34678215 -2.30936746\n   0.76045509]\n [ 0.55727436  0.2203048   0.41242777 ...  0.95881301  0.58538315\n   1.26002782]\n [-0.77753666  0.53018912  0.70592259 ...  0.14847539 -0.60861808\n  -0.36093896]]'

In [43]: len(str(data))
Out[43]: 491

In [44]: timeit hashlib.md5(str(data).encode('utf-8')).hexdigest()
125 µs ± 392 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool

----------
X: int, or np.*int*
Y: int, or np.*int*
condition_set: Iterable<int | np.*int*>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not force it to be List?

self.SAVE_CACHE_CYCLE_SECONDS = 30
self.last_time_cache_saved = time.time()
self.pvalue_cache = {'data_hash': self.data_hash}
if not cache_path is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not cache_path is None:
if cache_path is not None:

check https://stackoverflow.com/questions/2710940/python-if-x-is-not-none-or-if-not-x-is-none

Copy link
Collaborator Author

@MarkDana MarkDana Jul 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, thanks for this!!! Marked!

I always get lost when writing something like this...

self.save_to_local_cache()

METHODS_SUPPORTING_MULTIDIM_DATA = ["kci"]
if condition_set is None: condition_set = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be cleaner to force condition_set is never None?

If the user don't want condition_set, just use []?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I missed this.

Sometimes users might want to input FisherZ(X, Y) to test just unconditional independence.

From the users' end, I feel that this looks better than FisherZ(X, Y, []), or FisherZ(X, Y, ()) ... (which of course also works).

Another reason is that, I'm not sure whether there exist usages like FisherZ(X, Y) in current codes. lol

return [X], [Y], condition_set, _stringize([X], [Y], condition_set)

# also to support multi-dimensional unconditional X, Y (usually in kernel-based tests)
Xs = sorted(set(map(int, X))) if isinstance(X, Iterable) else [int(X)] # sorted for comparison
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not force X to be List type?

Personally I always prefer to make the variable typed --- it can remove lots of potential bug and make the code much cleaner.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In all of our constraint-based methods, X and Y are assumed to be integers.

If we force X as List type, all the codes related to cit calls in pc, fci, ... will need to be changed.

Overall, integers X and Y should always be the first-class citizen in CITests. Multi-dim X Y only works in KCI, and not in constaint-based methods, but somewhere else (e.g., GIN, you name it).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ic, sounds good.

var = Xs + Ys + condition_set
sub_corr_matrix = self.correlation_matrix[np.ix_(var, var)]
try:
inv = np.linalg.inv(sub_corr_matrix)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm, curious, what if exception is thrown here? You didn't catch?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes I didn't catch exceptions. But what would be the possible exceptions here? If it's about the type of X, Y, condition_set, a built-in error message seems to be informative enough.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you "try" then?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh you mean which lines? L155?

This is from the original code at Fixed fisherz test (#58).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ha, ic. I missed the later "except".... nvm

class FisherZ(CIT_Base):
def __init__(self, data, **kwargs):
super().__init__(data, **kwargs)
self.check_cache_method_consistent('fisherz', -1) # -1: no parameters can be specified for fisherz
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, -1 here looks ugly. maybe we can have better design, e.g. use None?

And in the code below, it seems you never write "parameters_hash" to json? Why the assertion not fail?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeh -1 is ugly, lol. A message string ("NO SPECIFIED PARAMETERS") might also be ok.

"parameters_hash" is written to cache (and json). See check_cache_method_consistent.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, a (const) string is much better.

try to avoid to use any magic number like -1 here in your code. :)

return np.unique(column, return_inverse=True)[1]
assert method_name in ['chisq', 'gsq']
super().__init__(np.apply_along_axis(_unique, 0, data).astype(np.int64), **kwargs)
self.check_cache_method_consistent(method_name, -1) # -1: no parameters can be specified for chisq/gsq
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Comment on lines 562 to 565
# result = [[]]
# for pool in lists:
# result = [x + [y] for x in result for y in pool]
# return result





# return result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this to make the code cleaner

Copy link
Contributor

@tofuwen tofuwen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great, I think this PR is ready to be merged! cc @kunwuz

class CIT(object):
def __init__(self, data, method='fisherz', **kwargs):

def CIT(data, method='fisherz', **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarkDana how about add a todo here to remind us later we want to remove the kargs argument?

var = Xs + Ys + condition_set
sub_corr_matrix = self.correlation_matrix[np.ix_(var, var)]
try:
inv = np.linalg.inv(sub_corr_matrix)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ha, ic. I missed the later "except".... nvm

@kunwuz kunwuz merged commit 89e1b78 into py-why:main Jul 25, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants