-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
exponentiated_quadratic.py
128 lines (108 loc) · 5.01 KB
/
exponentiated_quadratic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Copyright 2018 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""The ExponentiatedQuadratic kernel."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import tensor_util
from tensorflow_probability.python.math.psd_kernels.internal import util
from tensorflow_probability.python.math.psd_kernels.positive_semidefinite_kernel import PositiveSemidefiniteKernel
__all__ = ['ExponentiatedQuadratic']
class ExponentiatedQuadratic(PositiveSemidefiniteKernel):
"""The ExponentiatedQuadratic kernel.
Sometimes called the "squared exponential", "Gaussian" or "radial basis
function", this kernel function has the form
```none
k(x, y) = amplitude**2 * exp(-||x - y||**2 / (2 * length_scale**2))
```
where the double-bars represent vector length (ie, Euclidean, or L2 norm).
"""
def __init__(self,
amplitude=None,
length_scale=None,
feature_ndims=1,
validate_args=False,
name='ExponentiatedQuadratic'):
"""Construct an ExponentiatedQuadratic kernel instance.
Args:
amplitude: floating point `Tensor` that controls the maximum value
of the kernel. Must be broadcastable with `length_scale` and inputs to
`apply` and `matrix` methods. Must be greater than zero. A value of
`None` is treated like 1.
Default value: None
length_scale: floating point `Tensor` that controls how sharp or wide the
kernel shape is. This provides a characteristic "unit" of length against
which `||x - y||` can be compared for scale. Must be broadcastable with
`amplitude` and inputs to `apply` and `matrix` methods. A value of
`None` is treated like 1.
Default value: None
feature_ndims: Python `int` number of rightmost dims to include in the
squared difference norm in the exponential.
validate_args: If `True`, parameters are checked for validity despite
possibly degrading runtime performance
name: Python `str` name prefixed to Ops created by this class.
"""
with tf.name_scope(name):
dtype = util.maybe_get_common_dtype(
[amplitude, length_scale])
self._amplitude = tensor_util.convert_nonref_to_tensor(
amplitude, name='amplitude', dtype=dtype)
self._length_scale = tensor_util.convert_nonref_to_tensor(
length_scale, name='length_scale', dtype=dtype)
super(ExponentiatedQuadratic, self).__init__(
feature_ndims, dtype=dtype, name=name, validate_args=validate_args)
@property
def amplitude(self):
"""Amplitude parameter."""
return self._amplitude
@property
def length_scale(self):
"""Length scale parameter."""
return self._length_scale
def _batch_shape(self):
scalar_shape = tf.TensorShape([])
return tf.broadcast_static_shape(
scalar_shape if self.amplitude is None else self.amplitude.shape,
scalar_shape if self.length_scale is None else self.length_scale.shape)
def _batch_shape_tensor(self):
return tf.broadcast_dynamic_shape(
[] if self.amplitude is None else tf.shape(self.amplitude),
[] if self.length_scale is None else tf.shape(self.length_scale))
def _apply(self, x1, x2, example_ndims=0):
exponent = -0.5 * util.sum_rightmost_ndims_preserving_shape(
tf.math.squared_difference(x1, x2), self.feature_ndims)
if self.length_scale is not None:
length_scale = tf.convert_to_tensor(self.length_scale)
length_scale = util.pad_shape_with_ones(
length_scale, example_ndims)
exponent /= length_scale**2
if self.amplitude is not None:
amplitude = tf.convert_to_tensor(self.amplitude)
amplitude = util.pad_shape_with_ones(amplitude, example_ndims)
exponent += 2. * tf.math.log(amplitude)
return tf.exp(exponent)
def _parameter_control_dependencies(self, is_init):
if not self.validate_args:
return []
assertions = []
for arg_name, arg in dict(amplitude=self.amplitude,
length_scale=self.length_scale).items():
if arg is not None and is_init != tensor_util.is_ref(arg):
assertions.append(assert_util.assert_positive(
arg,
message='{} must be positive.'.format(arg_name)))
return assertions