Skip to content

Commit

Permalink
Clearing worker processor instances of reactor state after every inte…
Browse files Browse the repository at this point in the history
…raction (#1729)
  • Loading branch information
zachmprince committed Jun 12, 2024
1 parent 8cfc244 commit e8a80be
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 5 deletions.
2 changes: 1 addition & 1 deletion armi/bookkeeping/memoryProfiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def __init__(self):
self.processMemoryInMB: Optional[float] = None
if _havePsutil:
self.percentNodeRamUsed = psutil.virtual_memory().percent
self.processMemoryInMB = psutil.Process().memory_info().rss / (1012.0**2)
self.processMemoryInMB = psutil.Process().memory_info().rss / (1024.0**2)

def __isub__(self, other):
if self.percentNodeRamUsed is not None and other.percentNodeRamUsed is not None:
Expand Down
10 changes: 10 additions & 0 deletions armi/operators/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,9 @@ def _interactAll(self, interactionName, activeInterfaces, *args):
)
)

# Allow inherited classes to clean up things after an interaction
self._finalizeInteract()

runLog.header(
"=========== Completed {} Event ===========\n".format(
interactionName + cycleNodeTag
Expand All @@ -536,6 +539,13 @@ def _interactAll(self, interactionName, activeInterfaces, *args):

return halt

def _finalizeInteract(self):
"""Member called after each interface has completed its interaction.
Useful for cleaning up data.
"""
pass

def printInterfaceSummary(self, interface, interactionName, statePointIndex, *args):
"""
Log which interaction point is about to be executed.
Expand Down
28 changes: 24 additions & 4 deletions armi/operators/operatorMPI.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def workerOperate(self):
note = context.MPI_COMM.bcast("wait", root=0)
if note != "wait":
raise RuntimeError('did not get "wait". Got {0}'.format(note))
elif cmd == "reset":
runLog.extra("Workers are being reset.")
else:
# we don't understand the command on our own. check the interfaces
# this allows all interfaces to have their own custom operation code.
Expand Down Expand Up @@ -193,14 +195,28 @@ def workerOperate(self):
pm = getPluginManager()
resetFlags = pm.hook.mpiActionRequiresReset(cmd=cmd)
# only reset if all the plugins agree to reset
if all(resetFlags):
if all(resetFlags) or cmd == "reset":
self._resetWorker()

# might be an mpi action which has a reactor and everything, preventing
# garbage collection
del cmd
gc.collect()

def _finalizeInteract(self):
"""Inherited member called after each interface has completed its interact.
This will force all the workers to clear their reactor data so that it
isn't carried around to the next interact.
Notes
-----
This is only called on the root processor. Worker processors will know
what to do with the "reset" broadcast.
"""
context.MPI_COMM.bcast("reset", root=0)
runLog.extra("Workers have been reset.")

def _resetWorker(self):
"""
Clear out the reactor on the workers to start anew.
Expand All @@ -214,12 +230,16 @@ def _resetWorker(self):
.. warning:: This should build empty non-core systems too.
"""
xsGroups = self.getInterface("xsGroups")
if xsGroups:
xsGroups.clearRepresentativeBlocks()
# Nothing to do if we never had anything
if self.r is None:
return

cs = self.cs
bp = self.r.blueprints
spatialGrid = self.r.core.spatialGrid
xsGroups = self.getInterface("xsGroups")
if xsGroups:
xsGroups.clearRepresentativeBlocks()
self.detach()
self.r = reactors.Reactor(cs.caseTitle, bp)
core = reactors.Core("Core")
Expand Down
32 changes: 32 additions & 0 deletions armi/tests/test_mpiFeatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ def workerOperate(self, cmd):
return False


class MockInterface(Interface):
name = "mockInterface"

def interactInit(self):
pass


class MpiOperatorTests(unittest.TestCase):
"""Testing the MPI parallelization operator."""

Expand Down Expand Up @@ -140,6 +147,31 @@ def test_primaryCritical(self):
else:
self.o.operate()

@unittest.skipIf(context.MPI_SIZE <= 1 or MPI_EXE is None, "Parallel test only")
def test_finalizeInteract(self):
"""Test to make sure workers are reset after interface interactions."""
# Add a random number of interfaces
interface = MockInterface(self.o.r, self.o.cs)
self.o.addInterface(interface)

with mockRunLogs.BufferLog() as mock:
if context.MPI_RANK == 0:
self.o.interactAllInit()
context.MPI_COMM.bcast("quit", root=0)
context.MPI_COMM.bcast("finished", root=0)
else:
self.o.workerOperate()

logMessage = (
"Workers have been reset."
if context.MPI_RANK == 0
else "Workers are being reset."
)
numCalls = len(
[line for line in mock.getStdout().splitlines() if logMessage in line]
)
self.assertGreaterEqual(numCalls, 1)


# these two must be defined up here so that they can be pickled
class BcastAction1(mpiActions.MpiAction):
Expand Down

0 comments on commit e8a80be

Please sign in to comment.