diff --git a/pylammpsmpi/mpi/lmpmpi.py b/pylammpsmpi/mpi/lmpmpi.py index b933d35d..8d06f730 100644 --- a/pylammpsmpi/mpi/lmpmpi.py +++ b/pylammpsmpi/mpi/lmpmpi.py @@ -499,6 +499,9 @@ def _gather_data_from_all_processors(data): input_dict = None input_dict = MPI.COMM_WORLD.bcast(input_dict, root=0) if input_dict["c"] == "close": + if MPI.COMM_WORLD.rank == 0: + socket.close() + context.term() job.close() break output = select_cmd(input_dict["c"])(input_dict["d"]) diff --git a/pylammpsmpi/utils/lammps.py b/pylammpsmpi/utils/lammps.py index 437cdf4c..2f61f9b5 100644 --- a/pylammpsmpi/utils/lammps.py +++ b/pylammpsmpi/utils/lammps.py @@ -30,13 +30,14 @@ def __init__( self._oversubscribe = oversubscribe self._cmdargs = cmdargs self._socket = None + self._context = None def start_process(self): executable = os.path.join( os.path.dirname(os.path.abspath(__file__)), "../mpi", "lmpmpi.py" ) - context = zmq.Context() - self._socket = context.socket(zmq.PAIR) + self._context = zmq.Context() + self._socket = self._context.socket(zmq.PAIR) port_selected = self._socket.bind_to_random_port("tcp://*") cmds = ["mpiexec"] if self._oversubscribe: @@ -677,10 +678,16 @@ def close(self): self._send(command="close") try: self._process.kill() + self._process.stdout.close() + self._process.stdin.close() + self._process.wait() + self._socket.close() + self._context.term() except AttributeError: pass self._process = None self._socket = None + self._context = None # TODO def __del__(self):