-
Notifications
You must be signed in to change notification settings - Fork 238
dsl: Expand derivative #2559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
dsl: Expand derivative #2559
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2559 +/- ##
===========================================
- Coverage 92.00% 63.37% -28.64%
===========================================
Files 245 245
Lines 48802 48977 +175
Branches 4307 4332 +25
===========================================
- Hits 44902 31038 -13864
- Misses 3208 17061 +13853
- Partials 692 878 +186
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Do you have some MFE (minimal failing example) of where this is popping up? |
0f1da56
to
c89d3cb
Compare
A near copy of sympy.core.expr.Expr.as_independent | ||
with a bug fixed | ||
""" | ||
from sympy import Symbol |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can these imports be moved to the top of the file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All these are external sympy imports so they shouldn't need to be local yes
else: | ||
want = Mul | ||
|
||
# sift out deps into symbolic and other and ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: Start comment with capital level
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really want to address nitpicks or style issues in this method as it has been lifted directly from sympy (as I acknowledge in the docstring) unless the policy is that copied code should be modified to fit the style. Honestly I'd rather just fix the bug upsteam (planned) but we still need to depend on older versions of sympy.
other.append(d) | ||
|
||
def has(e): | ||
"""return the standard has() if there are no literal symbols, else |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: Tweak docstring for consistency
return has_other | ||
return has_other or e.has(*(e.free_symbols & sym)) | ||
|
||
if (want is not func or |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: Weird indent multiline
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here
d1476d6
to
ab53c5d
Compare
Here is a wonderful list of failing examples that this PR now addresses: import devito as dv
import sympy as sym
grid = dv.Grid(shape=(11,), extent=(1,))
x = grid.dimensions[0]
u = dv.Function(name='u', grid=grid, space_order=4)
a = u.dx
b = a.subs({u: -5*u.dx + 4*u + 3})
# Substitution works!
print(b)
# Unintuitive results:
# .as_independent method on devito.finite_differences.differentiable.Mul doesn't work
print((-a).as_independent(x))
# Result:
## (1, -Derivative(u(x), x))
# devito.Derivatives aren't reconstructable from func and args (as per sympy docs)
du = u.dx
print(du.func(*du.args))
print(du.func(*du.args).args)
# Result
## Derivative(u(x), (x, 1))
## (u(x), ((x, 1), 1))
# default simplification causes different results:
du11 = dv.Derivative(u, x, x)
print(du11)
print(du11.deriv_order)
# Result
## Derivative(u(x), (x, 2))
## (1, 1)
du2 = dv.Derivative(u, (x, 2))
print(du2)
print(du2.deriv_order)
# Result
## Derivative(u(x), (x, 2))
## (2,)
# Whut!?
print(dv.Derivative(u, x, deriv_order=(2,4)))
# Result
## Derivative(u(x), (x, 2))
# Also pretty wild
print(dv.Derivative(u, (x, 0)))
# Result
## Derivative(u(x), x)
# Really? Double check this...
print(dv.Derivative(-1))
# Result
## Traceback (most recent call last):
## File "<stdin>", line 1, in <module>
## File "/media/devito/devito_py312/src/devito/devito/finite_differences/derivative.py", line 98, in __new__
## raise ValueError("`expr` must be a Differentiable object")
## ValueError: `expr` must be a Differentiable object
# Maybe it's sufficient to handle the case:
print(dv.Derivative(sym.sympify(-1)))
# Cannot expand
print(b.expand())
# Result
## ...
## RecursionError: maximum recursion depth exceeded |
with pytest.raises(TypeError): | ||
_ = Derivative(u, (x, a)) | ||
|
||
def test_expand_mul(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these work when using the derivative shortcuts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that is actually what is being tested (see the setup_class
method)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly nitpicky comments, looks like a nice cleanup
tests/test_derivatives.py
Outdated
assert du11 == du2 | ||
assert du11.deriv_order == du2.deriv_order | ||
|
||
@pytest.mark.xfail(raises=ValueError) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not the right way to test an exception. You should do
with pytest.raises(ValueError):
Derivative(self.u, self.x, deriv_order=(2, 4))
This will mark the test a successful if it raises the right exception instead of as "accepted failure"
Same below
except Exception as e: | ||
raise ValueError("`expr` must be a Differentiable object") from e | ||
|
||
# Validate `dims`. It can be: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we put those in classmethods with appropriate names, like cls._process_dims
, cls._process_order
, ...
# Default finite difference orders depending on input dimension (.dt or .dx) | ||
# It's possible that the expr is a `sympy.Number` at this point, which | ||
# has derivative 0, unless we're taking a 0th derivative. | ||
if isinstance(expr, sympy.Number): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should maybe be more general as "check if expr depends on dims" so something like if not expr.free_symbols & set(dims)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like an optimisation, but actually is not the right thing to do here. Consider the case where you want to make a substitution by using a temporary. Now you cannot construct the Derivative
as it will be prematurely simplified to zero, unless you explicitly construct the correct type of temporary.
I think in the case of just a number this is a safe thing to do.
A near copy of sympy.core.expr.Expr.as_independent | ||
with a bug fixed | ||
""" | ||
from sympy import Symbol |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All these are external sympy imports so they shouldn't need to be local yes
@@ -113,6 +113,73 @@ def is_Staggered(self): | |||
def is_TimeDependent(self): | |||
return any(i.is_Time for i in self.dimensions) | |||
|
|||
def as_independent(self, *deps, **hint): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems quite brutal that we need that almost full copy paste here. Do we know which lines actually cause issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's horrible, it is this line:
if (want is not func or func is not Add and func is not Mul):
which becomes
if (want is not func or not issubclass(func, Add) and not issubclass(func, Mul)):
raise ValueError(f'Expected `(dim, deriv_order)`, got {dims[0]}') | ||
elif len(dims) == 2 and not isinstance(dims[1], Iterable) and is_integer(dims[1]): | ||
# special case of single dimension and order | ||
dims = (dims, ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just use as_tuple
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That doesn't work in this case, dim
is already a tuple
, but I want a nested tuple
@@ -497,3 +505,57 @@ def _eval_fd(self, expr, **kwargs): | |||
res = res.xreplace(e) | |||
|
|||
return res | |||
|
|||
def _eval_expand_nest(self, **hints): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would some kind of recursing singledispatch
here tidy things up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think sympy is doing the right thing for us here
|
||
if func is Add: # all terms were treated as commutative | ||
return (Add(*indep), _unevaluated_Add(*depend)) | ||
else: # handle noncommutative by stopping at first dependent term |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: This file has a lot of comments which aren't in our nominal style
cls.b = a.subs({cls.u: -5*cls.u.dx + 4*cls.u + 3}) | ||
|
||
def test_reconstructible(self): | ||
''' Check that devito.Derivatives are reconstructible from func and args |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ultra nitpick: Docstrings not matching wider style in the codebase
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reference somewhere?
# TODO: Delete this | ||
if kwargs.get('preprocessed', False): | ||
from warnings import warn | ||
warn('I removed the `preprocessed` kwarg') | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# TODO: Delete this | |
if kwargs.get('preprocessed', False): | |
from warnings import warn | |
warn('I removed the `preprocessed` kwarg') |
4b05609
to
0700fb2
Compare
0700fb2
to
8d66b66
Compare
I want to be able to manipulate Devito derivative objects, more like sympy derivative objects.
(possibly only optionally)