From 0bbcff464dca3d44b954d79e8fb986797a880a91 Mon Sep 17 00:00:00 2001 From: Jan Janssen Date: Tue, 21 Feb 2023 17:17:54 -0700 Subject: [PATCH 1/4] extract_vairable() --- pylammpsmpi/mpi/lmpmpi.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/pylammpsmpi/mpi/lmpmpi.py b/pylammpsmpi/mpi/lmpmpi.py index bf11caf3..ca284359 100644 --- a/pylammpsmpi/mpi/lmpmpi.py +++ b/pylammpsmpi/mpi/lmpmpi.py @@ -165,15 +165,25 @@ def extract_fix(funct_args): def extract_variable(funct_args): # in the args - if the third one, # which is the type is 1 - a lammps array is returned - if MPI.COMM_WORLD.rank == 0: - # if type is 1 - reformat file - try: - data = job.extract_variable(*funct_args) - except ValueError: - return [] - if funct_args[2] == 1: - data = np.array(data) - return data + if funct_args[2] == 1: + data = job.numpy.extract_variable(*funct_args) + data_gather = MPI.COMM_WORLD.gather(data, root=0) + if MPI.COMM_WORLD.rank == 0: + data = [] + for vl in data_gather: + for v in vl: + data.append(v) + if funct_args[2] == 1: + data = np.array(data) + return data + else: + if MPI.COMM_WORLD.rank == 0: + # if type is 1 - reformat file + try: + data = job.extract_variable(*funct_args) + except ValueError: + return [] + return data def get_natoms(funct_args): From 103c9ffb591c22106baf35265f150aeca2ffef06 Mon Sep 17 00:00:00 2001 From: Jan Janssen Date: Tue, 21 Feb 2023 17:23:53 -0700 Subject: [PATCH 2/4] Update lmpmpi.py --- pylammpsmpi/mpi/lmpmpi.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/pylammpsmpi/mpi/lmpmpi.py b/pylammpsmpi/mpi/lmpmpi.py index ca284359..d067d79c 100644 --- a/pylammpsmpi/mpi/lmpmpi.py +++ b/pylammpsmpi/mpi/lmpmpi.py @@ -77,15 +77,10 @@ def convert_data(val, type, length, width): val = job.extract_compute(*filtered_args) return convert_data(val=val, type=type, length=length, width=width) elif style == 1: # per atom property - val = job.numpy.extract_compute(*filtered_args) - val_gather = MPI.COMM_WORLD.gather(val, root=0) + val = _gather_data_from_all_processors( + data=job.numpy.extract_compute(*filtered_args) + ) if MPI.COMM_WORLD.rank == 0: - # val_gather.shape [number of cores, atoms on specific core] - # the number of atoms on specific cores can vary - val = [] - for vl in val_gather: - for v in vl: - val.append(v) length = job.get_natoms() return convert_data(val=val, type=type, length=length, width=width) else: # Todo @@ -166,16 +161,11 @@ def extract_variable(funct_args): # in the args - if the third one, # which is the type is 1 - a lammps array is returned if funct_args[2] == 1: - data = job.numpy.extract_variable(*funct_args) - data_gather = MPI.COMM_WORLD.gather(data, root=0) + data = _gather_data_from_all_processors( + data=job.numpy.extract_variable(*funct_args) + ) if MPI.COMM_WORLD.rank == 0: - data = [] - for vl in data_gather: - for v in vl: - data.append(v) - if funct_args[2] == 1: - data = np.array(data) - return data + return np.array(data) else: if MPI.COMM_WORLD.rank == 0: # if type is 1 - reformat file @@ -482,6 +472,16 @@ def select_cmd(argument): return switcher.get(argument) +def _gather_data_from_all_processors(data): + data_gather = MPI.COMM_WORLD.gather(data, root=0) + if MPI.COMM_WORLD.rank == 0: + data = [] + for vl in data_gather: + for v in vl: + data.append(v) + return data + + if __name__ == "__main__": while True: if MPI.COMM_WORLD.rank == 0: From cd84d1a343988786aee453df698181f0a5b1b3f8 Mon Sep 17 00:00:00 2001 From: Jan Janssen Date: Tue, 21 Feb 2023 17:30:11 -0700 Subject: [PATCH 3/4] Update test_pylammpsmpi_local.py --- tests/test_pylammpsmpi_local.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_pylammpsmpi_local.py b/tests/test_pylammpsmpi_local.py index e63ad8f1..f2a831f5 100644 --- a/tests/test_pylammpsmpi_local.py +++ b/tests/test_pylammpsmpi_local.py @@ -55,7 +55,7 @@ def test_extract_variable(self): x = self.lmp.extract_variable("tt", "all", 0) self.assertEqual(np.round(x, 2), 1.13) x = self.lmp.extract_variable("fx", "all", 1) - self.assertEqual(len(x), 128) + self.assertEqual(len(x), 256) self.assertEqual(np.round(x[0], 2), -0.26) def test_scatter_atoms(self): From 6a339c1f7006cc0fc8d810b29f3cf6d831c3a7d1 Mon Sep 17 00:00:00 2001 From: Jan Janssen Date: Tue, 21 Feb 2023 17:30:25 -0700 Subject: [PATCH 4/4] Update test_pylammpsmpi_cluster.py --- tests/test_pylammpsmpi_cluster.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_pylammpsmpi_cluster.py b/tests/test_pylammpsmpi_cluster.py index d1a7c015..356fe6f6 100644 --- a/tests/test_pylammpsmpi_cluster.py +++ b/tests/test_pylammpsmpi_cluster.py @@ -59,7 +59,7 @@ def test_extract_variable(self): self.assertEqual(np.round(x, 2), 1.13) x = self.lmp.extract_variable("fx", "all", 1) - self.assertEqual(len(x), 128) + self.assertEqual(len(x), 256) self.assertEqual(np.round(x[0], 2), -0.26) def test_scatter_atoms(self):