Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions autotest/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3251,11 +3251,11 @@ def gpr_zdt1_ppw():


if __name__ == "__main__":
geostat_draws_test('.')
#geostat_draws_test('.')
#fac2real_wrapped_test('.')
#maha_pdc_test('.')
#ppu_geostats_test(".")
#pypestworker_test()
pypestworker_test()
#gpr_zdt1_test()
#gpr_compare_invest()
#gpr_constr_test()
Expand Down
65 changes: 40 additions & 25 deletions pyemu/utils/os_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from ..pyemu_warnings import PyemuWarning
from ..pst import pst_handler
from ..logger import Logger

ext = ""
bin_path = os.path.join("..", "bin")
Expand Down Expand Up @@ -783,6 +784,9 @@ def __init__(self, pst, host, port, timeout=0.25,verbose=True, socket_timeout=No
self.socket_timeout = socket_timeout
self.par_values = None
self.max_reconnect_attempts = 10
self.logger_filename = "pypestworker_{0}.txt".format(datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f"))
self.logger = Logger(self.logger_filename,echo=verbose)
self.message("PyPestWorker starting with timeout:{0} and socket_timeout:{1}".format(self.timeout, self.socket_timeout))
self._process_pst()
self.connect()
self._lock = threading.Lock()
Expand All @@ -792,6 +796,7 @@ def __init__(self, pst, host, port, timeout=0.25,verbose=True, socket_timeout=No


def _process_pst(self):
self.message("processing control file")
if isinstance(self._pst_arg,str):
self._pst = pst_handler.Pst(self._pst_arg)
elif isinstance(self._pst_arg,pst_handler.Pst):
Expand All @@ -802,19 +807,19 @@ def _process_pst(self):


def connect(self,is_reconnect=False):
self.message("trying to connect to {0}:{1}...".format(self.host,self.port),echo=True)
self.message("trying to connect to {0}:{1}...".format(self.host,self.port))
self.s = None
c = 0
while True:
try:
time.sleep(self.timeout)
c += 1
if is_reconnect and c > self.max_reconnect_attempts:
print("max reconnect attempts reached...")
self.message("max reconnect attempts reached...",True)
return False
self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.s.connect((self.host, self.port))
self.message("connected to {0}:{1}".format(self.host,self.port),echo=True)
self.message("connected to {0}:{1}".format(self.host,self.port))
break

except ConnectionRefusedError:
Expand All @@ -827,24 +832,28 @@ def connect(self,is_reconnect=False):


def message(self,msg,echo=False):
self.logger.statement(msg)
if self.verbose or echo:
print(str(datetime.now())+" : "+msg)


def recv(self,dtype=None):
n = self.net_pack.recv(self.s,dtype=dtype)
if n > 0:
self.message("recv'd message type:{0}".format(NetPack.netpack_type[self.net_pack.mtype]))
self.message("recv'd message type:{0}, group:{1}, run_id:{2}, desc:{3}"\
.format(NetPack.netpack_type[self.net_pack.mtype],
self.net_pack.group, self.net_pack.runid,self.net_pack.desc))
return n


def send(self,mtype,group,runid,desc="",data=0):
try:
self.net_pack.send(self.s,mtype,group,runid,desc,data)
except Exception as e:
print("WARNING: error sending message:{0}".format(str(e)))
self.message("WARNING: error sending message:{0}".format(str(e)), True)
return False
self.message("sent message type:{0}".format(NetPack.netpack_type[mtype]))
self.message("sent message type:{0}, group: {1}, run_id:{2}, desc:{3}".\
format(NetPack.netpack_type[mtype], group, runid, desc))
return True

def listen(self,lock=None,send_lock=None):
Expand All @@ -855,18 +864,18 @@ def listen(self,lock=None,send_lock=None):
try:
n = self.recv()
except Exception as e:
print("WARNING: recv exception:"+str(e)+"...trying to reconnect...")
self.message("WARNING: recv exception:"+str(e)+"...trying to reconnect...", True)
success = self.connect(is_reconnect=True)
if not success:
print("...exiting")
self.message("...exiting")
time.sleep(self.timeout)
# set the teminate flag so that the get_pars() look will exit
self._lock.acquire()
self.net_pack.mtype = 14
self._lock.release()
return
else:
print("...reconnected successfully...")
self.message("...reconnected successfully...", True)
continue

if n > 0:
Expand All @@ -882,26 +891,34 @@ def listen(self,lock=None,send_lock=None):
elif self.net_pack.mtype == 4:
if self._send_lock is not None:
self._send_lock.acquire()
self.send(mtype=5, group=self.net_pack.group,
success = self.send(mtype=5, group=self.net_pack.group,
runid=self.net_pack.runid,
desc="sending cwd", data=os.getcwd())
if self._send_lock is not None:
self._send_lock.release()
if not success:
self.message("failed cwd send...trying to reconnect...", True)
success = self.connect(is_reconnect=True)
if not success:
self.message("...exiting", True)
time.sleep(self.timeout)
return
else:
self.message("reconnect successfully...", True)
continue

elif self.net_pack.mtype == 8:
self.par_names = self.net_pack.data_pak
diff = set(self.par_names).symmetric_difference(set(self._pst.par_names))
if len(diff) > 0:
print("WARNING: pst par names != master par names")
self.message("WARNING: the following par names are not common\n"+\
" between the control file and the master:{0}".format(','.join(diff)))
" between the control file and the master:{0}".format(','.join(diff)), True)
elif self.net_pack.mtype == 9:
self.obs_names = self.net_pack.data_pak
diff = set(self.obs_names).symmetric_difference(set(self._pst.obs_names))
if len(diff) > 0:
print("WARNING: pst obs names != master obs names")
self.message("WARNING: the following obs names are not common\n"+\
" between the control file and the master:{0}".format(','.join(diff)))
" between the control file and the master:{0}".format(','.join(diff)), True)

elif self.net_pack.mtype == 6:
if self._send_lock is not None:
Expand All @@ -912,14 +929,14 @@ def listen(self,lock=None,send_lock=None):
if self._send_lock is not None:
self._send_lock.release()
if not success:
print("...trying to reconnect...")
self.message("failed linpack send...trying to reconnect...", True)
success = self.connect(is_reconnect=True)
if not success:
print("...exiting")
self.message("...exiting",True)
time.sleep(self.timeout)
return
else:
print("reconnect successfully...")
self.message("reconnect successfully...", True)
continue

elif self.net_pack.mtype == 15:
Expand All @@ -931,25 +948,23 @@ def listen(self,lock=None,send_lock=None):
if self._send_lock is not None:
self._send_lock.release()
if not success:
print("...trying to reconnect...")
self.message("failed ping back...trying to reconnect...", True)
success = self.connect(is_reconnect=True)
if not success:
print("...exiting")
self.message("...exiting",True)
time.sleep(self.timeout)
return
else:
print("reconnect successfully...")
self.message("reconnect successfully...", True)
continue
elif self.net_pack.mtype == 14:
#print("recv'd terminate signal")
self.message("recv'd terminate signal")
self.message("recv'd terminate signal", True)
return
elif self.net_pack.mtype == 16:
print("master is requesting run kill...")
self.message("master is requesting run kill...")
self.message("master is requesting run kill...", True)

else:
print("WARNING: unsupported request received: {0}".format(NetPack.netpack_type[self.net_pack.mtype]))
self.message("WARNING: unsupported request received: {0}".format(NetPack.netpack_type[self.net_pack.mtype]), True)


def get_parameters(self):
Expand Down
Loading