Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 48 additions & 44 deletions pylammpsmpi/mpi/lmpmpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,8 @@
# taken directly from atom.cpp -> extract()
}

# Lammps executable
args = ["-screen", "none"]
if len(sys.argv) > 3:
args.extend(sys.argv[3:])
job = lammps(cmdargs=args)


def extract_compute(funct_args):
def extract_compute(job, funct_args):
def convert_data(val, type, length, width):
data = []
if type == 2:
Expand Down Expand Up @@ -88,42 +82,42 @@ def convert_data(val, type, length, width):
raise ValueError("Local style is currently not supported")


def get_version(funct_args):
def get_version(job, funct_args):
if MPI.COMM_WORLD.rank == 0:
return job.version()


def get_file(funct_args):
def get_file(job, funct_args):
job.file(*funct_args)
return 1


def commands_list(funct_args):
def commands_list(job, funct_args):
job.commands_list(*funct_args)
return 1


def commands_string(funct_args):
def commands_string(job, funct_args):
job.commands_string(*funct_args)
return 1


def extract_setting(funct_args):
def extract_setting(job, funct_args):
if MPI.COMM_WORLD.rank == 0:
return job.extract_setting(*funct_args)


def extract_global(funct_args):
def extract_global(job, funct_args):
if MPI.COMM_WORLD.rank == 0:
return job.extract_global(*funct_args)


def extract_box(funct_args):
def extract_box(job, funct_args):
if MPI.COMM_WORLD.rank == 0:
return job.extract_box(*funct_args)


def extract_atom(funct_args):
def extract_atom(job, funct_args):
if MPI.COMM_WORLD.rank == 0:
# extract atoms return an internal data type
# this has to be reformatted
Expand Down Expand Up @@ -153,12 +147,12 @@ def extract_atom(funct_args):
return np.array(data)


def extract_fix(funct_args):
def extract_fix(job, funct_args):
if MPI.COMM_WORLD.rank == 0:
return job.extract_fix(*funct_args)


def extract_variable(funct_args):
def extract_variable(job, funct_args):
# in the args - if the third one,
# which is the type is 1 - a lammps array is returned
if funct_args[2] == 1:
Expand All @@ -177,21 +171,21 @@ def extract_variable(funct_args):
return data


def get_natoms(funct_args):
def get_natoms(job, funct_args):
if MPI.COMM_WORLD.rank == 0:
return job.get_natoms()


def set_variable(funct_args):
def set_variable(job, funct_args):
return job.set_variable(*funct_args)


def reset_box(funct_args):
def reset_box(job, funct_args):
job.reset_box(*funct_args)
return 1


def gather_atoms(funct_args):
def gather_atoms(job, funct_args):
# extract atoms return an internal data type
# this has to be reformatted
name = str(funct_args[0])
Expand All @@ -217,7 +211,7 @@ def gather_atoms(funct_args):
return np.array(data)


def gather_atoms_concat(funct_args):
def gather_atoms_concat(job, funct_args):
# extract atoms return an internal data type
# this has to be reformatted
name = str(funct_args[0])
Expand All @@ -243,7 +237,7 @@ def gather_atoms_concat(funct_args):
return np.array(data)


def gather_atoms_subset(funct_args):
def gather_atoms_subset(job, funct_args):
# convert to ctypes
name = str(funct_args[0])
lenids = int(funct_args[1])
Expand Down Expand Up @@ -280,76 +274,76 @@ def gather_atoms_subset(funct_args):
return np.array(data)


def create_atoms(funct_args):
def create_atoms(job, funct_args):
job.create_atoms(*funct_args)
return 1


def has_exceptions(funct_args):
def has_exceptions(job, funct_args):
return job.has_exceptions


def has_gzip_support(funct_args):
def has_gzip_support(job, funct_args):
return job.has_gzip_support


def has_png_support(funct_args):
def has_png_support(job, funct_args):
return job.has_png_support


