Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add drop_first option to HSGPPeriodic #7115

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion pymc/gp/cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def n_dims(self) -> int:
def _slice(self, X, Xs=None):
xdims = X.shape[-1]
if isinstance(xdims, Variable):
[xdims] = constant_fold([xdims])
[xdims] = constant_fold([xdims], raise_not_constant=False)
if self.input_dim != xdims:
warnings.warn(
f"Only {self.input_dim} column(s) out of {xdims} are"
Expand Down
57 changes: 34 additions & 23 deletions pymc/gp/hsgp_approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,17 @@
Xs: TensorLike,
period: TensorLike,
m: int,
tl: ModuleType = np,
):
"""
Calculate basis vectors for the cosine series expansion of the periodic covariance function.
These are derived from the Taylor series representation of the covariance.
"""
w0 = (2 * np.pi) / period # angular frequency defining the periodicity
m1 = tl.tile(w0 * Xs, m)
m2 = tl.diag(tl.arange(0, m, 1))
w0 = (2 * pt.pi) / period # angular frequency defining the periodicity
m1 = pt.tile(w0 * Xs, m)
m2 = pt.diag(pt.arange(0, m, 1))

Check warning on line 80 in pymc/gp/hsgp_approx.py

View check run for this annotation

Codecov / codecov/patch

pymc/gp/hsgp_approx.py#L78-L80

Added lines #L78 - L80 were not covered by tests
mw0x = m1 @ m2
phi_cos = tl.cos(mw0x)
phi_sin = tl.sin(mw0x)
phi_cos = pt.cos(mw0x)
phi_sin = pt.sin(mw0x)

Check warning on line 83 in pymc/gp/hsgp_approx.py

View check run for this annotation

Codecov / codecov/patch

pymc/gp/hsgp_approx.py#L82-L83

Added lines #L82 - L83 were not covered by tests
return phi_cos, phi_sin


Expand Down Expand Up @@ -473,11 +472,15 @@
self,
m: int,
scale: Optional[Union[float, TensorLike]] = 1.0,
drop_first=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remember to add it to the dostrings ;) Also type: drop_first: bool = Flase

*,
mean_func: Mean = Zero(),
cov_func: Periodic,
):
arg_err_msg = "`m` must be a positive integer as the `Periodic` kernel approximation is only implemented for 1-dimensional case."
arg_err_msg = (

Check warning on line 480 in pymc/gp/hsgp_approx.py

View check run for this annotation

Codecov / codecov/patch

pymc/gp/hsgp_approx.py#L480

Added line #L480 was not covered by tests
"`m` must be a positive integer as the `Periodic` kernel approximation is "
"only implemented for 1-dimensional case."
)

if not isinstance(m, int):
raise ValueError(arg_err_msg)
Expand All @@ -487,7 +490,8 @@

if not isinstance(cov_func, Periodic):
raise ValueError(
"`cov_func` must be an instance of a `Periodic` kernel only. Use the `scale` parameter to control the variance."
"`cov_func` must be an instance of a `Periodic` kernel only. Use the `scale` "
"parameter to control the variance."
)

if cov_func.n_dims > 1:
Expand All @@ -497,6 +501,7 @@

self._m = m
self.scale = scale
self.drop_first = drop_first

Check warning on line 504 in pymc/gp/hsgp_approx.py

View check run for this annotation

Codecov / codecov/patch

pymc/gp/hsgp_approx.py#L504

Added line #L504 was not covered by tests

super().__init__(mean_func=mean_func, cov_func=cov_func)

