-
-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
test_ndarray_backend.py
97 lines (70 loc) · 2.62 KB
/
test_ndarray_backend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import unittest
import numpy as np
import numpy.testing as npt
from pymc3.tests import backend_fixtures as bf
from pymc3.backends import base, ndarray
class TestNDArray0dSampling(bf.SamplingTestCase):
backend = ndarray.NDArray
name = None
shape = ()
class TestNDArray1dSampling(bf.SamplingTestCase):
backend = ndarray.NDArray
name = None
shape = 2
class TestNDArray2dSampling(bf.SamplingTestCase):
backend = ndarray.NDArray
name = None
shape = (2, 3)
class TestNDArray0dSelection(bf.SelectionTestCase):
backend = ndarray.NDArray
name = None
shape = ()
class TestNDArray1dSelection(bf.SelectionTestCase):
backend = ndarray.NDArray
name = None
shape = 2
class TestNDArray2dSelection(bf.SelectionTestCase):
backend = ndarray.NDArray
name = None
shape = (2, 3)
class TestMultiTrace(bf.ModelBackendSetupTestCase):
name = None
backend = ndarray.NDArray
shape = ()
def setUp(self):
super(TestMultiTrace, self).setUp()
self.strace0 = self.strace
super(TestMultiTrace, self).setUp()
self.strace1 = self.strace
def test_multitrace_nonunique(self):
self.assertRaises(ValueError,
base.MultiTrace, [self.strace0, self.strace1])
def test_merge_traces_nonunique(self):
mtrace0 = base.MultiTrace([self.strace0])
mtrace1 = base.MultiTrace([self.strace1])
self.assertRaises(ValueError,
base.merge_traces, [mtrace0, mtrace1])
class TestSqueezeCat(unittest.TestCase):
def setUp(self):
self.x = np.arange(10)
self.y = np.arange(10, 20)
def test_combine_false_squeeze_false(self):
expected = [self.x, self.y]
result = base._squeeze_cat([self.x, self.y], False, False)
npt.assert_equal(result, expected)
def test_combine_true_squeeze_false(self):
expected = [np.concatenate([self.x, self.y])]
result = base._squeeze_cat([self.x, self.y], True, False)
npt.assert_equal(result, expected)
def test_combine_false_squeeze_true_more_than_one_item(self):
expected = [self.x, self.y]
result = base._squeeze_cat([self.x, self.y], False, True)
npt.assert_equal(result, expected)
def test_combine_false_squeeze_true_one_item(self):
expected = self.x
result = base._squeeze_cat([self.x], False, True)
npt.assert_equal(result, expected)
def test_combine_true_squeeze_true(self):
expected = np.concatenate([self.x, self.y])
result = base._squeeze_cat([self.x, self.y], True, True)
npt.assert_equal(result, expected)