/
test_theanof.py
50 lines (42 loc) · 1.41 KB
/
test_theanof.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import itertools
import unittest
import numpy as np
import theano
from ..theanof import DataGenerator, GeneratorOp, generator
def integers():
i = 0
while True:
yield np.float32(i)
i += 1
def integers_ndim(ndim):
i = 0
while True:
yield np.ones((2,) * ndim) * i
i += 1
class TestGenerator(unittest.TestCase):
def test_basic(self):
generator = DataGenerator(integers())
gop = GeneratorOp(generator)()
self.assertEqual(gop.tag.test_value, np.float32(0))
f = theano.function([], gop)
self.assertEqual(f(), np.float32(0))
self.assertEqual(f(), np.float32(1))
for i in range(2, 100):
f()
self.assertEqual(f(), np.float32(100))
def test_ndim(self):
for ndim in range(10):
res = list(itertools.islice(integers_ndim(ndim), 0, 2))
generator = DataGenerator(integers_ndim(ndim))
gop = GeneratorOp(generator)()
f = theano.function([], gop)
self.assertEqual(ndim, res[0].ndim)
np.testing.assert_equal(f(), res[0])
np.testing.assert_equal(f(), res[1])
def test_cloning_available(self):
gop = generator(integers())
res = gop ** 2
shared = theano.shared(np.float32(10))
res1 = theano.clone(res, {gop: shared})
f = theano.function([], res1)
self.assertEqual(f(), np.float32(100))