def has_jpeg_support(funct_args):
def has_jpeg_support(job, funct_args):
return job.has_jpeg_support


def has_ffmpeg_support(funct_args):
def has_ffmpeg_support(job, funct_args):
return job.has_ffmpeg_support


def installed_packages(funct_args):
def installed_packages(job, funct_args):
return job.installed_packages


def set_fix_external_callback(funct_args):
def set_fix_external_callback(job, funct_args):
job.set_fix_external_callback(*funct_args)
return 1


def get_neighlist(funct_args):
def get_neighlist(job, funct_args):
if MPI.COMM_WORLD.rank == 0:
return job.get_neighlist(*funct_args)


def find_pair_neighlist(funct_args):
def find_pair_neighlist(job, funct_args):
if MPI.COMM_WORLD.rank == 0:
return job.find_pair_neighlist(*funct_args)


def find_fix_neighlist(funct_args):
def find_fix_neighlist(job, funct_args):
if MPI.COMM_WORLD.rank == 0:
return job.find_fix_neighlist(*funct_args)


def find_compute_neighlist(funct_args):
def find_compute_neighlist(job, funct_args):
if MPI.COMM_WORLD.rank == 0:
return job.find_compute_neighlist(*funct_args)


def get_neighlist_size(funct_args):
def get_neighlist_size(job, funct_args):
if MPI.COMM_WORLD.rank == 0:
return job.get_neighlist_size(*funct_args)


def get_neighlist_element_neighbors(funct_args):
def get_neighlist_element_neighbors(job, funct_args):
if MPI.COMM_WORLD.rank == 0:
return job.get_neighlist_element_neighbors(*funct_args)


def get_thermo(funct_args):
def get_thermo(job, funct_args):
if MPI.COMM_WORLD.rank == 0:
return np.array(job.get_thermo(*funct_args))


def scatter_atoms(funct_args):
def scatter_atoms(job, funct_args):
name = str(funct_args[0])
py_vector = funct_args[1]
# now see if its an integer or double type- but before flatten
Expand All @@ -366,7 +360,7 @@ def scatter_atoms(funct_args):
return 1


def scatter_atoms_subset(funct_args):
def scatter_atoms_subset(job, funct_args):
name = str(funct_args[0])
lenids = int(funct_args[2])
ids = funct_args[3]
Expand Down Expand Up @@ -396,7 +390,7 @@ def scatter_atoms_subset(funct_args):
return 1


def command(funct_args):
def command(job, funct_args):
job.command(funct_args)
return 1

Expand Down Expand Up @@ -464,13 +458,19 @@ def _gather_data_from_all_processors(data):
return data


if __name__ == "__main__":
def _run_lammps_mpi(argument_lst):
if MPI.COMM_WORLD.rank == 0:
context = zmq.Context()
socket = context.socket(zmq.PAIR)
argument_lst = sys.argv
port_selected = argument_lst[argument_lst.index("--zmqport") + 1]
socket.connect("tcp://localhost:" + port_selected)
else:
context, socket = None, None
# Lammps executable
args = ["-screen", "none"]
if len(argument_lst) > 3:
args.extend(argument_lst[3:])
job = lammps(cmdargs=args)
while True:
if MPI.COMM_WORLD.rank == 0:
input_dict = cloudpickle.loads(socket.recv())
Expand All @@ -485,8 +485,12 @@ def _gather_data_from_all_processors(data):
context.term()
job.close()
break
output = select_cmd(input_dict["c"])(input_dict["d"])
output = select_cmd(input_dict["c"])(job=job, funct_args=input_dict["d"])
if MPI.COMM_WORLD.rank == 0 and output is not None:
# with open('process.txt', 'a') as file:
# print('Output:', output, file=file)
socket.send(cloudpickle.dumps({"r": output}))


if __name__ == "__main__":
_run_lammps_mpi(argument_lst=sys.argv)