-
Notifications
You must be signed in to change notification settings - Fork 56
/
test_matrix.py
121 lines (91 loc) · 3.13 KB
/
test_matrix.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""Test cases for ``Kernel.matrix``"""
from typing import Callable, Optional
import numpy as np
import pytest
import probnum as pn
from probnum.typing import ShapeType
@pytest.fixture(name="kernmat")
def fixture_kernmat(
kernel: pn.randprocs.kernels.Kernel, x0: np.ndarray, x1: Optional[np.ndarray]
) -> np.ndarray:
"""Kernel evaluated at the data."""
if x1 is None and np.prod(x0.shape[:-1]) >= 100:
pytest.skip("Runs too long")
return kernel.matrix(x0, x1)
@pytest.fixture(name="kernmat_naive")
def fixture_kernmat_naive(
kernel_call_naive: Callable[[np.ndarray, Optional[np.ndarray]], np.ndarray],
x0: np.ndarray,
x1: Optional[np.ndarray],
) -> np.ndarray:
"""Kernel evaluated at the data."""
if x1 is None:
if np.prod(x0.shape[:-1]) >= 100:
pytest.skip("Runs too long")
return kernel_call_naive(x0=x0[:, None, :], x1=x0[None, :, :])
return kernel_call_naive(x0=x0[:, None, :], x1=x1[None, :, :])
def test_type(kernmat: np.ndarray):
"""Check whether a kernel evaluates to a numpy scalar or array."""
assert isinstance(kernmat, (np.ndarray, np.number))
def test_shape(
kernel: pn.randprocs.kernels.Kernel,
x0: np.ndarray,
x1: Optional[np.ndarray],
kernmat: np.ndarray,
kernmat_naive: np.ndarray,
):
"""Test the shape of a kernel evaluated at sets of inputs."""
assert kernmat.shape == kernmat_naive.shape, (
f"Kernel {type(kernel)} does not have the right shape if evaluated at inputs "
f"with x0.shape={x0.shape}"
+ ("" if x1 is None else f"and x1.shape={x1.shape}.")
)
def test_kernel_matrix_against_naive(
kernmat: np.ndarray,
kernmat_naive: np.ndarray,
):
"""Test the computation of the kernel matrix against a naive computation."""
np.testing.assert_allclose(
kernmat,
kernmat_naive,
rtol=10 ** -12,
atol=10 ** -12,
)
@pytest.mark.parametrize(
"x0_shape,x1_shape",
[
((2, 5), (3, 5)),
((4, 4), (4, 2)),
],
)
def test_invalid_shape(
kernel: pn.randprocs.kernels.Kernel,
x0_shape: np.ndarray,
x1_shape: np.ndarray,
):
"""Test whether an error is raised if the inputs can not be broadcast to a common
shape."""
with pytest.raises(ValueError):
kernel.matrix(np.zeros(x0_shape + (kernel.input_dim,)))
with pytest.raises(ValueError):
kernel.matrix(
np.zeros(x0_shape + (kernel.input_dim,)),
np.ones(x1_shape + (kernel.input_dim,)),
)
@pytest.mark.parametrize(
"shape",
[
(),
(1,),
(10,),
],
)
def test_wrong_input_dimension(kernel: pn.randprocs.kernels.Kernel, shape: ShapeType):
"""Test whether passing an input with the wrong input dimension raises an error."""
input_shape = shape + (kernel.input_dim + 1,)
with pytest.raises(ValueError):
kernel.matrix(np.zeros(input_shape))
with pytest.raises(ValueError):
kernel.matrix(np.ones(input_shape), np.zeros(shape + (kernel.input_dim,)))
with pytest.raises(ValueError):
kernel.matrix(np.ones(shape + (kernel.input_dim,)), np.zeros(input_shape))