Expand Down Expand Up @@ -576,8 +581,7 @@
ppc = pm.sample_posterior_predictive(idata, var_names=["f"])
"""
Xs, _ = self.cov_func._slice(Xs)

phi_cos, phi_sin = calc_basis_periodic(Xs, self.cov_func.period, self._m, tl=pt)
phi_cos, phi_sin = calc_basis_periodic(Xs, self.cov_func.period, self._m)

Check warning on line 584 in pymc/gp/hsgp_approx.py

View check run for this annotation

Codecov / codecov/patch

pymc/gp/hsgp_approx.py#L584

Added line #L584 was not covered by tests
J = pt.arange(0, self._m, 1)
# rescale basis coefficients by the sqrt variance term
psd = self.scale * self.cov_func.power_spectral_density_approx(J)
Expand All @@ -602,21 +606,27 @@
(phi_cos, phi_sin), psd = self.prior_linearized(X - self._X_mean)

m = self._m
self._beta = pm.Normal(f"{name}_hsgp_coeffs_", size=(m * 2 - 1))
# The first eigenfunction for the sine component is zero
# and so does not contribute to the approximation.
f = (
self.mean_func(X)
+ phi_cos @ (psd * self._beta[:m]) # type: ignore
+ phi_sin[..., 1:] @ (psd[1:] * self._beta[m:]) # type: ignore
)

self.f = pm.Deterministic(name, f, dims=dims)
if self.drop_first:

Check warning on line 610 in pymc/gp/hsgp_approx.py

View check run for this annotation

Codecov / codecov/patch

pymc/gp/hsgp_approx.py#L610

Added line #L610 was not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder is we would like to inform the user when drop_first=Flase and the fist basis vectors are zero or ones

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree this would be very cool. The basis vectors are made in pytensor. I tried using constant_fold to automatically check if the first sin term is all zeros and drop it automatically, but since X can be mutable it didn't work.

# Drop first sine term (all zeros), drop first cos term (all ones)
beta = pm.Normal(f"{name}_hsgp_coeffs_", size=(m * 2) - 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which one is the old case with size=(m * 2 - 1))? With m*2, you might be accidentally using the same beta twice, or have some redundancy

Copy link
Contributor Author

@bwengals bwengals Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neither, now drop_first=True removes both the first sine and first cos terms, and drop_first=False keeps all terms. So whenever the first sine term is all zeros, then the first cosine term is all ones. My reasoning was that wanting to keep the quasi-intercept cos term while dropping the all zeros sin term was more of an "experts only" move, so someone can do this with prior_linearized. Having all three options as an argument felt kind of cumbersome? What do you think?

With m*2, you might be accidentally using the same beta twice, or have some redundancy

True, need to double check this is all good, especially if someone uses an odd number m.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with the drop_first=True case, but why would anyone want the drop_first=False case with the first sine term when the first eigenfunction is zero? Doesn't that always make one of the beta redundant? I would just have the drop_first=True case and the original case as drop_first=False. I was just following the sums in the maths:

Screenshot 2024-01-24 at 20 54 05

If you go against the maths, please add some comments to explain why. Otherwise pymc devs in 2034 might get confused when they look at appendix b in the original paper.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise pymc devs in 2034

Love the optimism!

Copy link
Contributor Author

@bwengals bwengals Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but why would anyone want the drop_first=False case with the first sine term when the first eigenfunction is zero? Doesn't that always make one of the beta redundant?

Yup it does. OK, I made a dumb mistake. When sin(x) = 0, cos(x) is either 1 or -1 (been a while since trig class), not cos(x) always equals one like I wrote. So it only happens sometimes that the cosine term is all ones. Could be all -1's or be a mix of 1's and -1's. So drop_first = True always dropping both the first sine and cosine term is a bad idea.

To see what I'm trying to explain (sorry kind of poorly), try running the code below, but change period from 1 to 2 to 3.

  • period = 1: sine term all zeros (gotta drop it), cosine term all ones (should probably drop it, or at least be warned? If you already have an intercept in your model the sampler will bog down and you should remove one of them.)
  • period = 2: sine term all zeros (gotta drop it), cosine term alternates -1 and 1 (gotta keep it)
  • period = 3: both terms oscillate (keep both... right? though you point out @theorashid that the math says drop the sine term, so I'm not sure)

Samples of the prior predictive for the gp look good in all cases.

import pymc as pm
import numpy as np
import matplotlib.pyplot as plt
import pytensor
import pytensor.tensor as pt

X = np.linspace(0, 10, 1000)[:, None]

with pm.Model() as model:
    period = 1
    scale = pm.HalfNormal("scale", 10)
    cov_func = pm.gp.cov.Periodic(1, period=period, ls=1.0)

    m = 200
    gp = pm.gp.HSGPPeriodic(m=m, scale=scale, cov_func=cov_func)
    X_mean = np.mean(X, axis=0)
    X = pm.MutableData("X", X)
    Xs = X - X_mean
    
    # FROM pm.gp.hsgp_approx.calc_basis_periodic(Xs, cov_func.period, m=m)
    w0 = (2 * pt.pi) / period  # angular frequency defining the periodicity
    m1 = pt.tile(w0 * Xs, m)
    m2 = pt.diag(pt.arange(0, m, 1))
    mw0x = m1 @ m2

    (phi_cos, phi_sin), psd = gp.prior_linearized(Xs=Xs)

    plt.plot(phi_cos[0, :].eval())
    plt.plot(phi_sin[0, :].eval())

   beta = pm.Normal(f"beta", size=m * 2)
    beta_cos = beta[:m]
    beta_sin = beta[m:]

    cos_term = phi_cos @ (psd * beta_cos)
    sin_term = phi_sin @ (psd * beta_sin)
    f = pm.Deterministic("f", cos_term + sin_term)

    #prior = pm.sample_prior_predictive()
    #import arviz as az
    #f = az.extract(prior.prior, var_names="f")
    #plt.figure(figsize=(10, 4))
    #plt.plot(X.eval().flatten(), f.mean(dim="sample").data)
    #plt.xticks(np.arange(10));

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I stared at the equation again and I can't see how sin(0 * w0 * x) can be anything other than 0. Equally, cos(0 * w0 * x) is always 1 – so there's always an intercept term, I suppose.

So mw0x there for me has shape (1000, 200) (i.e. (X.size, m)). So phi_cos, phi_sin both have shape (1000, 200). So by plotting plt.plot(phi_cos[0, :].eval()) you are plotting something of shape 200 i.e. m. So you're plotting the first term of each basis vector rather than plt.plot(phi_cos[:, 0].eval()) which is the contribution of the first (0th) basis across X. I find that to be 0 for sine and 1 for cos throughout.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are plotting something of shape 200

Yuuuuuup. You're right. I plotted it wrong, so sorry for spreading my confusion around! It looks like I found that the cosine term is somtimes all ones, so the issue with the extra intercept is real. But then plotted it wrong and proceeded to draw lots of incorrect conclusions from there.

What the code should do is either:

  • Sine term is always all zeros, always drop (as the code originally did)
  • Cosine term is sometimes all ones. Allow user to optionally drop it.

Is that right @theorashid?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds right to me. Maybe remove the DeprecationWarning from the HSGP class too, if you want to stick with drop_first consistent after all

self._beta_cos = pt.concatenate(([0.0], beta[: m - 1]))
self._beta_sin = pt.concatenate(([0.0], beta[m - 1 :]))

Check warning on line 614 in pymc/gp/hsgp_approx.py

View check run for this annotation

Codecov / codecov/patch

pymc/gp/hsgp_approx.py#L612-L614

Added lines #L612 - L614 were not covered by tests

else:
# Keep all terms
beta = pm.Normal(f"{name}_hsgp_coeffs_", size=m * 2)
self._beta_cos = beta[:m]
self._beta_sin = beta[m:]

Check warning on line 620 in pymc/gp/hsgp_approx.py

View check run for this annotation

Codecov / codecov/patch

pymc/gp/hsgp_approx.py#L618-L620

Added lines #L618 - L620 were not covered by tests

cos_term = phi_cos @ (psd * self._beta_cos)
sin_term = phi_sin @ (psd * self._beta_sin)
self.f = pm.Deterministic(name, self.mean_func(X) + cos_term + sin_term, dims=dims)

Check warning on line 624 in pymc/gp/hsgp_approx.py

View check run for this annotation

Codecov / codecov/patch

pymc/gp/hsgp_approx.py#L622-L624

Added lines #L622 - L624 were not covered by tests
return self.f

def _build_conditional(self, Xnew):
try:
beta, X_mean = self._beta, self._X_mean
X_mean = self._X_mean

Check warning on line 629 in pymc/gp/hsgp_approx.py

View check run for this annotation

Codecov / codecov/patch

pymc/gp/hsgp_approx.py#L629

Added line #L629 was not covered by tests

except AttributeError:
raise ValueError(
Expand All @@ -625,14 +635,15 @@

Xnew, _ = self.cov_func._slice(Xnew)

phi_cos, phi_sin = calc_basis_periodic(Xnew - X_mean, self.cov_func.period, self._m, tl=pt)
phi_cos, phi_sin = calc_basis_periodic(Xnew - X_mean, self.cov_func.period, self._m)

Check warning on line 638 in pymc/gp/hsgp_approx.py

View check run for this annotation

Codecov / codecov/patch

pymc/gp/hsgp_approx.py#L638

Added line #L638 was not covered by tests
m = self._m
J = pt.arange(0, m, 1)
# rescale basis coefficients by the sqrt variance term
psd = self.scale * self.cov_func.power_spectral_density_approx(J)

phi = phi_cos @ (psd * beta[:m]) + phi_sin[..., 1:] @ (psd[1:] * beta[m:])
return self.mean_func(Xnew) + phi
cos_term = phi_cos @ (psd * self._beta_cos)
sin_term = phi_sin @ (psd * self._beta_sin)
return self.mean_func(Xnew) + cos_term + sin_term

Check warning on line 646 in pymc/gp/hsgp_approx.py

View check run for this annotation

Codecov / codecov/patch

pymc/gp/hsgp_approx.py#L644-L646

Added lines #L644 - L646 were not covered by tests

def conditional(self, name: str, Xnew: TensorLike, dims: Optional[str] = None): # type: ignore
R"""
Expand Down
55 changes: 33 additions & 22 deletions tests/gp/test_hsgp_approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,62 +229,73 @@ def test_conditional(self, model, cov_func, X1, parameterization):

