Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix NaN problem while normailze the data #55

Merged
merged 8 commits into from
Jul 12, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 24 additions & 5 deletions causallearn/utils/KCI/KCI.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,16 @@ def kernel_matrix(self, data_x, data_y):
else:
raise Exception('Undefined kernel function')

data_x = stats.zscore(data_x, axis=0)
data_y = stats.zscore(data_y, axis=0)
if np.var(data_x) == 0:
data_x -= np.average(data_x)
else:
data_x = stats.zscore(data_x, axis=0)
Copy link
Collaborator

@MarkDana MarkDana Jul 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Current np.var(data_x) does not support multi-dim data_x with some dims being constant (while others not), or all dims are constant but with different values:
In [7]: data_x = np.random.randn(100, 1)

In [8]: data_y = np.ones_like(data_x) # constant

In [9]: data_xy = np.hstack([data_x, data_y])

In [10]: stats.zscore(data_xy, axis=0)
Out[10]: 
array([[ 4.15513535e-01,             nan],
       [-1.71903423e+00,             nan],
       [ 7.59493517e-01,             nan],
       [-1.34182046e+00,             nan],
       ...
  • For numpy floating points, how to reliably identify a constant array? var=0 is not enough:
In [16]: arr1 = np.array([1., 1., 1.])

In [17]: np.var(arr1)
Out[17]: 0.0

In [18]: stats.zscore(arr1)
Out[18]: array([nan, nan, nan])

#########

In [19]: arr2 = np.array([-0.087, -0.087, -0.087])

In [20]: np.var(arr2)
Out[20]: 1.925929944387236e-34

In [21]: np.var(arr2) == 0
Out[21]: False

In [22]: stats.zscore(arr2)
Out[22]: array([nan, nan, nan]) # though np.var != 0, here it still runs to stats.zscore and returns nan
  • Based on above, can we just mask nan values to zero after stats.zscore?
data_x = stats.zscore(data_x, axis=0)
data_x[np.isnan(data_x)] = 0.

This operation is safe, since we would (expect to) check any raw data_x (before kci) so it does not contain nan values, the nan returned in the normalized array is only due to constant (not original nan values).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This operation is safe, since we would (expect to) check any raw data_x (before kci) so it does not contain nan values, the nan returned in the normalized array is only due to constant (not original nan values).

Do we ensure this in code, i.e. ensure raw data_x doesn't contain nan value?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This operation is safe, since we would (expect to) check any raw data_x (before kci) so it does not contain nan values, the nan returned in the normalized array is only due to constant (not original nan values).

Do we ensure this in code, i.e. ensure raw data_x doesn't contain nan value?

Not yet. I'll do this soon after this pr is merged.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for your awesome contributions, @cogito233 , @MarkDana and @tofuwen ! @cogito233 Please let us know if you think the current PR is ready to go so we will solve the remaining issues in a new PR. Or we could include these in this PR if you would like to. :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is already fixed. I guess maybe we can merge now~

Many thanks to all of your(@kunwuz, @MarkDana, @tofuwen ) help~ Since I am at the first time contributing to a community codebase and lack experience.

Thanks so much for your awesome contributions, @cogito233 , @MarkDana and @tofuwen ! @cogito233 Please let us know if you think the current PR is ready to go so we will solve the remaining issues in a new PR. Or we could include these in this PR if you would like to. :-)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much! @tofuwen @MarkDana When you have time, could you please make a final check on the current PR and let me know when it's ready to be merged? Many thanks!


if np.var(data_y) == 0:
data_y -= np.average(data_y)
else:
data_y = stats.zscore(data_y, axis=0)

Kx = kernelX.kernel(data_x)
Ky = kernelY.kernel(data_y)
return Kx, Ky
Expand Down Expand Up @@ -322,9 +330,20 @@ def kernel_matrix(self, data_x, data_y, data_z):
kzy: centering kernel matrix for data_y (nxn)
"""
# normalize the data
data_x = stats.zscore(data_x, axis=0)
data_y = stats.zscore(data_y, axis=0)
data_z = stats.zscore(data_z, axis=0)
if np.var(data_x) == 0:
data_x -= np.average(data_x)
else:
data_x = stats.zscore(data_x, axis=0)

if np.var(data_y) == 0:
data_y -= np.average(data_y)
else:
data_y = stats.zscore(data_y, axis=0)

if np.var(data_z) == 0:
data_z -= np.average(data_z)
else:
data_z = stats.zscore(data_z, axis=0)

# concatenate x and z
data_x = np.concatenate((data_x, 0.5 * data_z), axis=1)
Expand Down
32 changes: 29 additions & 3 deletions causallearn/utils/cit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from math import log, sqrt
from collections.abc import Iterable

import numpy as np
from scipy.stats import chi2, norm
Expand Down Expand Up @@ -59,9 +60,18 @@ def _unique(column):
}

def kci(self, X, Y, condition_set):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarkDana

Currently X and Y here can be int / Iterable? This doesn't sound like a good design --- if possible, we better make it every variable type-checked.

Why not enforce X and Y here to be a list of Int? This is the most general one right? So we don't need those lines just to do type-checking --- it's usually a good idea to make code concise.

But I think this can be changed in your later PR, @MarkDana

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes let me handle this in the later pr, or in the new KCI subclass.

if type(X) == int:
X = [X]
Comment on lines +63 to +64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better with an elif to ensure X is some iterable.

And then X = list(X) - otherwise self.data[:, X] does not support X as e.g., set, or tuple with only one element.

elif type(X) != list:
Y = list(X)
if type(Y) == int:
Y = [Y]
elif type(Y) != list:
Y = list(Y)

if len(condition_set) == 0:
return self.kci_ui.compute_pvalue(self.data[:, [X]], self.data[:, [Y]])[0]
return self.kci_ci.compute_pvalue(self.data[:, [X]], self.data[:, [Y]], self.data[:, list(condition_set)])[0]
return self.kci_ui.compute_pvalue(self.data[:, X], self.data[:, Y])[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great, thanks!

QQ:

  1. what's the reason that the old code only support one variable? Is there some special design? @MarkDana

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tofuwen There is no special design in the old code for supporting only one variable. I just aligned with other tests (used in constraint-based methods), and forgot that KCI can take in multivariate unconditional variables.

I was just about to fix it. So thanks so much for your work @cogito233 !

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also for the cache key in https://github.com/cmu-phil/causal-learn/blob/ffe75f95c4003fa7e9d7d5f3bbec4ace90ed3a41/causallearn/utils/cit.py#L339,

we'll need to handle X < Y for iterable X, Y. And also frozenset(i) as hashable key.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For

https://github.com/cmu-phil/causal-learn/blob/ffe75f95c4003fa7e9d7d5f3bbec4ace90ed3a41/causallearn/utils/cit.py#L338

,
we'll need to handle the case for iterable X, Y.

I have no ideas about how to change the assert in the case of kernel CIT(Concerning the whole or part of X in the conditional set). Other problems are already fixed.

Maybe $X \bot Y | X$ is also a valid expression?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow cool. You're so productive! Thanks for all these!

I see your point. This line was intended for correctness checks in constraint-based methods.

As for the citest itself, is X;Y|X valid? I don't know actually - but the results is expected to be always "independent" (consider X|X as degenerated const).

How about we force X not in condition_set for X: int, and len(set(condition_set).intersection(X)) == 0 for X: Iterable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return self.kci_ci.compute_pvalue(self.data[:, X], self.data[:, Y], self.data[:, list(condition_set)])[0]

def fisherz(self, X, Y, condition_set):
"""
Expand Down Expand Up @@ -326,7 +336,23 @@ def __call__(self, X, Y, condition_set=None, *args):
else:
assert len(args) == 2, "Arguments other than skel and prt_m are provided for mc_fisherz."
if condition_set is None: condition_set = tuple()
assert X not in condition_set and Y not in condition_set, "X, Y cannot be in condition_set."

if type(X) == int and type(Y) == int:
assert X not in condition_set and Y not in condition_set, "X, Y cannot be in condition_set."
else:
if isinstance(X, Iterable):
assert len(set(condition_set).intersection(X)) == 0, "X cannot be in condition_set."
elif isinstance(X, int):
assert X not in condition_set, "X cannot be in condition_set."
else:
raise Exception("Undefined type of X, X should be int or Iterable")
if isinstance(Y, Iterable):
assert len(set(condition_set).intersection(Y)) == 0, "Y cannot be in condition_set."
elif isinstance(Y, int):
assert Y not in condition_set, "Y cannot be in condition_set."
else:
raise Exception("Undefined type of Y, Y should be int or Iterable")

i, j = (X, Y) if (X < Y) else (Y, X)
cache_key = (i, j, frozenset(condition_set))

Expand Down