Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 54 additions & 127 deletions reframe/core/environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ class Environment:

It is simply a collection of modules to be loaded and environment variables
to be set when this environment is loaded by the framework.
Users may not create or modify directly environments.
'''
name = fields.TypedField('name', typ.Str[r'(\w|-)+'])
modules = fields.TypedField('modules', typ.List[str])
Expand All @@ -23,11 +22,6 @@ def __init__(self, name, modules=[], variables=[]):
self._name = name
self._modules = list(modules)
self._variables = collections.OrderedDict(variables)
self._loaded = False
self._saved_variables = {}
self._conflicted = []
self._preloaded = set()
self._module_ops = []

@property
def name(self):
Expand Down Expand Up @@ -63,100 +57,6 @@ def is_loaded(self):
all(os.environ.get(k, None) == os_ext.expandvars(v)
for k, v in self._variables.items()))

def load(self):
# conflicted module list must be filled at the time of load
rt = runtime()
for m in self._modules:
if rt.modules_system.is_module_loaded(m):
self._preloaded.add(m)

conflicted = rt.modules_system.load_module(m, force=True)
for c in conflicted:
self._module_ops.append(('u', c))

self._module_ops.append(('l', m))
self._conflicted += conflicted

for k, v in self._variables.items():
if k in os.environ:
self._saved_variables[k] = os.environ[k]

os.environ[k] = os_ext.expandvars(v)

self._loaded = True

def unload(self):
if not self._loaded:
return

for k, v in self._variables.items():
if k in self._saved_variables:
os.environ[k] = self._saved_variables[k]
elif k in os.environ:
del os.environ[k]

# Unload modules in reverse order
for m in reversed(self._modules):
if m not in self._preloaded:
runtime().modules_system.unload_module(m)

# Reload the conflicted packages, previously removed
for m in self._conflicted:
runtime().modules_system.load_module(m)

self._loaded = False

def emit_load_commands(self):
rt = runtime()
emit_fn = {
'l': rt.modules_system.emit_load_commands,
'u': rt.modules_system.emit_unload_commands
}
module_ops = self._module_ops or [('l', m) for m in self._modules]

# Emit module commands
ret = []
for op, m in module_ops:
ret += emit_fn[op](m)

# Emit variable set commands
for k, v in self._variables.items():
ret.append('export %s=%s' % (k, v))

return ret

def emit_unload_commands(self):
rt = runtime()

# Invert the logic of module operations, since we are unloading the
# environment
emit_fn = {
'l': rt.modules_system.emit_unload_commands,
'u': rt.modules_system.emit_load_commands
}

ret = []
for var in self._variables.keys():
ret.append('unset %s' % var)

if self._module_ops:
module_ops = reversed(self._module_ops)
else:
module_ops = (('l', m) for m in reversed(self._modules))

for op, m in module_ops:
ret += emit_fn[op](m)

return ret

def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented

return (self._name == other._name and
set(self._modules) == set(other._modules) and
self._variables == other._variables)

def details(self):
'''Return a detailed description of this environment.'''
variables = '\n'.join(' '*8 + '- %s=%s' % (k, v)
Expand All @@ -168,6 +68,14 @@ def details(self):
]
return '\n'.join(lines)

def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented

return (self.name == other.name and
set(self.modules) == set(other.modules) and
self.variables == other.variables)

def __str__(self):
return self.name

Expand All @@ -177,44 +85,63 @@ def __repr__(self):
self.modules, self.variables)


def swap_environments(src, dst):
src.unload()
dst.load()


class EnvironmentSnapshot(Environment):
class _EnvironmentSnapshot(Environment):
def __init__(self, name='env_snapshot'):
self._name = name
self._modules = runtime().modules_system.loaded_modules()
self._variables = dict(os.environ)
self._conflicted = []
super().__init__(name,
runtime().modules_system.loaded_modules(),
os.environ.items())

def load(self):
def restore(self):
'''Restore this environment snapshot.'''
os.environ.clear()
os.environ.update(self._variables)
self._loaded = True

@property
def is_loaded(self):
raise NotImplementedError('is_loaded is not a valid property '
'of an environment snapshot')
def __eq__(self, other):
if not isinstance(other, Environment):
return NotImplemented

# Order of variables is not important when comparing snapshots
for k, v in self.variables.items():
if other.variables[k] != v:
return False

return (self.name == other.name and
set(self.modules) == set(other.modules))


def snapshot():
'''Create an environment snapshot'''
return _EnvironmentSnapshot()

def unload(self):
raise NotImplementedError('cannot unload an environment snapshot')

def load(*environs):
'''Load environments in the current Python context.

class save_environment:
'''A context manager for saving and restoring the current environment.'''
Returns a tuple containing a snapshot of the environment at entry to this
function and a list of shell commands required to load ``environs``.
'''
env_snapshot = snapshot()
commands = []
rt = runtime()
for env in environs:
for m in env.modules:
conflicted = rt.modules_system.load_module(m, force=True)
for c in conflicted:
commands += rt.modules_system.emit_unload_commands(c)

commands += rt.modules_system.emit_load_commands(m)

for k, v in env.variables.items():
os.environ[k] = os_ext.expandvars(v)
commands.append('export %s=%s' % (k, v))

def __init__(self):
self.environ_save = EnvironmentSnapshot()
return env_snapshot, commands

def __enter__(self):
return self.environ_save

def __exit__(self, exc_type, exc_value, traceback):
# Restore the environment and propagate any exception thrown
self.environ_save.load()
def emit_load_commands(*environs):
env_snapshot, commands = load(*environs)
env_snapshot.restore()
return commands


class ProgEnvironment(Environment):
Expand Down
Loading