class TestHSGPPeriodic(_BaseFixtures):
def test_parametrization(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To add some drop first tests

err_msg = "`m` must be a positive integer as the `Periodic` kernel approximation is only implemented for 1-dimensional case."
err_msg = (
"`m` must be a positive integer as the `Periodic` kernel approximation is only "
"implemented for 1-dimensional case."
)

with pytest.raises(ValueError, match=err_msg):
# `m` must be a positive integer, not a list
cov_func = pm.gp.cov.Periodic(1, period=1, ls=0.1)
cov_func = pm.gp.cov.Periodic(1, period=1, ls=1.0)
pm.gp.HSGPPeriodic(m=[500], cov_func=cov_func)

with pytest.raises(ValueError, match=err_msg):
# `m`` must be a positive integer
cov_func = pm.gp.cov.Periodic(1, period=1, ls=0.1)
cov_func = pm.gp.cov.Periodic(1, period=1, ls=1.0)
pm.gp.HSGPPeriodic(m=-1, cov_func=cov_func)

with pytest.raises(
ValueError,
match="`cov_func` must be an instance of a `Periodic` kernel only. Use the `scale` parameter to control the variance.",
match=(
"`cov_func` must be an instance of a `Periodic` kernel only. Use the `scale` "
"parameter to control the variance."
),
):
# `cov_func` must be `Periodic` only
cov_func = 5.0 * pm.gp.cov.Periodic(1, period=1, ls=0.1)
cov_func = 5.0 * pm.gp.cov.Periodic(1, period=1, ls=1.0)
pm.gp.HSGPPeriodic(m=500, cov_func=cov_func)

with pytest.raises(
ValueError,
match="HSGP approximation for `Periodic` kernel only implemented for 1-dimensional case.",
match=(
"HSGP approximation for `Periodic` kernel only implemented for 1-dimensional case."
),
):
cov_func = pm.gp.cov.Periodic(2, period=1, ls=[1, 2])
pm.gp.HSGPPeriodic(m=500, scale=0.5, cov_func=cov_func)

@pytest.mark.parametrize("cov_func", [pm.gp.cov.Periodic(1, period=1, ls=1)])
@pytest.mark.parametrize("eta", [100.0])
@pytest.mark.xfail(
reason="For `pm.gp.cov.Periodic`, this test does not pass.\
The mmd is around `0.0468`.\
The test passes more often when subtracting the mean from the mean from the samples.\
It might be that the period is slightly off for the approximate power spectral density.\
See https://github.com/pymc-devs/pymc/pull/6877/ for the full discussion."
)
@pytest.mark.parametrize("eta", [2.0])
# @pytest.mark.xfail(
# reason="For `pm.gp.cov.Periodic`, this test does not pass.\
# The mmd is around `0.0468`.\
# The test passes more often when subtracting the mean from the mean from the samples.\
# It might be that the period is slightly off for the approximate power spectral density.\
# See https://github.com/pymc-devs/pymc/pull/6877/ for the full discussion."
# )
def test_prior(self, model, cov_func, eta, X1, rng):
"""Compare HSGPPeriodic prior to unapproximated GP prior, pm.gp.Latent. Draw samples from the
prior and compare them using MMD two sample test.
"""Compare HSGPPeriodic prior to unapproximated GP prior, pm.gp.Latent. Draw samples from
the prior and compare them using MMD two sample test.
"""
with model:
hsgp = pm.gp.HSGPPeriodic(m=200, scale=eta, cov_func=cov_func)
hsgp = pm.gp.HSGPPeriodic(m=200, scale=eta, drop_first=False, cov_func=cov_func)
f1 = hsgp.prior("f1", X=X1)

gp = pm.gp.Latent(cov_func=eta**2 * cov_func)
f2 = gp.prior("f2", X=X1)

idata = pm.sample_prior_predictive(samples=1000, random_seed=rng)

samples1 = az.extract(idata.prior["f1"])["f1"].values.T
samples2 = az.extract(idata.prior["f2"])["f2"].values.T
samples1 = az.extract(idata.prior, var_names="f1").values.T
samples2 = az.extract(idata.prior, var_names="f2").values.T

h0, mmd, critical_value, reject = two_sample_test(
samples1, samples2, n_sims=500, alpha=0.01
)
assert not reject, f"H0 was rejected, {mmd} even though HSGP and GP priors should match."
assert not reject, (
f"H0 was rejected, MMD {mmd:.3f} > {critical_value:.3f} even though HSGP and GP priors "
"should match."
)

@pytest.mark.parametrize("cov_func", [pm.gp.cov.Periodic(1, period=1, ls=1)])
def test_conditional_periodic(self, model, cov_func, X1):
Expand All @@ -299,8 +310,8 @@ def test_conditional_periodic(self, model, cov_func, X1):

idata = pm.sample_prior_predictive(samples=1000)

samples1 = az.extract(idata.prior["f"])["f"].values.T
samples2 = az.extract(idata.prior["fc"])["fc"].values.T
samples1 = az.extract(idata.prior, var_names="f").values.T
samples2 = az.extract(idata.prior, var_names="fc").values.T

h0, mmd, critical_value, reject = two_sample_test(
samples1, samples2, n_sims=500, alpha=0.01
Expand Down