Skip to content

Commit

Permalink
Merge 4ae72c5 into 74d8917
Browse files Browse the repository at this point in the history
  • Loading branch information
richardotis committed Aug 2, 2019
2 parents 74d8917 + 4ae72c5 commit 10987e5
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 16 deletions.
6 changes: 5 additions & 1 deletion pycalphad/core/cache.py
Expand Up @@ -6,7 +6,11 @@
http://code.activestate.com/recipes/578078-py26-and-py30-backport-of-python-33s-lru-cache/
"""

from collections import Iterable, Mapping, namedtuple
try:
from collections.abc import Iterable, Mapping
except ImportError:
from collections import Iterable, Mapping
from collections import namedtuple
from threading import RLock
from functools import update_wrapper

Expand Down
13 changes: 8 additions & 5 deletions pycalphad/core/utils.py
Expand Up @@ -6,13 +6,16 @@
from pycalphad.core.halton import halton
from pycalphad.core.constants import MIN_SITE_FRACTION
from sympy.utilities.lambdify import lambdify
from sympy.printing.lambdarepr import LambdaPrinter
from sympy import Symbol
import numpy as np
import operator
import functools
import itertools
import collections
try:
from collections.abc import Iterable, Mapping
except ImportError:
from collections import Iterable, Mapping

try:
# Only available in numpy 1.10 and newer
Expand Down Expand Up @@ -137,7 +140,7 @@ def unpack_condition(tup):
return np.arange(tup[0], tup[1], tup[2], dtype=np.float)
else:
raise ValueError('Condition tuple is length {}'.format(len(tup)))
elif isinstance(tup, collections.Iterable):
elif isinstance(tup, Iterable):
return [float(x) for x in tup]
else:
return [float(tup)]
Expand Down Expand Up @@ -242,14 +245,14 @@ def unpack_kwarg(kwarg_obj, default_arg=None):
"""
new_dict = collections.defaultdict(lambda: default_arg)

if isinstance(kwarg_obj, collections.Mapping):
if isinstance(kwarg_obj, Mapping):
new_dict.update(kwarg_obj)
# kwarg_obj is a list containing a dict and a default
# For now at least, we don't treat ndarrays the same as other iterables
# ndarrays are assumed to be numeric arrays containing "default values", so don't match here
elif isinstance(kwarg_obj, collections.Iterable) and not isinstance(kwarg_obj, np.ndarray):
elif isinstance(kwarg_obj, Iterable) and not isinstance(kwarg_obj, np.ndarray):
for element in kwarg_obj:
if isinstance(element, collections.Mapping):
if isinstance(element, Mapping):
new_dict.update(element)
else:
# element=element syntax to silence var-from-loop warning
Expand Down
33 changes: 23 additions & 10 deletions pycalphad/io/tdb.py
Expand Up @@ -11,6 +11,7 @@
from sympy import sympify, And, Or, Not, Intersection, Union, EmptySet, Interval, Piecewise
from sympy import Symbol, GreaterThan, StrictGreaterThan, LessThan, StrictLessThan, Complement, S
from sympy import Mul, Pow, Rational
from sympy.abc import _clash
from sympy.printing.str import StrPrinter
from sympy.core.mul import _keep_coeff
from sympy.printing.precedence import precedence
Expand All @@ -36,6 +37,16 @@
ast.Load, ast.Mult, ast.Name, ast.Num, ast.Pow, ast.Sub,
ast.UAdd, ast.UnaryOp, ast.USub]

# Avoid symbol names clashing with objects in sympy (gh-233)
clashing_namespace = {}
clashing_namespace.update(_clash)
clashing_namespace['CC'] = Symbol('CC')
clashing_namespace['FF'] = Symbol('FF')
clashing_namespace['T'] = v.T
clashing_namespace['P'] = v.P
clashing_namespace['R'] = v.R


def _sympify_string(math_string):
"Convert math string into SymPy object."
# drop pound symbols ('#') since they denote function names
Expand All @@ -49,12 +60,7 @@ def _sympify_string(math_string):
expr_string = \
re.sub(r'(?<!\w)EXP(?!\w)', 'exp', expr_string,
flags=re.IGNORECASE)
# Convert raw variables into StateVariable objects
variable_fixes = {
Symbol('T'): v.T,
Symbol('P'): v.P,
Symbol('R'): v.R
}

# sympify uses eval, so we need to sanitize the input
nodes = ast.parse(expr_string)
nodes = ast.Expression(nodes.body[0].value)
Expand All @@ -63,7 +69,8 @@ def _sympify_string(math_string):
if type(node) not in _AST_WHITELIST: #pylint: disable=W1504
raise ValueError('Expression from TDB file not in whitelist: '
'{}'.format(expr_string))
return sympify(expr_string).xreplace(variable_fixes)

return sympify(expr_string, locals=clashing_namespace)

def _parse_action(func):
"""
Expand All @@ -84,7 +91,13 @@ def _parse_action(func):
Source: Florian Brucker on StackOverflow
http://stackoverflow.com/questions/10177276/pyparsing-setparseaction-function-is-getting-no-arguments
"""
num_args = len(inspect.getargspec(func).args)
if sys.version_info[0] > 2:
func_items = inspect.signature(func).parameters.items()
func_args = [name for name, param in func_items
if param.kind == param.POSITIONAL_OR_KEYWORD]
else:
func_args = inspect.getargspec(func).args
num_args = len(func_args)
if num_args > 3:
raise ValueError('Input function must take at most 3 parameters.')

Expand Down Expand Up @@ -628,7 +641,7 @@ def _apply_new_symbol_names(dbf, symbol_name_map):
dbf.symbols = {name: S(expr).xreplace({Symbol(s): Symbol(v) for s, v in symbol_name_map.items()}) for name, expr in dbf.symbols.items()}
# finally propagate through to the parameters
for p in dbf._parameters.all():
dbf._parameters.update({'parameter': S(p['parameter']).xreplace({Symbol(s): Symbol(v) for s, v in symbol_name_map.items()})}, eids=[p.eid])
dbf._parameters.update({'parameter': S(p['parameter']).xreplace({Symbol(s): Symbol(v) for s, v in symbol_name_map.items()})}, doc_ids=[p.doc_id])


def write_tdb(dbf, fd, groupby='subsystem', if_incompatible='warn'):
Expand Down Expand Up @@ -914,7 +927,7 @@ def read_tdb(dbf, fd):
try:
tokens = grammar.parseString(command)
_TDB_PROCESSOR[tokens[0]](dbf, *tokens[1:])
except ParseException:
except:
print("Failed while parsing: " + command)
print("Tokens: " + str(tokens))
raise
Expand Down
5 changes: 5 additions & 0 deletions pycalphad/tests/test_database.py
Expand Up @@ -560,3 +560,8 @@ def test_database_parameter_with_species_that_is_not_a_stoichiometric_formula():
sbminus3 = dbf._parameters.search(tinydb.where('constituent_array') == ((species_dict['SB-3'],),))
assert len(sbminus3) == 1
assert sbminus3[0]['parameter'].args[0][0] == 10000


def test_database_sympy_namespace_clash():
"""Symbols that clash with sympy special objects are replaced (gh-233)"""
Database.from_string("""FUNCTION TEST 0.01 T*LN(CC)+FF; 6000 N TW !""", fmt='tdb')

0 comments on commit 10987e5

Please sign in to comment.