Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions tensorflowonspark/TFSparkNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,15 @@ def _mapfn(iter):
tb_pid = 0
tb_port = 0
if tensorboard and job_name == tb_job_name and task_index == 0:
tb_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
tb_sock.bind(('', 0))
tb_port = tb_sock.getsockname()[1]
tb_sock.close()
if 'TENSORBOARD_PORT' in os.environ:
# use port defined in env var
tb_port = int(os.environ['TENSORBOARD_PORT'])
else:
# otherwise, find a free port
tb_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
tb_sock.bind(('', 0))
tb_port = tb_sock.getsockname()[1]
tb_sock.close()
logdir = log_dir if log_dir else "tensorboard_%d" % executor_id

# search for tensorboard in python/bin, PATH, and PYTHONPATH
Expand Down Expand Up @@ -250,11 +255,15 @@ def _mapfn(iter):

# if not already done, register everything we need to set up the cluster
if node_meta is None:
# first, find a free port for TF
tmp_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
tmp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
tmp_sock.bind(('', port))
port = tmp_sock.getsockname()[1]
if 'TENSORFLOW_PORT' in os.environ:
# use port defined in env var
port = int(os.environ['TENSORFLOW_PORT'])
else:
# otherwise, find a free port
tmp_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
tmp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
tmp_sock.bind(('', port))
port = tmp_sock.getsockname()[1]

node_meta = {
'executor_id': executor_id,
Expand Down
23 changes: 12 additions & 11 deletions test/test_reservation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
else:
import mock


class ReservationTest(unittest.TestCase):
def test_reservation_class(self):
"""Test core reservation class, expecting 2 reservations"""
Expand Down Expand Up @@ -56,23 +57,23 @@ def test_reservation_server(self):
self.assertEqual(s.done, True)

def test_reservation_enviroment_exists_get_server_ip_return_environment_value(self):
tfso_server = Server(5)
with mock.patch.dict(os.environ,{'TFOS_SERVER_HOST':'my_host_ip'}):
assert tfso_server.get_server_ip() == "my_host_ip"
tfos_server = Server(5)
with mock.patch.dict(os.environ, {'TFOS_SERVER_HOST': 'my_host_ip'}):
assert tfos_server.get_server_ip() == "my_host_ip"

def test_reservation_enviroment_not_exists_get_server_ip_return_actual_host_ip(self):
tfso_server = Server(5)
assert tfso_server.get_server_ip() == util.get_ip_address()
tfos_server = Server(5)
assert tfos_server.get_server_ip() == util.get_ip_address()

def test_reservation_enviroment_exists_start_listening_socket_return_socket_listening_to_environment_port_value(self):
tfso_server = Server(1)
tfos_server = Server(1)
with mock.patch.dict(os.environ, {'TFOS_SERVER_PORT': '9999'}):
assert tfso_server.start_listening_socket().getsockname()[1] == 9999
assert tfos_server.start_listening_socket().getsockname()[1] == 9999

def test_reservation_enviroment_not_exists_start_listening_socket_return_socket(self):
tfso_server = Server(1)
print(tfso_server.start_listening_socket().getsockname()[1])
assert type(tfso_server.start_listening_socket().getsockname()[1]) == int
tfos_server = Server(1)
print(tfos_server.start_listening_socket().getsockname()[1])
assert type(tfos_server.start_listening_socket().getsockname()[1]) == int

def test_reservation_server_multi(self):
"""Test reservation server, expecting multiple reservations"""
Expand Down Expand Up @@ -111,4 +112,4 @@ def reserve(num):


if __name__ == '__main__':
unittest.main()
unittest.main()