Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix NaN problem while normailze the data #55

Merged
merged 8 commits into from
Jul 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions causallearn/utils/KCI/KCI.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,11 @@ def kernel_matrix(self, data_x, data_y):
raise Exception('Undefined kernel function')

data_x = stats.zscore(data_x, axis=0)
data_x[np.isnan(data_x)] = 0.

data_y = stats.zscore(data_y, axis=0)
data_y[np.isnan(data_y)] = 0.

Kx = kernelX.kernel(data_x)
Ky = kernelY.kernel(data_y)
return Kx, Ky
Expand Down Expand Up @@ -323,8 +327,13 @@ def kernel_matrix(self, data_x, data_y, data_z):
"""
# normalize the data
data_x = stats.zscore(data_x, axis=0)
data_x[np.isnan(data_x)] = 0.

data_y = stats.zscore(data_y, axis=0)
data_y[np.isnan(data_y)] = 0.

data_z = stats.zscore(data_z, axis=0)
data_z[np.isnan(data_z)] = 0.

# concatenate x and z
data_x = np.concatenate((data_x, 0.5 * data_z), axis=1)
Expand Down
32 changes: 29 additions & 3 deletions causallearn/utils/cit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from math import log, sqrt
from collections.abc import Iterable

import numpy as np
from scipy.stats import chi2, norm
Expand Down Expand Up @@ -59,9 +60,18 @@ def _unique(column):
}

def kci(self, X, Y, condition_set):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarkDana

Currently X and Y here can be int / Iterable? This doesn't sound like a good design --- if possible, we better make it every variable type-checked.

Why not enforce X and Y here to be a list of Int? This is the most general one right? So we don't need those lines just to do type-checking --- it's usually a good idea to make code concise.

But I think this can be changed in your later PR, @MarkDana

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes let me handle this in the later pr, or in the new KCI subclass.

if type(X) == int:
X = [X]
Comment on lines +63 to +64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better with an elif to ensure X is some iterable.

And then X = list(X) - otherwise self.data[:, X] does not support X as e.g., set, or tuple with only one element.

elif type(X) != list:
Y = list(X)
if type(Y) == int:
Y = [Y]
elif type(Y) != list:
Y = list(Y)

if len(condition_set) == 0:
return self.kci_ui.compute_pvalue(self.data[:, [X]], self.data[:, [Y]])[0]
return self.kci_ci.compute_pvalue(self.data[:, [X]], self.data[:, [Y]], self.data[:, list(condition_set)])[0]
return self.kci_ui.compute_pvalue(self.data[:, X], self.data[:, Y])[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great, thanks!

QQ:

  1. what's the reason that the old code only support one variable? Is there some special design? @MarkDana

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tofuwen There is no special design in the old code for supporting only one variable. I just aligned with other tests (used in constraint-based methods), and forgot that KCI can take in multivariate unconditional variables.

I was just about to fix it. So thanks so much for your work @cogito233 !

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also for the cache key in https://github.com/cmu-phil/causal-learn/blob/ffe75f95c4003fa7e9d7d5f3bbec4ace90ed3a41/causallearn/utils/cit.py#L339,

we'll need to handle X < Y for iterable X, Y. And also frozenset(i) as hashable key.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For

https://github.com/cmu-phil/causal-learn/blob/ffe75f95c4003fa7e9d7d5f3bbec4ace90ed3a41/causallearn/utils/cit.py#L338

,
we'll need to handle the case for iterable X, Y.

I have no ideas about how to change the assert in the case of kernel CIT(Concerning the whole or part of X in the conditional set). Other problems are already fixed.

Maybe $X \bot Y | X$ is also a valid expression?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow cool. You're so productive! Thanks for all these!

I see your point. This line was intended for correctness checks in constraint-based methods.

As for the citest itself, is X;Y|X valid? I don't know actually - but the results is expected to be always "independent" (consider X|X as degenerated const).

How about we force X not in condition_set for X: int, and len(set(condition_set).intersection(X)) == 0 for X: Iterable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return self.kci_ci.compute_pvalue(self.data[:, X], self.data[:, Y], self.data[:, list(condition_set)])[0]

def fisherz(self, X, Y, condition_set):
"""
Expand Down Expand Up @@ -326,7 +336,23 @@ def __call__(self, X, Y, condition_set=None, *args):
else:
assert len(args) == 2, "Arguments other than skel and prt_m are provided for mc_fisherz."
if condition_set is None: condition_set = tuple()
assert X not in condition_set and Y not in condition_set, "X, Y cannot be in condition_set."

if type(X) == int and type(Y) == int:
assert X not in condition_set and Y not in condition_set, "X, Y cannot be in condition_set."
else:
if isinstance(X, Iterable):
assert len(set(condition_set).intersection(X)) == 0, "X cannot be in condition_set."
elif isinstance(X, int):
assert X not in condition_set, "X cannot be in condition_set."
else:
raise Exception("Undefined type of X, X should be int or Iterable")
if isinstance(Y, Iterable):
assert len(set(condition_set).intersection(Y)) == 0, "Y cannot be in condition_set."
elif isinstance(Y, int):
assert Y not in condition_set, "Y cannot be in condition_set."
else:
raise Exception("Undefined type of Y, Y should be int or Iterable")

i, j = (X, Y) if (X < Y) else (Y, X)
cache_key = (i, j, frozenset(condition_set))

Expand Down