-
-
Notifications
You must be signed in to change notification settings - Fork 986
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Done] Add various kernels for GP module (#805)
* add various kernels * add various kernel * make lint * small bug at isotropy * lint * add tests * fix bug NaN at sqrt
- Loading branch information
Showing
10 changed files
with
606 additions
and
85 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,11 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
from .brownian import Brownian | ||
from .dot_product import DotProduct, Linear, Polynomial | ||
from .isotropic import (Exponential, Isotropy, Matern12, Matern32, Matern52, | ||
RationalQuadratic, RBF, SquaredExponential) | ||
from .kernel import Kernel | ||
from .rbf import RBF | ||
from .periodic import Cosine, Periodic, SineSquaredExponential | ||
from .static import Bias, Constant, WhiteNoise | ||
|
||
# flake8: noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
import torch | ||
from torch.autograd import Variable | ||
from torch.distributions import constraints | ||
from torch.nn import Parameter | ||
|
||
from .kernel import Kernel | ||
|
||
|
||
class Brownian(Kernel): | ||
""" | ||
This kernel correponds to a two-sided Brownion motion (Wiener process): | ||
``k(x, z) = min(|x|,|z|)`` if ``x.z >= 0`` and ``k(x, z) = 0`` otherwise. | ||
Note that the input dimension of this kernel must be 1. | ||
References: | ||
[1] `Theory and Statistical Applications of Stochastic Processes`, | ||
Yuliya Mishura, Georgiy Shevchenko | ||
""" | ||
|
||
def __init__(self, input_dim, variance=None, active_dims=None, name="Brownian"): | ||
if input_dim != 1: | ||
raise ValueError("Input dimensional for Brownian kernel must be 1.") | ||
super(Brownian, self).__init__(input_dim, active_dims, name) | ||
|
||
if variance is None: | ||
variance = torch.ones(1) | ||
self.variance = Parameter(variance) | ||
self.set_constraint("variance", constraints.positive) | ||
|
||
def forward(self, X, Z=None, diag=False): | ||
variance = self.get_param("variance") | ||
|
||
if Z is None: | ||
Z = X | ||
X = self._slice_input(X) | ||
if diag: | ||
return variance * X.abs().squeeze(1) | ||
|
||
Z = self._slice_input(Z) | ||
if X.size(1) != Z.size(1): | ||
raise ValueError("Inputs must have the same number of features.") | ||
|
||
Zt = Z.t() | ||
return torch.where(X.sign() == Zt.sign(), | ||
variance * torch.min(X.abs(), Zt.abs()), | ||
Variable(X.data.new(X.size(0), Z.size(0)).zero_())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
import torch | ||
from torch.distributions import constraints | ||
from torch.nn import Parameter | ||
|
||
from .kernel import Kernel | ||
|
||
|
||
class DotProduct(Kernel): | ||
""" | ||
Base kernel for Linear and Polonomial kernels. | ||
:param torch.Tensor variance: Variance parameter which plays the role of scaling. | ||
""" | ||
|
||
def __init__(self, input_dim, variance=None, active_dims=None, name=None): | ||
super(DotProduct, self).__init__(input_dim, active_dims, name) | ||
|
||
if variance is None: | ||
variance = torch.ones(1) | ||
self.variance = Parameter(variance) | ||
self.set_constraint("variance", constraints.positive) | ||
|
||
def _dot_product(self, X, Z=None, diag=False): | ||
""" | ||
Returns ``X.Z``. | ||
""" | ||
if Z is None: | ||
Z = X | ||
X = self._slice_input(X) | ||
if diag: | ||
return (X ** 2).sum(-1) | ||
|
||
Z = self._slice_input(Z) | ||
if X.size(1) != Z.size(1): | ||
raise ValueError("Inputs must have the same number of features.") | ||
|
||
return X.matmul(Z.t()) | ||
|
||
|
||
class Linear(DotProduct): | ||
""" | ||
Implementation of Linear kernel. Doing Gaussian Process Regression with linear kernel | ||
is equivalent to Linear Regression. | ||
Note that here we implement the homogeneous version. To use the inhomogeneous version, | ||
consider using Polynomial kernel with `degree=1` or making a combination with a Bias kernel. | ||
""" | ||
|
||
def __init__(self, input_dim, variance=None, active_dims=None, name="Linear"): | ||
super(Linear, self).__init__(input_dim, variance, active_dims, name) | ||
|
||
def forward(self, X, Z=None, diag=False): | ||
variance = self.get_param("variance") | ||
return variance * self._dot_product(X, Z, diag) | ||
|
||
|
||
class Polynomial(DotProduct): | ||
""" | ||
Implementation of Polynomial kernel: ``k(x, z) = (bias + x.z)^d``. | ||
:param torch.Tensor bias: Bias parameter for this kernel. Should be positive. | ||
:param int degree: Degree of this polynomial. | ||
""" | ||
|
||
def __init__(self, input_dim, variance=None, bias=None, degree=1, active_dims=None, name="Polynomial"): | ||
super(Polynomial, self).__init__(input_dim, variance, active_dims, name) | ||
|
||
if bias is None: | ||
bias = torch.ones(1) | ||
self.bias = Parameter(bias) | ||
self.set_constraint("bias", constraints.positive) | ||
|
||
if degree < 1: | ||
raise ValueError("Degree for Polynomial kernel should be a positive integer.") | ||
self.degree = degree | ||
|
||
def forward(self, X, Z=None, diag=False): | ||
variance = self.get_param("variance") | ||
bias = self.get_param("bias") | ||
return variance * ((bias + self._dot_product(X, Z, diag)) ** self.degree) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
import torch | ||
from torch.distributions import constraints | ||
from torch.nn import Parameter | ||
|
||
from .kernel import Kernel | ||
|
||
|
||
def _torch_sqrt(x, eps=1e-18): | ||
""" | ||
A convenient function to avoid the NaN gradient issue of ``torch.sqrt`` at 0. | ||
""" | ||
# Ref: https://github.com/pytorch/pytorch/issues/2421 | ||
return (x + eps).sqrt() | ||
|
||
|
||
class Isotropy(Kernel): | ||
""" | ||
Base kernel for a family of isotropic covariance functions which is a | ||
function of the distance ``r=|x-z|``. | ||
By default, the parameter ``lengthscale`` has size 1. To use the | ||
anisotropic version (different lengthscale for each dimension), | ||
make sure that lengthscale has size equal to ``input_dim``. | ||
:param torch.Tensor variance: Variance parameter of this kernel. | ||
:param torch.Tensor lengthscale: Length scale parameter of this kernel. | ||
""" | ||
|
||
def __init__(self, input_dim, variance=None, lengthscale=None, active_dims=None, name=None): | ||
super(Isotropy, self).__init__(input_dim, active_dims, name) | ||
|
||
if variance is None: | ||
variance = torch.ones(1) | ||
self.variance = Parameter(variance) | ||
self.set_constraint("variance", constraints.positive) | ||
|
||
if lengthscale is None: | ||
lengthscale = torch.ones(1) | ||
self.lengthscale = Parameter(lengthscale) | ||
self.set_constraint("lengthscale", constraints.positive) | ||
|
||
def _square_scaled_dist(self, X, Z=None): | ||
""" | ||
Returns ``||(X-Z)/lengthscale||^2``. | ||
""" | ||
if Z is None: | ||
Z = X | ||
X = self._slice_input(X) | ||
Z = self._slice_input(Z) | ||
if X.size(1) != Z.size(1): | ||
raise ValueError("Inputs must have the same number of features.") | ||
|
||
lengthscale = self.get_param("lengthscale") | ||
scaled_X = X / lengthscale | ||
scaled_Z = Z / lengthscale | ||
X2 = (scaled_X ** 2).sum(1, keepdim=True) | ||
Z2 = (scaled_Z ** 2).sum(1, keepdim=True) | ||
XZ = scaled_X.matmul(scaled_Z.t()) | ||
r2 = X2 - 2 * XZ + Z2.t() | ||
return r2 | ||
|
||
def _scaled_dist(self, X, Z=None): | ||
""" | ||
Returns ``||(X-Z)/lengthscale||``. | ||
""" | ||
return _torch_sqrt(self._square_scaled_dist(X, Z)) | ||
|
||
def _diag(self, X): | ||
""" | ||
Calculates the diagonal part of covariance matrix on active dimensionals. | ||
""" | ||
variance = self.get_param("variance") | ||
return variance.expand(X.size(0)) | ||
|
||
|
||
class RBF(Isotropy): | ||
""" | ||
Implementation of Radial Basis Function kernel: ``exp(-0.5 * r^2 / l^2)``. | ||
""" | ||
|
||
def __init__(self, input_dim, variance=None, lengthscale=None, active_dims=None, name="RBF"): | ||
super(RBF, self).__init__(input_dim, variance, lengthscale, active_dims, name) | ||
|
||
def forward(self, X, Z=None, diag=False): | ||
if diag: | ||
return self._diag(X) | ||
|
||
variance = self.get_param("variance") | ||
r2 = self._square_scaled_dist(X, Z) | ||
return variance * torch.exp(-0.5 * r2) | ||
|
||
|
||
class SquaredExponential(RBF): | ||
""" | ||
Squared Exponential is another name for RBF kernel. | ||
""" | ||
|
||
def __init__(self, input_dim, variance=None, lengthscale=None, active_dims=None, | ||
name="SquaredExponential"): | ||
super(SquaredExponential, self).__init__(input_dim, variance, lengthscale, active_dims, name) | ||
|
||
|
||
class RationalQuadratic(Isotropy): | ||
""" | ||
Implementation of Rational Quadratic kernel: ``(1 + 0.5 * r^2 / alpha l^2)^(-alpha)``. | ||
:param torch.Tensor scale_mixture: Scale mixture (alpha) parameter of this kernel. | ||
Should have size 1. | ||
""" | ||
|
||
def __init__(self, input_dim, variance=None, lengthscale=None, scale_mixture=None, active_dims=None, | ||
name="RationalQuadratic"): | ||
super(RationalQuadratic, self).__init__(input_dim, variance, lengthscale, active_dims, name) | ||
|
||
if scale_mixture is None: | ||
scale_mixture = torch.ones(1) | ||
self.scale_mixture = Parameter(scale_mixture) | ||
self.set_constraint("scale_mixture", constraints.positive) | ||
|
||
def forward(self, X, Z=None, diag=False): | ||
if diag: | ||
return self._diag(X) | ||
|
||
variance = self.get_param("variance") | ||
scale_mixture = self.get_param("scale_mixture") | ||
r2 = self._square_scaled_dist(X, Z) | ||
return variance * (1 + (0.5 / scale_mixture) * r2).pow(-scale_mixture) | ||
|
||
|
||
class Exponential(Isotropy): | ||
""" | ||
Implementation of Exponential kernel: `exp(-r/l)`. | ||
""" | ||
|
||
def __init__(self, input_dim, variance=None, lengthscale=None, active_dims=None, name="Exponential"): | ||
super(Exponential, self).__init__(input_dim, variance, lengthscale, active_dims, name) | ||
|
||
def forward(self, X, Z=None, diag=False): | ||
if diag: | ||
return self._diag(X) | ||
|
||
variance = self.get_param("variance") | ||
r = self._scaled_dist(X, Z) | ||
return variance * torch.exp(-r) | ||
|
||
|
||
class Matern12(Exponential): | ||
""" | ||
Another name of Exponential kernel. | ||
""" | ||
def __init__(self, input_dim, variance=None, lengthscale=None, active_dims=None, name="Matern12"): | ||
super(Matern12, self).__init__(input_dim, variance, lengthscale, active_dims, name) | ||
|
||
|
||
class Matern32(Isotropy): | ||
""" | ||
Implementation of Matern32 kernel: ``(1 + sqrt(3) * r/l) * exp(-sqrt(3) * r/l)``. | ||
""" | ||
|
||
def __init__(self, input_dim, variance=None, lengthscale=None, active_dims=None, name="Matern32"): | ||
super(Matern32, self).__init__(input_dim, variance, lengthscale, active_dims, name) | ||
|
||
def forward(self, X, Z=None, diag=False): | ||
if diag: | ||
return self._diag(X) | ||
|
||
variance = self.get_param("variance") | ||
r = self._scaled_dist(X, Z) | ||
sqrt3_r = 3**0.5 * r | ||
return variance * (1 + sqrt3_r) * torch.exp(-sqrt3_r) | ||
|
||
|
||
class Matern52(Isotropy): | ||
""" | ||
Implementation of Matern52 kernel: ``(1 + sqrt(5) * r/l + 5/3 * r^2 / l^2) * exp(-sqrt(5) * r/l)``. | ||
""" | ||
|
||
def __init__(self, input_dim, variance=None, lengthscale=None, active_dims=None, name="Matern52"): | ||
super(Matern52, self).__init__(input_dim, variance, lengthscale, active_dims, name) | ||
|
||
def forward(self, X, Z=None, diag=False): | ||
if diag: | ||
return self._diag(X) | ||
|
||
variance = self.get_param("variance") | ||
r2 = self._square_scaled_dist(X, Z) | ||
r = _torch_sqrt(r2) | ||
sqrt5_r = 5**0.5 * r | ||
return variance * (1 + sqrt5_r + (5/3) * r2) * torch.exp(-sqrt5_r) |
Oops, something went wrong.