Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 27 additions & 17 deletions pylammpsmpi/mpi/lmpmpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,10 @@ def convert_data(val, type, length, width):
val = job.extract_compute(*filtered_args)
return convert_data(val=val, type=type, length=length, width=width)
elif style == 1: # per atom property
val = job.numpy.extract_compute(*filtered_args)
val_gather = MPI.COMM_WORLD.gather(val, root=0)
val = _gather_data_from_all_processors(
data=job.numpy.extract_compute(*filtered_args)
)
if MPI.COMM_WORLD.rank == 0:
# val_gather.shape [number of cores, atoms on specific core]
# the number of atoms on specific cores can vary
val = []
for vl in val_gather:
for v in vl:
val.append(v)
length = job.get_natoms()
return convert_data(val=val, type=type, length=length, width=width)
else: # Todo
Expand Down Expand Up @@ -165,15 +160,20 @@ def extract_fix(funct_args):
def extract_variable(funct_args):
# in the args - if the third one,
# which is the type is 1 - a lammps array is returned
if MPI.COMM_WORLD.rank == 0:
# if type is 1 - reformat file
try:
data = job.extract_variable(*funct_args)
except ValueError:
return []
if funct_args[2] == 1:
data = np.array(data)
return data
if funct_args[2] == 1:
data = _gather_data_from_all_processors(
data=job.numpy.extract_variable(*funct_args)
)
if MPI.COMM_WORLD.rank == 0:
return np.array(data)
else:
if MPI.COMM_WORLD.rank == 0:
# if type is 1 - reformat file
try:
data = job.extract_variable(*funct_args)
except ValueError:
return []
return data


def get_natoms(funct_args):
Expand Down Expand Up @@ -472,6 +472,16 @@ def select_cmd(argument):
return switcher.get(argument)


def _gather_data_from_all_processors(data):
data_gather = MPI.COMM_WORLD.gather(data, root=0)
if MPI.COMM_WORLD.rank == 0:
data = []
for vl in data_gather:
for v in vl:
data.append(v)
return data


if __name__ == "__main__":
while True:
if MPI.COMM_WORLD.rank == 0:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pylammpsmpi_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_extract_variable(self):
self.assertEqual(np.round(x, 2), 1.13)

x = self.lmp.extract_variable("fx", "all", 1)
self.assertEqual(len(x), 128)
self.assertEqual(len(x), 256)
self.assertEqual(np.round(x[0], 2), -0.26)

def test_scatter_atoms(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pylammpsmpi_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_extract_variable(self):
x = self.lmp.extract_variable("tt", "all", 0)
self.assertEqual(np.round(x, 2), 1.13)
x = self.lmp.extract_variable("fx", "all", 1)
self.assertEqual(len(x), 128)
self.assertEqual(len(x), 256)
self.assertEqual(np.round(x[0], 2), -0.26)

def test_scatter_atoms(self):
Expand Down