Skip to content

Commit

Permalink
[Done] Add various kernels for GP module (#805)
Browse files Browse the repository at this point in the history
* add various kernels

* add various kernel

* make lint

* small bug at isotropy

* lint

* add tests

* fix bug NaN at sqrt
  • Loading branch information
fehiepsi authored and eb8680 committed Feb 28, 2018
1 parent a72f9ac commit 98752a9
Show file tree
Hide file tree
Showing 10 changed files with 606 additions and 85 deletions.
40 changes: 38 additions & 2 deletions docs/source/contrib.gp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,31 @@ Kernels
:show-inheritance:
:member-order: bysource

.. automodule:: pyro.contrib.gp.kernels.rbf
.. automodule:: pyro.contrib.gp.kernels.brownian
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

.. automodule:: pyro.contrib.gp.kernels.dot_product
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

.. automodule:: pyro.contrib.gp.kernels.isotropic
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

.. automodule:: pyro.contrib.gp.kernels.periodic
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

.. automodule:: pyro.contrib.gp.kernels.static
:members:
:undoc-members:
:show-inheritance:
Expand Down Expand Up @@ -53,9 +77,21 @@ Models
:undoc-members:
:show-inheritance:
:member-order: bysource

.. automodule:: pyro.contrib.gp.models.sgpr
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

.. automodule:: pyro.contrib.gp.models.vgp
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

.. automodule:: pyro.contrib.gp.models.svgp
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
7 changes: 6 additions & 1 deletion pyro/contrib/gp/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from __future__ import absolute_import, division, print_function

from .brownian import Brownian
from .dot_product import DotProduct, Linear, Polynomial
from .isotropic import (Exponential, Isotropy, Matern12, Matern32, Matern52,
RationalQuadratic, RBF, SquaredExponential)
from .kernel import Kernel
from .rbf import RBF
from .periodic import Cosine, Periodic, SineSquaredExponential
from .static import Bias, Constant, WhiteNoise

# flake8: noqa
50 changes: 50 additions & 0 deletions pyro/contrib/gp/kernels/brownian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from __future__ import absolute_import, division, print_function

import torch
from torch.autograd import Variable
from torch.distributions import constraints
from torch.nn import Parameter

from .kernel import Kernel


class Brownian(Kernel):
"""
This kernel correponds to a two-sided Brownion motion (Wiener process):
``k(x, z) = min(|x|,|z|)`` if ``x.z >= 0`` and ``k(x, z) = 0`` otherwise.
Note that the input dimension of this kernel must be 1.
References:
[1] `Theory and Statistical Applications of Stochastic Processes`,
Yuliya Mishura, Georgiy Shevchenko
"""

def __init__(self, input_dim, variance=None, active_dims=None, name="Brownian"):
if input_dim != 1:
raise ValueError("Input dimensional for Brownian kernel must be 1.")
super(Brownian, self).__init__(input_dim, active_dims, name)

if variance is None:
variance = torch.ones(1)
self.variance = Parameter(variance)
self.set_constraint("variance", constraints.positive)

def forward(self, X, Z=None, diag=False):
variance = self.get_param("variance")

if Z is None:
Z = X
X = self._slice_input(X)
if diag:
return variance * X.abs().squeeze(1)

Z = self._slice_input(Z)
if X.size(1) != Z.size(1):
raise ValueError("Inputs must have the same number of features.")

Zt = Z.t()
return torch.where(X.sign() == Zt.sign(),
variance * torch.min(X.abs(), Zt.abs()),
Variable(X.data.new(X.size(0), Z.size(0)).zero_()))
82 changes: 82 additions & 0 deletions pyro/contrib/gp/kernels/dot_product.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from __future__ import absolute_import, division, print_function

import torch
from torch.distributions import constraints
from torch.nn import Parameter

from .kernel import Kernel


class DotProduct(Kernel):
"""
Base kernel for Linear and Polonomial kernels.
:param torch.Tensor variance: Variance parameter which plays the role of scaling.
"""

def __init__(self, input_dim, variance=None, active_dims=None, name=None):
super(DotProduct, self).__init__(input_dim, active_dims, name)

if variance is None:
variance = torch.ones(1)
self.variance = Parameter(variance)
self.set_constraint("variance", constraints.positive)

def _dot_product(self, X, Z=None, diag=False):
"""
Returns ``X.Z``.
"""
if Z is None:
Z = X
X = self._slice_input(X)
if diag:
return (X ** 2).sum(-1)

Z = self._slice_input(Z)
if X.size(1) != Z.size(1):
raise ValueError("Inputs must have the same number of features.")

return X.matmul(Z.t())


class Linear(DotProduct):
"""
Implementation of Linear kernel. Doing Gaussian Process Regression with linear kernel
is equivalent to Linear Regression.
Note that here we implement the homogeneous version. To use the inhomogeneous version,
consider using Polynomial kernel with `degree=1` or making a combination with a Bias kernel.
"""

def __init__(self, input_dim, variance=None, active_dims=None, name="Linear"):
super(Linear, self).__init__(input_dim, variance, active_dims, name)

def forward(self, X, Z=None, diag=False):
variance = self.get_param("variance")
return variance * self._dot_product(X, Z, diag)


class Polynomial(DotProduct):
"""
Implementation of Polynomial kernel: ``k(x, z) = (bias + x.z)^d``.
:param torch.Tensor bias: Bias parameter for this kernel. Should be positive.
:param int degree: Degree of this polynomial.
"""

def __init__(self, input_dim, variance=None, bias=None, degree=1, active_dims=None, name="Polynomial"):
super(Polynomial, self).__init__(input_dim, variance, active_dims, name)

if bias is None:
bias = torch.ones(1)
self.bias = Parameter(bias)
self.set_constraint("bias", constraints.positive)

if degree < 1:
raise ValueError("Degree for Polynomial kernel should be a positive integer.")
self.degree = degree

def forward(self, X, Z=None, diag=False):
variance = self.get_param("variance")
bias = self.get_param("bias")
return variance * ((bias + self._dot_product(X, Z, diag)) ** self.degree)
191 changes: 191 additions & 0 deletions pyro/contrib/gp/kernels/isotropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
from __future__ import absolute_import, division, print_function

import torch
from torch.distributions import constraints
from torch.nn import Parameter

from .kernel import Kernel


def _torch_sqrt(x, eps=1e-18):
"""
A convenient function to avoid the NaN gradient issue of ``torch.sqrt`` at 0.
"""
# Ref: https://github.com/pytorch/pytorch/issues/2421
return (x + eps).sqrt()


class Isotropy(Kernel):
"""
Base kernel for a family of isotropic covariance functions which is a
function of the distance ``r=|x-z|``.
By default, the parameter ``lengthscale`` has size 1. To use the
anisotropic version (different lengthscale for each dimension),
make sure that lengthscale has size equal to ``input_dim``.
:param torch.Tensor variance: Variance parameter of this kernel.
:param torch.Tensor lengthscale: Length scale parameter of this kernel.
"""

def __init__(self, input_dim, variance=None, lengthscale=None, active_dims=None, name=None):
super(Isotropy, self).__init__(input_dim, active_dims, name)

if variance is None:
variance = torch.ones(1)
self.variance = Parameter(variance)
self.set_constraint("variance", constraints.positive)

if lengthscale is None:
lengthscale = torch.ones(1)
self.lengthscale = Parameter(lengthscale)
self.set_constraint("lengthscale", constraints.positive)

def _square_scaled_dist(self, X, Z=None):
"""
Returns ``||(X-Z)/lengthscale||^2``.
"""
if Z is None:
Z = X
X = self._slice_input(X)
Z = self._slice_input(Z)
if X.size(1) != Z.size(1):
raise ValueError("Inputs must have the same number of features.")

lengthscale = self.get_param("lengthscale")
scaled_X = X / lengthscale
scaled_Z = Z / lengthscale
X2 = (scaled_X ** 2).sum(1, keepdim=True)
Z2 = (scaled_Z ** 2).sum(1, keepdim=True)
XZ = scaled_X.matmul(scaled_Z.t())
r2 = X2 - 2 * XZ + Z2.t()
return r2

def _scaled_dist(self, X, Z=None):
"""
Returns ``||(X-Z)/lengthscale||``.
"""
return _torch_sqrt(self._square_scaled_dist(X, Z))

def _diag(self, X):
"""
Calculates the diagonal part of covariance matrix on active dimensionals.
"""
variance = self.get_param("variance")
return variance.expand(X.size(0))


class RBF(Isotropy):
"""
Implementation of Radial Basis Function kernel: ``exp(-0.5 * r^2 / l^2)``.
"""

def __init__(self, input_dim, variance=None, lengthscale=None, active_dims=None, name="RBF"):
super(RBF, self).__init__(input_dim, variance, lengthscale, active_dims, name)

def forward(self, X, Z=None, diag=False):
if diag:
return self._diag(X)

variance = self.get_param("variance")
r2 = self._square_scaled_dist(X, Z)
return variance * torch.exp(-0.5 * r2)


class SquaredExponential(RBF):
"""
Squared Exponential is another name for RBF kernel.
"""

def __init__(self, input_dim, variance=None, lengthscale=None, active_dims=None,
name="SquaredExponential"):
super(SquaredExponential, self).__init__(input_dim, variance, lengthscale, active_dims, name)


class RationalQuadratic(Isotropy):
"""
Implementation of Rational Quadratic kernel: ``(1 + 0.5 * r^2 / alpha l^2)^(-alpha)``.
:param torch.Tensor scale_mixture: Scale mixture (alpha) parameter of this kernel.
Should have size 1.
"""

def __init__(self, input_dim, variance=None, lengthscale=None, scale_mixture=None, active_dims=None,
name="RationalQuadratic"):
super(RationalQuadratic, self).__init__(input_dim, variance, lengthscale, active_dims, name)

if scale_mixture is None:
scale_mixture = torch.ones(1)
self.scale_mixture = Parameter(scale_mixture)
self.set_constraint("scale_mixture", constraints.positive)

def forward(self, X, Z=None, diag=False):
if diag:
return self._diag(X)

variance = self.get_param("variance")
scale_mixture = self.get_param("scale_mixture")
r2 = self._square_scaled_dist(X, Z)
return variance * (1 + (0.5 / scale_mixture) * r2).pow(-scale_mixture)


class Exponential(Isotropy):
"""
Implementation of Exponential kernel: `exp(-r/l)`.
"""

def __init__(self, input_dim, variance=None, lengthscale=None, active_dims=None, name="Exponential"):
super(Exponential, self).__init__(input_dim, variance, lengthscale, active_dims, name)

def forward(self, X, Z=None, diag=False):
if diag:
return self._diag(X)

variance = self.get_param("variance")
r = self._scaled_dist(X, Z)
return variance * torch.exp(-r)


class Matern12(Exponential):
"""
Another name of Exponential kernel.
"""
def __init__(self, input_dim, variance=None, lengthscale=None, active_dims=None, name="Matern12"):
super(Matern12, self).__init__(input_dim, variance, lengthscale, active_dims, name)


class Matern32(Isotropy):
"""
Implementation of Matern32 kernel: ``(1 + sqrt(3) * r/l) * exp(-sqrt(3) * r/l)``.
"""

def __init__(self, input_dim, variance=None, lengthscale=None, active_dims=None, name="Matern32"):
super(Matern32, self).__init__(input_dim, variance, lengthscale, active_dims, name)

def forward(self, X, Z=None, diag=False):
if diag:
return self._diag(X)

variance = self.get_param("variance")
r = self._scaled_dist(X, Z)
sqrt3_r = 3**0.5 * r
return variance * (1 + sqrt3_r) * torch.exp(-sqrt3_r)


class Matern52(Isotropy):
"""
Implementation of Matern52 kernel: ``(1 + sqrt(5) * r/l + 5/3 * r^2 / l^2) * exp(-sqrt(5) * r/l)``.
"""

def __init__(self, input_dim, variance=None, lengthscale=None, active_dims=None, name="Matern52"):
super(Matern52, self).__init__(input_dim, variance, lengthscale, active_dims, name)

def forward(self, X, Z=None, diag=False):
if diag:
return self._diag(X)

variance = self.get_param("variance")
r2 = self._square_scaled_dist(X, Z)
r = _torch_sqrt(r2)
sqrt5_r = 5**0.5 * r
return variance * (1 + sqrt5_r + (5/3) * r2) * torch.exp(-sqrt5_r)

0 comments on commit 98752a9

Please sign in to comment.