Skip to content

Commit

Permalink
Merge pull request #407 from tenpy/yaml_py_eval
Browse files Browse the repository at this point in the history
allow !py_eval python snippets when loading yaml files
  • Loading branch information
Jakob-Unfried committed May 2, 2024
2 parents 630c36e + 8bab6d3 commit e18ea2e
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 16 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/code-coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ jobs:
python -m pip install --upgrade coverage
- name: Build and install tenpy
# also installs extra dependencies defined in pyproject.toml
run: |
python -m build .
python -m pip install .
python -m pip install ".[io, test, extra]"
- name: Run pytest with coverage
# pytest configuration in pyproject.toml
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ jobs:
python -m pip install --upgrade pytest
- name: Build and install tenpy
# also installs extra dependencies defined in pyproject.toml
run: |
python -m build .
python -m pip install .
python -m pip install ".[io, test, extra]"
- name: Run pytest
# configuration in pyproject.toml
Expand Down
6 changes: 4 additions & 2 deletions doc/commandline-help.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ Command line interface to run a TeNPy simulation.

positional arguments:
parameters_file
Yaml (*.yml) file with the simulation parameters/options. Multiple files get merged
according to MERGE; see tenpy.tools.misc.merge_recursive for details.
Yaml (*.yml) file with the simulation parameters/options. We support an additional yaml
tag ``!py_eval: VALUE`` that gets initialized by python's ``eval(VALUE)`` with `np`, `scipy`
and `tenpy` defined. Multiple files get merged according to MERGE; see
tenpy.tools.misc.merge_recursive for details.

