Kernels, the machine learning ones (also, mean functions)
Contents:
- Installation
- Usage
- AutoGrad, TensorFlow, PyTorch, or JAX? Your Choice!
- Available Kernels
- Available Means
- Compositional Design
- Displaying Kernels and Means
- Properties of Kernels and Means
- Structured Matrix Types
- Implementing Your Own Kernel
TLDR:
>>> from mlkernels import EQ, Linear
>>> k1 = 2 * Linear() + 1
>>> k1
2 * Linear() + 1
>>> k1(np.linspace(0, 1, 3)) # Structured matrices enable efficiency.
<low-rank matrix: shape=3x3, dtype=float64, rank=2
left=[[0. 1. ]
[0.5 1. ]
[1. 1. ]]
middle=[[2. 0.]
[0. 1.]]
right=[[0. 1. ]
[0.5 1. ]
[1. 1. ]]>
>>> import lab as B
>>> B.dense(k1(np.linspace(0, 1, 3))) # Discard structure: get a regular NumPy array.
array([[1. , 1. , 1. ],
[1. , 1.5, 2. ],
[1. , 2. , 3. ]])
>>> k2 = 2 + EQ() * Linear()
>>> k2
2 * 1 + EQ() * Linear()
>>> k2(np.linspace(0, 1, 3))
<dense matrix: shape=3x3, dtype=float64
mat=[[2. 2. 2. ]
[2. 2.25 2.441]
[2. 2.441 3. ]]>
>>> B.dense(k1(np.linspace(0, 1, 3)))
array([[2. , 2. , 2. ],
[2. , 2.25 , 2.44124845],
[2. , 2.44124845, 3. ]])
pip install mlkernels
See also the instructions here.
Let k
be a kernel, e.g. k = EQ()
.
k(x, y)
constructs the kernel matrix for all pairs of points betweenx
andy
.k(x)
is shorthand fork(x, x)
.k.elwise(x, y)
constructs the kernel vector pairing the points inx
andy
element-wise, which will be a rank-2 column vector.
Example:
>>> k = EQ()
>>> k(np.linspace(0, 1, 3))
<dense matrix: shape=3x3, dtype=float64
mat=[[1. 0.882 0.607]
[0.882 1. 0.882]
[0.607 0.882 1. ]]>
>>> k.elwise(np.linspace(0, 1, 3), 0)
array([[1. ],
[0.8824969 ],
[0.60653066]])
Let m
be a mean, e.g. m = TensorProductMean(lambda x: x ** 2)
.
m(x)
constructs the mean vector for the points inx
, which will be a rank-2 column vector.
Example:
>>> m = TensorProductMean(lambda x: x ** 2)
>>> m(np.linspace(0, 1, 3))
array([[0. ],
[0.25],
[1. ]])
Inputs to kernels and means are interpreted in the following way:
-
If the input
x
is a rank-0 tensor, i.e. a scalar, thenx
refers to a single input location. For example,0
simply refers to the sole input location0
. -
If the input
x
is a rank-1 tensor, i.e. a vector, then every element ofx
is interpreted as a separate input location. For example,np.linspace(0, 1, 10)
generates 10 different input locations ranging from0
to1
. -
If the input
x
is a rank-2 tensor, i.e. a matrix, then every row ofx
is interpreted as a separate input location. In this case inputs are multi-dimensional, and the columns correspond to the various input dimensions. -
If the input
x
is a tensor of rank 3 or higher, then the input is interpreted as a batch of matrices where the matrix dimensions are the two outermost dimensions.
Example:
>>> k = EQ()
>>> k(0) # One scalar input
<dense matrix: batch=(), shape=(1, 1), dtype=float64
mat=[[1.]]>
>>> k(np.linspace(0, 1, 3)) # Three single-dimensional inputs
<dense matrix: batch=(), shape=(3, 3), dtype=float64
mat=[[1. 0.882 0.607]
[0.882 1. 0.882]
[0.607 0.882 1. ]]>
>>> k(np.random.randn(3, 2)) # Three two-dimensional inputs
<dense matrix: batch=(), shape=(3, 3), dtype=float64
mat=[[1. 0.606 0.399]
[0.606 1. 0.931]
[0.399 0.931 1. ]]>
>>> k(np.random.randn(2, 3, 2)) # A batch of two times three two-dimensional inputs
<dense matrix: batch=(2,), shape=(3, 3), dtype=float64
mat=[[[1. 0.15 0.559]
[0.15 1. 0.678]
[0.559 0.678 1. ]]
[[1. 0.891 0.627]
[0.891 1. 0.638]
[0.627 0.638 1. ]]]>
Finally, if you are simultaneously computing means and kernel matrices, then speed-ups
may be possible.
To obtain these speed-ups, use mean_var
instead of first calling m(x)
and then
k(x)
;
and use mean_var_diag
instead of first calling m(x)
and then k.elwise(x)
.
Example:
>>> from mlkernels import mean_var, mean_var_diag
>>> m = TensorProductMean(lambda x: x ** 2)
>>> k = EQ()
>>> x = np.linspace(0, 1, 3)
>>> m(x), k(x) # Less efficient
(array([[0. ],
[0.25],
[1. ]]),
<dense matrix: batch=(), shape=(3, 3), dtype=float64
mat=[[1. 0.882 0.607]
[0.882 1. 0.882]
[0.607 0.882 1. ]]>)
>>> mean_var(m, k, x) # More efficient
(array([[0. ],
[0.25],
[1. ]]),
<dense matrix: batch=(), shape=(3, 3), dtype=float64
mat=[[1. 0.882 0.607]
[0.882 1. 0.882]
[0.607 0.882 1. ]]>)
>>> m(x), k.elwise(x) # Less efficient
(array([[0. ],
[0.25],
[1. ]]),
array([[1.],
[1.],
[1.]]))
>>> mean_var_diag(m, k, x) # More efficient
(array([[0. ],
[0.25],
[1. ]]),
array([[1.],
[1.],
[1.]]))
from mlkernels.autograd import EQ, Linear
from mlkernels.tensorflow import EQ, Linear
from mlkernels.torch import EQ, Linear
from mlkernels.jax import EQ, Linear
See here for a nicely rendered version of the list below.
Constants function as constant kernels. Besides that, the following kernels are available:
-
EQ()
, the exponentiated quadratic:$$ k(x, y) = \exp\left( -\frac{1}{2}|x - y|^2 \right); $$ -
CEQ(alpha)
, the causal exponentiated quadratic:$$ k(x, y) = \left(1 - \operatorname{erf}\left( \frac{\alpha}{4}|x - y| \right)\right)\exp\left( -\frac{1}{2}|x - y|^2 \right); $$
-
RQ(alpha)
, the rational quadratic:$$ k(x, y) = \left( 1 + \frac{|x - y|^2}{2 \alpha} \right)^{-\alpha}; $$
-
Matern12()
orExp()
, the Matern–1/2 kernel:$$ k(x, y) = \exp\left( -|x - y| \right); $$
-
Matern32()
, the Matern–3/2 kernel:$$ k(x, y) = \left( 1 + \sqrt{3}|x - y| \right)\exp\left(-\sqrt{3}|x - y|\right); $$
-
Matern52()
, the Matern–5/2 kernel:$$ k(x, y) = \left( 1 + \sqrt{5}|x - y| + \frac{5}{3} |x - y|^2 \right)\exp\left(-\sqrt{3}|x - y|\right); $$
-
Linear()
, the linear kernel:$$ k(x, y) = \langle x, y \rangle; $$
-
Delta(epsilon=1e-6)
, the Kronecker delta kernel:$$ k(x, y) = \begin{cases} 1 & \text{if } |x - y| < \varepsilon, \ 0 & \text{otherwise}; \end{cases} $$
-
DecayingKernel(alpha, beta)
:$$ k(x, y) = \frac{|\beta|^\alpha}{|x + y + \beta|^\alpha}; $$
-
LogKernel()
:$$ k(x, y) = \frac{\log(1 + |x - y|)}{|x - y|}; $$
-
PosteriorKernel(k_ij, k_zi, k_zj, z, K_z)
:$$ k(x, y) = k_{ij}(x, y) - k_{iz}(x, z) K_z^{-1} k_{zj}(x, y); $$
-
SubspaceKernel(k_zi, k_zj, z, A)
:$$ k(x, y) = k_{iz}(x, z) A^{-1} k_{zj}(x, y); $$
-
TensorProductKernel(f)
:$$ k(x, y) = f(x)f(y). $$
Adding or multiplying a
FunctionType
f
to or with a kernel will automatically translatef
toTensorProductKernel(f)
. For example,f * k
will translate toTensorProductKernel(f) * k
, andf + k
will translate toTensorProductKernel(f) + k
.
Constants function as constant means. Besides that, the following means are available:
-
TensorProductMean(f)
:$$ m(x) = f(x). $$
Adding or multiplying a
FunctionType
f
to or with a mean will automatically translatef
toTensorProductMean(f)
. For example,f * m
will translate toTensorProductMean(f) * m
, andf + m
will translate toTensorProductMean(f) + m
.
-
Add and subtract kernels and means.
Example:
>>> EQ() + Matern12() EQ() + Matern12() >>> EQ() + EQ() 2 * EQ() >>> EQ() + 1 EQ() + 1 >>> EQ() + 0 EQ() >>> EQ() - Matern12() EQ() - Matern12() >>> EQ() - EQ() 0
-
Multiply kernels and means.
Example:
>>> EQ() * Matern12() EQ() * Matern12() >>> 2 * EQ() 2 * EQ() >>> 0 * EQ() 0
-
Shift kernels and means.
Definition:
k.shift(c)(x, y) == k(x - c, y - c) k.shift(c1, c2)(x, y) == k(x - c1, y - c2)
Example:
>>> Linear().shift(1) Linear() shift 1 >>> EQ().shift(1, 2) EQ() shift (1, 2)
-
Stretch kernels and means.
Definition:
k.stretch(c)(x, y) == k(x / c, y / c) k.stretch(c1, c2)(x, y) == k(x / c1, y / c2)
Example:
>>> EQ().stretch(2) EQ() > 2 >>> EQ().stretch(2, 3) EQ() > (2, 3)
-
Select particular input dimensions of kernels and means.
Definition:
k.select([0])(x, y) == k(x[:, 0], y[:, 0]) k.select([0, 1])(x, y) == k(x[:, [0, 1]], y[:, [0, 1]]) k.select([0], [1])(x, y) == k(x[:, 0], y[:, 1]) k.select(None, [1])(x, y) == k(x, y[:, 1])
Example:
>>> EQ().select([0]) EQ() : [0] >>> EQ().select([0, 1]) EQ() : [0, 1] >>> EQ().select([0], [1]) EQ() : ([0], [1]) >>> EQ().select(None, [1]) EQ() : (None, [1])
-
Transform the inputs of kernels and means.
Definition:
k.transform(f)(x, y) == k(f(x), f(y)) k.transform(f1, f2)(x, y) == k(f1(x), f2(y)) k.transform(None, f)(x, y) == k(x, f(y))
Example:
>>> EQ().transform(f) EQ() transform f >>> EQ().transform(f1, f2) EQ() transform (f1, f2) >>> EQ().transform(None, f) EQ() transform (None, f)
-
Numerically, but efficiently, take derivatives of kernels and means. This currently only works in TensorFlow and does not yet support batched inputs.
Definition:
k.diff(0)(x, y) == d/d(x[:, 0]) d/d(y[:, 0]) k(x, y) k.diff(0, 1)(x, y) == d/d(x[:, 0]) d/d(y[:, 1]) k(x, y) k.diff(None, 1)(x, y) == d/d(y[:, 1]) k(x, y)
Example:
>>> EQ().diff(0) d(0) EQ() >>> EQ().diff(0, 1) d(0, 1) EQ() >>> EQ().diff(None, 1) d(None, 1) EQ()
-
Make kernels periodic. This is not implemented for means.
Definition:
k.periodic(2 pi / w)(x, y) == k((sin(w * x), cos(w * x)), (sin(w * y), cos(w * y)))
Example:
>>> EQ().periodic(1) EQ() per 1
-
Reverse the arguments of kernels. This does not apply to means.
Definition:
reversed(k)(x, y) == k(y, x)
Example:
>>> reversed(Linear()) Reversed(Linear())
-
Extract terms and factors from sums and products respectively.
Example:
>>> (EQ() + RQ(0.1) + Linear()).term(1) RQ(0.1) >>> (2 * EQ() * Linear).factor(0) 2
Kernels and means "wrapping" others can be "unwrapped" by indexing
k[0]
orm[0]
.Example:
>>> reversed(Linear()) Reversed(Linear()) >>> reversed(Linear())[0] Linear() >>> EQ().periodic(1) EQ() per 1 >>> EQ().periodic(1)[0] EQ()
Kernels and means have a display
method.
The display
method accepts a callable formatter that will be applied before any value
is printed.
This comes in handy when pretty printing kernels.
Example:
>>> print((2.12345 * EQ()).display(lambda x: f"{x:.2f}"))
2.12 * EQ()
-
Kernels and means can be equated to check for equality. This will attempt basic algebraic manipulations. If the kernels and means are not equal or equality cannot be proved, then
False
is returned.Example of equating kernels:
>>> 2 * EQ() == EQ() + EQ() True >>> EQ() + Matern12() == Matern12() + EQ() True >>> 2 * Matern12() == EQ() + Matern12() False >>> EQ() + Matern12() + Linear() == Linear() + Matern12() + EQ() # Too hard: cannot prove equality! False
-
The stationarity of a kernel
k
can always be determined by queryingk.stationary
.Example of querying the stationarity:
>>> EQ().stationary True >>> (EQ() + Linear()).stationary False
MLKernels uses an extension of LAB to accelerate linear algebra with structured linear algebra primitives.
Example:
>>> import lab as B
>>> k = 2 * Delta()
>>> x = np.linspace(0, 5, 10)
>>> from mlkernels import pairwise
>>> k(x) # Preserve structure.
<diagonal matrix: shape=10x10, dtype=float64
diag=[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]>
>>> B.dense(k(x)) # Do not preserve structure.
array([[2., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 2., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 2., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 2., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 2., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 2., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 2., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 2., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 2., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 2.]])
These structured matrices are compatible with LAB: they know how to efficiently add, multiply, and do other linear algebra operations.
>>> import lab as B
>>> B.matmul(pairwise(k, x), pairwise(k, x))
<diagonal matrix: shape=10x10, dtype=float64
diag=[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]>
As in the above example, you can convert structured primitives to regular
NumPy/TensorFlow/PyTorch/JAX arrays by calling B.dense
:
>>> import lab as B
>>> B.dense(B.matmul(pairwise(k, x), pairwise(k, x)))
array([[4., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 4., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 4., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 4., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 4., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 4., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 4., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 4., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 4., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 4.]])
An example is most helpful:
import lab as B
from algebra.util import identical
from matrix import Dense
from plum import dispatch
from mlkernels import Kernel, pairwise, elwise
class EQWithLengthScale(Kernel):
"""Exponentiated quadratic kernel with a length scale.
Args:
scale (scalar): Length scale of the kernel.
"""
def __init__(self, scale):
self.scale = scale
def _compute(self, dists2):
# This computes the kernel given squared distances. We use `B` to provide a
# backend-agnostic implementation.
return B.exp(-0.5 * dists2 / self.scale ** 2)
def render(self, formatter):
# This method determines how the kernel is displayed.
return f"EQWithLengthScale({formatter(self.scale)})"
@property
def _stationary(self):
# This method can be defined to return `True` to indicate that the kernel is
# stationary. By default, kernels are assumed to not be stationary.
return True
@dispatch
def __eq__(self, other: "EQWithLengthScale"):
# If `other` is also a `EQWithLengthScale`, then this method checks whether
# `self` and `other` can be treated as identical for the purpose of
# algebraic simplifications. In this case, `self` and `other` are identical
# for the purpose of algebraic simplification if `self.scale` and
# `other.scale` are. We use `algebra.util.identical` to check this condition.
return identical(self.scale, other.scale)
# It remains to implement pairwise and element-wise computation of the kernel.
@pairwise.dispatch
def pairwise(k: EQWithLengthScale, x: B.Numeric, y: B.Numeric):
return Dense(k._compute(B.pw_dists2(x, y)))
@elwise.dispatch
def elwise(k: EQWithLengthScale, x: B.Numeric, y: B.Numeric):
return k._compute(B.ew_dists2(x, y))
>>> k = EQWithLengthScale(2)
>>> k
EQWithLengthScale(2)
>>> k == EQWithLengthScale(2)
True
>>> 2 * k == k + EQWithLengthScale(2)
True
>>> k == Linear()
False
>>> k_composite = (2 * k + Linear()) * RQ(2.0)
>>> k_composite
(2 * EQWithLengthScale(2) + Linear()) * RQ(2.0)
>>> k_composite(np.linspace(0, 1, 3))
<dense matrix: shape=3x3, dtype=float64
mat=[[2. 1.717 1.13 ]
[1.717 2.25 2.16 ]
[1.13 2.16 3. ]]>
Of course, in practice we do not need to implement variants of kernels which include length scales, because we always adjust the length scale by stretching a base kernel:
>>> EQ().stretch(2)(np.linspace(0, 1, 3))
<dense matrix: shape=3x3, dtype=float64
mat=[[1. 0.969 0.882]
[0.969 1. 0.969]
[0.882 0.969 1. ]]>
>>> EQWithLengthScale(2)(np.linspace(0, 1, 3))
<dense matrix: shape=3x3, dtype=float64
mat=[[1. 0.969 0.882]
[0.969 1. 0.969]
[0.882 0.969 1. ]]>