Skip to content

Commit fa74236

Browse files
Add tests for NMT base and NMT master (#504)
1 parent c4560da commit fa74236

File tree

1 file changed

+122
-4
lines changed

1 file changed

+122
-4
lines changed

test/test_nmt.py

Lines changed: 122 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,139 @@
11
import time
22
import unittest
33

4+
import can
45
import canopen
6+
from canopen.nmt import COMMAND_TO_STATE, NMT_STATES, NMT_COMMANDS, NmtError
57
from .util import SAMPLE_EDS
68

79

10+
class TestNmtBase(unittest.TestCase):
11+
def setUp(self):
12+
node_id = 2
13+
self.node_id = node_id
14+
self.nmt = canopen.nmt.NmtBase(node_id)
15+
16+
def test_send_command(self):
17+
dataset = (
18+
"OPERATIONAL",
19+
"PRE-OPERATIONAL",
20+
"SLEEP",
21+
"STANDBY",
22+
"STOPPED",
23+
)
24+
for cmd in dataset:
25+
with self.subTest(cmd=cmd):
26+
code = NMT_COMMANDS[cmd]
27+
self.nmt.send_command(code)
28+
expected = NMT_STATES[COMMAND_TO_STATE[code]]
29+
self.assertEqual(self.nmt.state, expected)
30+
31+
def test_state_getset(self):
32+
for state in NMT_STATES.values():
33+
with self.subTest(state=state):
34+
self.nmt.state = state
35+
self.assertEqual(self.nmt.state, state)
36+
37+
def test_state_set_invalid(self):
38+
with self.assertRaisesRegex(ValueError, "INVALID"):
39+
self.nmt.state = "INVALID"
40+
41+
42+
class TestNmtMaster(unittest.TestCase):
43+
NODE_ID = 2
44+
COB_ID = 0x700 + NODE_ID
45+
PERIOD = 0.01
46+
TIMEOUT = PERIOD * 2
47+
48+
def setUp(self):
49+
bus = can.ThreadSafeBus(
50+
interface="virtual",
51+
channel="test",
52+
receive_own_messages=True,
53+
)
54+
net = canopen.Network(bus)
55+
net.connect()
56+
with self.assertLogs():
57+
node = net.add_node(self.NODE_ID, SAMPLE_EDS)
58+
59+
self.bus = bus
60+
self.net = net
61+
self.node = node
62+
63+
def tearDown(self):
64+
self.net.disconnect()
65+
66+
def test_nmt_master_no_heartbeat(self):
67+
with self.assertRaisesRegex(NmtError, "heartbeat"):
68+
self.node.nmt.wait_for_heartbeat(self.TIMEOUT)
69+
with self.assertRaisesRegex(NmtError, "boot-up"):
70+
self.node.nmt.wait_for_bootup(self.TIMEOUT)
71+
72+
def test_nmt_master_on_heartbeat(self):
73+
# Skip the special INITIALISING case.
74+
for code in [st for st in NMT_STATES if st != 0]:
75+
with self.subTest(code=code):
76+
task = self.net.send_periodic(self.COB_ID, [code], self.PERIOD)
77+
try:
78+
actual = self.node.nmt.wait_for_heartbeat(self.TIMEOUT)
79+
finally:
80+
task.stop()
81+
expected = NMT_STATES[code]
82+
self.assertEqual(actual, expected)
83+
84+
def test_nmt_master_on_heartbeat_initialising(self):
85+
task = self.net.send_periodic(self.COB_ID, [0], self.PERIOD)
86+
self.addCleanup(task.stop)
87+
self.node.nmt.wait_for_bootup(self.TIMEOUT)
88+
state = self.node.nmt.wait_for_heartbeat(self.TIMEOUT)
89+
self.assertEqual(state, "PRE-OPERATIONAL")
90+
91+
@unittest.expectedFailure
92+
def test_nmt_master_on_heartbeat_unknown_state(self):
93+
task = self.net.send_periodic(self.COB_ID, [0xcb], self.PERIOD)
94+
self.addCleanup(task.stop)
95+
state = self.node.nmt.wait_for_heartbeat(self.TIMEOUT)
96+
# Expect the high bit to be masked out, and and unknown state string to
97+
# be returned.
98+
self.assertEqual(state, "UNKNOWN STATE '75'")
99+
100+
def test_nmt_master_add_heartbeat_callback(self):
101+
from threading import Event
102+
event = Event()
103+
state = None
104+
def hook(st):
105+
nonlocal state
106+
state = st
107+
event.set()
108+
self.node.nmt.add_heartbeat_callback(hook)
109+
self.net.send_message(self.COB_ID, bytes([127]))
110+
self.assertTrue(event.wait(self.TIMEOUT))
111+
self.assertEqual(state, 127)
112+
113+
def test_nmt_master_node_guarding(self):
114+
self.node.nmt.start_node_guarding(self.PERIOD)
115+
msg = self.bus.recv(self.TIMEOUT)
116+
self.assertIsNotNone(msg)
117+
self.assertEqual(msg.arbitration_id, self.COB_ID)
118+
self.assertEqual(msg.dlc, 0)
119+
120+
self.node.nmt.stop_node_guarding()
121+
self.assertIsNone(self.bus.recv(self.TIMEOUT))
122+
123+
8124
class TestNmtSlave(unittest.TestCase):
9125
def setUp(self):
10126
self.network1 = canopen.Network()
11127
self.network1.connect("test", interface="virtual")
12-
self.remote_node = self.network1.add_node(2, SAMPLE_EDS)
128+
with self.assertLogs():
129+
self.remote_node = self.network1.add_node(2, SAMPLE_EDS)
13130

14131
self.network2 = canopen.Network()
15132
self.network2.connect("test", interface="virtual")
16-
self.local_node = self.network2.create_node(2, SAMPLE_EDS)
17-
self.remote_node2 = self.network1.add_node(3, SAMPLE_EDS)
18-
self.local_node2 = self.network2.create_node(3, SAMPLE_EDS)
133+
with self.assertLogs():
134+
self.local_node = self.network2.create_node(2, SAMPLE_EDS)
135+
self.remote_node2 = self.network1.add_node(3, SAMPLE_EDS)
136+
self.local_node2 = self.network2.create_node(3, SAMPLE_EDS)
19137

20138
def tearDown(self):
21139
self.network1.disconnect()

0 commit comments

Comments
 (0)