options:
-h, --help
Expand Down
19 changes: 19 additions & 0 deletions doc/intro/options.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,22 @@ simulation, e.g., in the `sim_params` of the `results` returned by :meth:`~tenpy
If you add extra options to your configuration that TeNPy doesn't read out by the end of the simulation, it will (usually) issue a warning.
Getting such a warnings is an indicator for a typo in your configuration, or an option being in the wrong config dictionary.
Python snippets in yaml files
-----------------------------
When defining the parameters in the yaml file, you might want to evaluate small formulas e.g., set a parameter to a certain fraction of $\pi$,
or expanding a long list ``[2**i for i in range(5, 10)]`` without explicitly writing all the entries.
For those cases, it can be convenient to have small python snippets inside the yaml file, which we allow by loading the
yaml files with :func:`tenpy.tools.params.load_yaml_with_py_eval`.

It defines a ``!py_eval`` yaml tag, which should be followed by a string of python code to be evaluated with python's ``eval()`` function.
A good method to pass the python code is to use a literal string in yaml, as shown in the simple examples below.

.. code :: yaml
a: !py_eval |
2**np.arange(6, 10)
b: !py_eval |
[10, 15] + list(range(20, 31, 2)) + [35, 40]
c: !py_eval "2*np.pi * 0.3"
7 changes: 4 additions & 3 deletions doc/intro/simulations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Of course, you can also directly run the simulation from inside python, the comm
import tenpy
import yaml

simulation_params = yaml.load("parameters.yml")
simulation_params = tenpy.load_yaml_with_py_eval("parameters.yml")
# instead of using yaml, you can also define a usual python dictionary
tenpy.run_simulation(**simulation_params)

Expand Down Expand Up @@ -86,8 +86,9 @@ returned by the simulation (or saved to file):
from pprint import pprint
import yaml
with open('parameters.yml', 'r') as f:
simulation_parameters = yaml.safe_load(f)
with open('parameters.yml', 'r') as stream:
simulation_parameters = tenpy.load_yaml_with_py_eval(stream)
# alternative: simulation_parameters = tenpy.load_yaml_with_py_eval('parameters.yml')
results = tenpy.run_simulation(simulation_parameters)
pprint(results['simulation_parameters'])
Expand Down
11 changes: 6 additions & 5 deletions tenpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
from .tools.hdf5_io import save, load, save_to_hdf5, load_from_hdf5
from .tools.misc import (setup_logging, consistency_check, TenpyInconsistencyError,
TenpyInconsistencyWarning, BetaWarning)
from .tools.params import Config, asConfig
from .tools.params import Config, asConfig, load_yaml_with_py_eval



Expand Down Expand Up @@ -134,6 +134,7 @@
# from tenpy.tools
'save', 'load', 'save_to_hdf5', 'load_from_hdf5', 'setup_logging', 'consistency_check',
'TenpyInconsistencyError', 'TenpyInconsistencyWarning', 'BetaWarning', 'Config', 'asConfig',
'load_yaml_with_py_eval',
# from tenpy.__init__, i.e. defined below
'show_config', 'console_main',
]
Expand Down Expand Up @@ -173,7 +174,7 @@ def console_main(*command_line_args):

args = parser.parse_args(args=command_line_args if command_line_args else None)
# import extra modules
context = {'tenpy': globals(), 'np': np, 'scipy': scipy}
context = {'tenpy': sys.modules[__name__], 'np': np, 'scipy': scipy}
if args.import_module:
sys.path.insert(0, '.')
for module_name in args.import_module:
Expand All @@ -182,11 +183,9 @@ def console_main(*command_line_args):
# load parameters_file
options = {}
if args.parameters_file:
import yaml
options_files = []
for fn in args.parameters_file:
with open(fn, 'r') as stream:
options = yaml.safe_load(stream)
options = load_yaml_with_py_eval(fn, context)
options_files.append(options)
if len(options_files) > 1:
options = tools.misc.merge_recursive(*options_files, conflict=args.merge)
Expand Down Expand Up @@ -267,6 +266,8 @@ def formatter(prog):
parser.add_argument('parameters_file',
nargs='*',
help="Yaml (*.yml) file with the simulation parameters/options. "
"We support an additional yaml tag !py_eval: VALUE that gets initialized "
"by python's ``eval(VALUE)`` with `np`, `scipy` and `tenpy` defined. "
"Multiple files get merged according to MERGE; "
"see tenpy.tools.misc.merge_recursive for details.")
opt_help = textwrap.dedent("""\
Expand Down
91 changes: 87 additions & 4 deletions tenpy/tools/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# Copyright (C) TeNPy Developers, GNU GPLv3

import warnings
import numpy
import numpy as np
from collections.abc import MutableMapping
import pprint
Expand All @@ -14,7 +15,7 @@

from .hdf5_io import ATTR_FORMAT

__all__ = ["Config", "asConfig"]
__all__ = ["Config", "asConfig", "load_yaml_with_py_eval"]


class Config(MutableMapping):
Expand Down Expand Up @@ -94,6 +95,8 @@ def save_yaml(self, filename):
def from_yaml(cls, filename, name=None):
"""Load a `Config` instance from a YAML file containing the :attr:`options`.
The yaml file can have additional ``!py_eval`` tags, see :func:`load_yaml_with_py_eval`.
.. warning ::
Like pickle, it is not safe to load a yaml file from an untrusted source! A malicious
file can call any Python function and should thus be treated with extreme caution.
Expand All @@ -113,9 +116,7 @@ def from_yaml(cls, filename, name=None):
"""
if name is None:
name = os.path.basename(filename)
import yaml
with open(filename, 'r') as stream:
config = yaml.safe_load(stream)
config = load_yaml_with_py_eval(filename)
return cls(config, name)

def save_hdf5(self, hdf5_saver, h5gr, subpath):
Expand Down Expand Up @@ -423,3 +424,85 @@ def asConfig(config, name):
if isinstance(config, Config):
return config
return Config(config, name)



def _yaml_eval_constructor(loader, node):
"""Yaml constructor to support `!py_eval` tag in yaml files."""
cmd = loader.construct_scalar(node)
if not isinstance(cmd, str):
raise ValueError("expect string argument to `!py_eval`")
try:
res = eval(cmd, loader.eval_context)
except:
print("\nError while yaml parsing the following !py_eval command:\n", cmd, "\n")
raise
return res


try:
import yaml
except ImportError:
yaml = None

if yaml is None:
_YamlLoaderWithPyEval = None
else:
class _YamlLoaderWithPyEval(yaml.FullLoader):
eval_context = {}

yaml.add_constructor("!py_eval", _yaml_eval_constructor, Loader=_YamlLoaderWithPyEval)


def load_yaml_with_py_eval(filename, context={'np': numpy}):
"""Load a yaml file with support for an additional `!py_eval` tag.
When defining yaml parameter files, it's sometimes convenient to just have python snippets
in there, e.g. to get fractions of pi or expand last lists.
This function loads a yaml file supporting such (short) python snippets
that get evaluated by python's ``eval(snippet)``.
It expects one string of python code following the ``!py_eval`` tag.
The most reliable method to pass the python code is to use a literal
string in yaml, as shown in the example below.
.. code :: yaml
a: !py_eval |
2**np.arange(6, 10)
b: !py_eval |
[10, 15] + list(range(20, 31, 2)) + [35, 40]
c: !py_eval "2*np.pi * 0.3"
Note that a subsequent ``yaml.dump()`` might contain ugly parts if you construct
generic python objects, e.g., a numpy array scalar like ``np.arange(10)[0]``.
If you want to avoid this, you can explicitly convert back to lists before.
.. warning ::
Like pickle, it is not safe to load a yaml file from an untrusted source! A malicious
file can call any Python function and should thus be treated with extreme caution.
Parameters
----------
filename : str
Filename of the file to load.
context : dict
The context of ``globals()`` passed to `eval`.
Returns
-------
config :
Data (typically nested dictionary) as defined in the yaml file.
"""
if _YamlLoaderWithPyEval is None:
raise RuntimeError('Could not import yaml. Consider installing the pyyaml package.')

_YamlLoaderWithPyEval.eval_context = context

with open(filename, 'r') as stream:
config = yaml.load(stream, Loader=_YamlLoaderWithPyEval)
return config
56 changes: 56 additions & 0 deletions tests/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,59 @@ def test_output_filename_from_dict():
},
parts_order=['alg.dt', ('model.Lx', 'model.Ly')])
assert fn == 'result_dt_0.50_3x4.h5'


yaml_example = """
simulation_class : GroundStateSearch
directory: results
model_class : SpinModel
model_params :
bc_MPS: infinite
bc_y: cylinder
lattice: !py_eval tenpy.models.lattice.Square
Lx: 2
Ly: 4
S: .5
Jx: !py_eval "[J ** 2 for J in range(6)]"
hx: !py_eval |
np.linspace(0, 5, 21, endpoint=True)
log_params:
to_file: INFO
to_stdout: INFO
initial_state_params:
method : lat_product_state
product_state : [[[up]]]
algorithm_class: TwoSiteDMRGEngine
algorithm_params:
mixer: True
trunc_params:
svd_min: 1.e-10
chi_max: 200
max_E_err: 1.e-10
sequential:
recursive_keys:
- model_params.hx
- model_params.Jx
"""


def test_yaml_load(tmp_path):
yaml = pytest.importorskip('yaml')
file = tmp_path / 'simulation.yaml'
with open(file, 'w') as f:
print(yaml_example, file=f)
simulation_params = tenpy.load_yaml_with_py_eval(file, context=dict(np=np, tenpy=tenpy))
assert simulation_params['simulation_class'] == 'GroundStateSearch'
assert simulation_params['model_params']['Jx'] == [0, 1, 4, 9, 16, 25]
np.testing.assert_array_almost_equal_nulp(
simulation_params['model_params']['hx'],
np.linspace(0, 5, 21, endpoint=True),
10
)
assert simulation_params['model_params']['lattice'] is tenpy.Square

0 comments on commit e18ea2e

Please sign in to comment.