From 22bedf627fb97e0d2f59200e9dc7c41a631a925d Mon Sep 17 00:00:00 2001 From: Vasileios Karakasis Date: Sat, 26 Sep 2020 22:58:56 +0200 Subject: [PATCH] Restore environment correctly in case of module load failures --- reframe/core/runtime.py | 8 ++++++-- unittests/test_environments.py | 21 +++++++++++++++++---- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/reframe/core/runtime.py b/reframe/core/runtime.py index 4c74dcd95f..5e4132e4e7 100644 --- a/reframe/core/runtime.py +++ b/reframe/core/runtime.py @@ -220,8 +220,12 @@ def loadenv(*environs): def emit_loadenv_commands(*environs): - env_snapshot, commands = loadenv(*environs) - env_snapshot.restore() + env_snapshot = snapshot() + try: + _, commands = loadenv(*environs) + finally: + env_snapshot.restore() + return commands diff --git a/unittests/test_environments.py b/unittests/test_environments.py index ac61f062ea..3d5029842d 100644 --- a/unittests/test_environments.py +++ b/unittests/test_environments.py @@ -3,6 +3,7 @@ # # SPDX-License-Identifier: BSD-3-Clause +import contextlib import os import pytest @@ -225,8 +226,8 @@ def test_env_immutability(base_environ, env0): prgenv.ldflags = ['-lm'] -def test_env_emit_load_commands(base_environ, user_runtime, - modules_system, env0): +def test_emit_loadenv_commands(base_environ, user_runtime, + modules_system, env0): ms = rt.runtime().modules_system expected_commands = [ ms.emit_load_commands('testmod_foo')[0], @@ -237,8 +238,8 @@ def test_env_emit_load_commands(base_environ, user_runtime, assert expected_commands == rt.emit_loadenv_commands(env0) -def test_env_emit_load_commands_with_confict(base_environ, user_runtime, - modules_system, env0): +def test_emit_loadenv_commands_with_confict(base_environ, user_runtime, + modules_system, env0): # Load a conflicting module modules_system.load_module('testmod_bar') ms = rt.runtime().modules_system @@ -250,3 +251,15 @@ def test_env_emit_load_commands_with_confict(base_environ, user_runtime, 'export _var3=${_var1}', ] assert expected_commands == rt.emit_loadenv_commands(env0) + + +def test_emit_loadenv_failure(user_runtime): + snap = rt.snapshot() + environ = env.Environment('test', modules=['testmod_foo', 'testmod_xxx']) + + # Suppress the module load error and verify that the original environment + # is preserved + with contextlib.suppress(EnvironError): + rt.emit_loadenv_commands(environ) + + assert rt.snapshot() == snap