-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_syn.py
91 lines (78 loc) · 3.01 KB
/
test_syn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import numpy as np
from numpy import testing as t
from unittest import TestCase
from pp.mdp import GridWorldMDP
import itertools
from syn import *
class TestGenPredictGoal(TestCase):
def test_no_crash(self):
g = GridWorldMDP(15, 15)
goals = [g.coor_to_state(9, 9), g.coor_to_state(1, 1),
g.coor_to_state(3, 3)]
data = gen_predict_goal(g, goals=goals, k=5)
self.assertEqual(data.N, len(data.Y))
class TestGenPredictPolicy(TestCase):
def test_no_crash(self):
g = GridWorldMDP(15, 15)
goals = [g.coor_to_state(9, 9), g.coor_to_state(1, 1),
g.coor_to_state(3, 3)]
data = gen_predict_policy(g, goals=goals)
self.assertEqual(data.N, len(data.Y))
def test_no_crash2(self):
g = GridWorldMDP(15, 15)
goals = [g.coor_to_state(9, 9), g.coor_to_state(1, 1),
g.coor_to_state(3, 3)]
data = gen_predict_policy(g, goals=goals, samples=30)
self.assertEqual(data.N, len(data.Y))
for y in data.Y:
self.assertTrue(0 <= y < g.A)
for z in data.Z:
self.assertTrue(0 <= z < len(goals))
class TestGenPredictPolicy2(TestCase):
def test_no_crash(self):
g = GridWorldMDP(15, 15)
goals = [g.coor_to_state(9, 9), g.coor_to_state(1, 1),
g.coor_to_state(3, 3)]
data = gen_predict_policy2(g, goals=goals)
self.assertEqual(data.N, len(data.Y))
class TestGenPredictTraj(TestCase):
def test_no_crash(self):
g = GridWorldMDP(15, 15)
goals = [g.coor_to_state(9, 9), g.coor_to_state(1, 1),
g.coor_to_state(3, 3)]
data = gen_predict_traj(g, goals=goals, k=3, l=3)
self.assertEqual(data.N, len(data.Y))
class TestGenPredictActions(TestCase):
def test_no_crash(self):
g = GridWorldMDP(15, 15)
goals = [g.coor_to_state(9, 9), g.coor_to_state(1, 1),
g.coor_to_state(3, 3)]
data = gen_predict_actions(g, goals=goals, k=3, l=3)
self.assertEqual(data.N, len(data.Y))
class TestData(TestCase):
def test_batch_one(self):
data = Data([1], [2])
x, y = data.get_batch(1)
self.assertEqual(len(x), len(y))
self.assertEqual(len(x), 1)
self.assertEqual(x, [1])
self.assertEqual(y, [2])
def test_batch_twelve(self):
data = Data(range(20), np.arange(20))
for i in range(10):
x, y = data.get_batch(12)
self.assertEqual(len(x), len(y))
self.assertEqual(len(x), 12)
t.assert_equal(x, y)
def test_batch_twelve_aux(self):
data = Data(range(20), np.arange(20), np.arange(20))
for i in range(10):
x, y, z = data.get_batch(12, enable_aux=True)
self.assertEqual(len(x), len(y))
self.assertEqual(len(x), 12)
t.assert_equal(x, y)
t.assert_equal(y, z)
class TestPuddlesWorld(TestCase):
def test_puddles(self):
g = puddles_world(10, p=0.5)
import pdb; pdb.set_trace()