Skip to content

Commit

Permalink
Add get_classical_any_host() and tests (#160)
Browse files Browse the repository at this point in the history
  • Loading branch information
san-gh committed Jun 7, 2023
1 parent 187b861 commit af3020a
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 0 deletions.
141 changes: 141 additions & 0 deletions integration_tests/test_get_classical_any_host.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import unittest
from qunetsim.components import Host, Network
from qunetsim.backends import EQSNBackend
import time

network = Network.get_instance()
hosts = {}

class TestGetClassicalAnyHost(unittest.TestCase):

MAX_WAIT_TIME = 10

@classmethod
def setUpClass(cls):
global network
global hosts
nodes = ["Alice", "Bob"]
backend = EQSNBackend()
network.start(nodes=nodes, backend=backend)
hosts = {'alice': Host('Alice', backend),
'bob': Host('Bob', backend)}
hosts['alice'].add_connection('Bob')
hosts['bob'].add_connection('Alice')
hosts['alice'].start()
hosts['bob'].start()
for h in hosts.values():
network.add_host(h)

def setUp(self) -> None:
hosts['bob']._classical_messages.empty()
hosts['alice']._classical_messages.empty()

@classmethod
def tearDownClass(cls):
global network
global hosts
network.stop(stop_hosts=True)



def test_get_all_with_wait_time(self):
def listen_with_wait_time(s):
time.sleep(2)
msgs = hosts['bob'].get_classical_any_host(seq_num=None, wait=self.MAX_WAIT_TIME)
self.assertEqual([x.content for x in msgs],['3','2','1'])

def send_some_with_delay(s):
hosts['alice'].send_classical(hosts['bob'].host_id,str(1),await_ack=True)
hosts['alice'].send_classical(hosts['bob'].host_id,str(2),await_ack=True)
time.sleep(5)
hosts['alice'].send_classical(hosts['bob'].host_id,str(3),await_ack=True)

t1 = hosts['bob'].run_protocol(listen_with_wait_time)
t2 = hosts['alice'].run_protocol(send_some_with_delay)

t1.join()
t2.join()

def test_get_seq_with_wait_time(self):
def listen_with_wait_time(s):
time.sleep(2)
msg = hosts['bob'].get_classical_any_host(seq_num=2, wait=self.MAX_WAIT_TIME)
self.assertEqual(msg.content,'3')

def send_some_with_delay(s):
hosts['alice'].send_classical(hosts['bob'].host_id,str(1),await_ack=True)
hosts['alice'].send_classical(hosts['bob'].host_id,str(2),await_ack=True)
time.sleep(5)
hosts['alice'].send_classical(hosts['bob'].host_id,str(3),await_ack=True)

t1 = hosts['bob'].run_protocol(listen_with_wait_time)
t2 = hosts['alice'].run_protocol(send_some_with_delay)

t1.join()
t2.join()

def test_get_seq_with_wait_time_none_value(self):
def listen_with_wait_time(s):
time.sleep(2)
msg = hosts['bob'].get_classical_any_host(seq_num=3, wait=self.MAX_WAIT_TIME)
self.assertEqual(msg,None)#seq_num not present

def send_some_with_delay(s):
hosts['alice'].send_classical(hosts['bob'].host_id,str(1),await_ack=True)
hosts['alice'].send_classical(hosts['bob'].host_id,str(2),await_ack=True)
time.sleep(5)
hosts['alice'].send_classical(hosts['bob'].host_id,str(3),await_ack=True)

t1 = hosts['bob'].run_protocol(listen_with_wait_time)
t2 = hosts['alice'].run_protocol(send_some_with_delay)

t1.join()
t2.join()

def test_get_all_with_wait_time_empty_arr(self):
def listen_with_wait_time(s):
msgs = hosts['bob'].get_classical_any_host(None, wait=self.MAX_WAIT_TIME)
self.assertEqual(msgs,[])

def send_after_max_wait(s):
time.sleep(self.MAX_WAIT_TIME+2)
hosts['alice'].send_classical(hosts['bob'].host_id,str(1),await_ack=True)

t1 = hosts['bob'].run_protocol(listen_with_wait_time)
t2 = hosts['alice'].run_protocol(send_after_max_wait)

t1.join()
t2.join()

def test_get_all_no_wait_time(self):
# no msgs yet
self.assertEqual(hosts['bob']._classical_messages.last_msg_added_to_host, None)
rec_msgs = hosts['bob'].get_classical_any_host(None, wait=0)
self.assertEqual(hosts['bob']._classical_messages.last_msg_added_to_host, None)
self.assertEqual(len(rec_msgs), 0)

# with some msgs
for c in range(5):
hosts['alice'].send_classical(hosts['bob'].host_id, str(c), await_ack=True)
rec_msgs = hosts['bob'].get_classical_any_host(None, wait=0)
self.assertEqual(hosts['bob']._classical_messages.last_msg_added_to_host, 'Alice')
self.assertEqual(len(rec_msgs), 5)

def test_get_seq_no_wait_time(self):
# no msgs yet
self.assertEqual(hosts['bob']._classical_messages.last_msg_added_to_host, None)
rec_msg = hosts['bob'].get_classical_any_host(0, wait=0)
self.assertEqual(hosts['bob']._classical_messages.last_msg_added_to_host, None)
self.assertEqual(rec_msg, None)

# with some msgs
for c in range(5):
hosts['alice'].send_classical(hosts['bob'].host_id, str(c), await_ack=True)
rec_msg = hosts['bob'].get_classical_any_host(4, wait=0)
self.assertEqual(hosts['bob']._classical_messages.last_msg_added_to_host, 'Alice')
self.assertEqual(rec_msg.content, '4')

def test_wait_data_type(self):
self.assertRaises(Exception, hosts['bob'].get_classical_any_host, None, "1")


10 changes: 10 additions & 0 deletions qunetsim/components/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -1454,6 +1454,16 @@ def get_classical(self, host_id, seq_num=None, wait=0):
cla = self._classical_messages.get_all_from_sender(host_id, wait)
return sorted(cla, key=lambda x: x.seq_num, reverse=True)

def get_classical_any_host(self, seq_num=None, wait=0):
if not isinstance(wait, float) and not isinstance(wait, int):
raise Exception('wait parameter should be a number')

if seq_num is not None:
return self._classical_messages.get_with_seq_num_from_any_sender(seq_num,wait)

cla = self._classical_messages.get_all_from_any_sender(wait)
return sorted(cla, key=lambda x: x.seq_num, reverse=True)

def get_next_classical(self, sender_id, wait=-1):
"""
Gets the next classical message available from a sender.
Expand Down
90 changes: 90 additions & 0 deletions qunetsim/objects/storage/classical_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ class ClassicalStorage(object):
GET_NEXT = 1
GET_ALL = 2
GET_WITH_SEQ_NUM = 3
GET_ALL_MSGS_ANY_HOST = 4
GET_WITH_SEQ_NUM_ANY_HOST = 5

def __init__(self):
self._host_to_msg_dict = {}
self._host_to_read_index = {}
self.last_msg_added_to_host = None

# read write lock, for threaded access
self._lock = RWLock()
Expand Down Expand Up @@ -43,6 +46,12 @@ def _check_all_requests(self):
ret = self._get_all_from_sender(args[1])
elif args[2] == ClassicalStorage.GET_WITH_SEQ_NUM:
ret = self._get_with_seq_num_from_sender(args[1], args[3])
elif args[2] == ClassicalStorage.GET_ALL_MSGS_ANY_HOST:
ret = self._get_all_from_sender(self.last_msg_added_to_host) \
if self.last_msg_added_to_host is not None else None
elif args[2] == ClassicalStorage.GET_WITH_SEQ_NUM_ANY_HOST:
ret = self._get_with_seq_num_from_sender(self.last_msg_added_to_host, args[3]) \
if self.last_msg_added_to_host is not None else None
else:
raise ValueError("Internal Error, this request does not exist!")

Expand Down Expand Up @@ -84,6 +93,7 @@ def empty(self):
self._lock.acquire_write()
self._host_to_msg_dict = {}
self._host_to_read_index = {}
self.last_msg_added_to_host = None
self._lock.release_write()

def _add_new_host_id(self, host_id):
Expand Down Expand Up @@ -129,6 +139,7 @@ def add_msg_to_storage(self, message):
if sender_id not in list(self._host_to_msg_dict):
self._add_new_host_id(sender_id)
self._host_to_msg_dict[sender_id].append(message)
self.last_msg_added_to_host = sender_id
self._check_all_requests()
self._lock.release_write()

Expand Down Expand Up @@ -269,6 +280,85 @@ def _get_with_seq_num_from_sender(self, sender_id, seq_num):
return None
msg = self._host_to_msg_dict[sender_id][seq_num]
return msg

def get_all_from_any_sender(self,wait=0):
"""
Get all stored messages from any sender. If delete option is set,
the returned messages are removed from the storage.
Args:
wait (int): Default is 0. The maximum blocking time. -1 to block forever.
Returns:
List of messages of the sender. If there are none, an empty list is
returned.
"""

# Block forever if wait is -1
if wait == -1:
wait = None

self._lock.acquire_write()
msg = None
if self.last_msg_added_to_host is not None:
msg = self.get_all_from_sender(self.last_msg_added_to_host)

if wait == 0:
self._lock.release_write()
return msg if msg is not None else []

q = queue.Queue()
request = [q, None, ClassicalStorage.GET_ALL_MSGS_ANY_HOST]
req_id = self._add_request(request)
self._lock.release_write()

try:
msg = q.get(timeout=wait)
except queue.Empty:
pass


if msg is None:
self._lock.acquire_write()
self._remove_request(req_id)
self._lock.release_write()
return []
return msg

def get_with_seq_num_from_any_sender(self, seq_num, wait=0):
'''
Returns:
Message object, if such a message exists, or none.
'''
# Block forever if wait is -1
if wait == -1:
wait = None


self._lock.acquire_write()
next_msg = None
if self.last_msg_added_to_host is not None:
next_msg = self.get_with_seq_num_from_sender(self.last_msg_added_to_host,seq_num)

if wait == 0:
self._lock.release_write()
return next_msg

q = queue.Queue()
request = [q, None, ClassicalStorage.GET_WITH_SEQ_NUM_ANY_HOST, seq_num]
req_id = self._add_request(request)
self._lock.release_write()

try:
next_msg = q.get(timeout=wait)
except queue.Empty:
pass

if next_msg is None:
self._lock.acquire_write()
self._remove_request(req_id)
self._lock.release_write()
return next_msg

def get_all(self):
"""
Expand Down

0 comments on commit af3020a

Please sign in to comment.