Skip to content

dsl: Expand derivative #2559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open

Conversation

JDBetteridge
Copy link
Contributor

I want to be able to manipulate Devito derivative objects, more like sympy derivative objects.

(possibly only optionally)

@mloubout mloubout added the API api (symbolics, types, ...) label Mar 28, 2025
Copy link

codecov bot commented Mar 28, 2025

Codecov Report

Attention: Patch coverage is 24.78261% with 173 lines in your changes missing coverage. Please review.

Project coverage is 63.37%. Comparing base (ce8de4c) to head (8d66b66).

Files with missing lines Patch % Lines
tests/test_derivatives.py 16.84% 79 Missing ⚠️
devito/finite_differences/derivative.py 43.01% 48 Missing and 5 partials ⚠️
devito/finite_differences/differentiable.py 2.38% 41 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (ce8de4c) and HEAD (8d66b66). Click for more details.

HEAD has 8 uploads less than BASE
Flag BASE (ce8de4c) HEAD (8d66b66)
16 9
pytest-gpu-nvc-nvidiaX 1 0
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #2559       +/-   ##
===========================================
- Coverage   92.00%   63.37%   -28.64%     
===========================================
  Files         245      245               
  Lines       48802    48977      +175     
  Branches     4307     4332       +25     
===========================================
- Hits        44902    31038    -13864     
- Misses       3208    17061    +13853     
- Partials      692      878      +186     
Flag Coverage Δ
pytest-gpu-aomp-amdgpuX 72.40% <26.66%> (-0.18%) ⬇️
pytest-gpu-nvc-nvidiaX ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@mloubout
Copy link
Contributor

Do you have some MFE (minimal failing example) of where this is popping up?

@JDBetteridge JDBetteridge changed the title JDBetteridge/expand derivative dsl: JDBetteridge/expand derivative Apr 2, 2025
@mloubout mloubout changed the title dsl: JDBetteridge/expand derivative dsl: Expand derivative Apr 3, 2025
@JDBetteridge JDBetteridge force-pushed the JDBetteridge/expand_derivative branch 2 times, most recently from 0f1da56 to c89d3cb Compare April 3, 2025 21:59
A near copy of sympy.core.expr.Expr.as_independent
with a bug fixed
"""
from sympy import Symbol
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can these imports be moved to the top of the file?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these are external sympy imports so they shouldn't need to be local yes

else:
want = Mul

# sift out deps into symbolic and other and ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: Start comment with capital level

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really want to address nitpicks or style issues in this method as it has been lifted directly from sympy (as I acknowledge in the docstring) unless the policy is that copied code should be modified to fit the style. Honestly I'd rather just fix the bug upsteam (planned) but we still need to depend on older versions of sympy.

other.append(d)

def has(e):
"""return the standard has() if there are no literal symbols, else
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: Tweak docstring for consistency

return has_other
return has_other or e.has(*(e.free_symbols & sym))

if (want is not func or
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: Weird indent multiline

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here

@JDBetteridge JDBetteridge force-pushed the JDBetteridge/expand_derivative branch from d1476d6 to ab53c5d Compare May 2, 2025 17:45
@JDBetteridge
Copy link
Contributor Author

Here is a wonderful list of failing examples that this PR now addresses:

import devito as dv
import sympy as sym

grid = dv.Grid(shape=(11,), extent=(1,))
x = grid.dimensions[0]

u = dv.Function(name='u', grid=grid, space_order=4)

a = u.dx
b = a.subs({u: -5*u.dx + 4*u + 3})

# Substitution works!
print(b)

# Unintuitive results:

# .as_independent method on devito.finite_differences.differentiable.Mul doesn't work
print((-a).as_independent(x))
# Result:
## (1, -Derivative(u(x), x))

# devito.Derivatives aren't reconstructable from func and args (as per sympy docs)
du = u.dx
print(du.func(*du.args))
print(du.func(*du.args).args)
# Result
## Derivative(u(x), (x, 1))
## (u(x), ((x, 1), 1))

# default simplification causes different results:
du11 = dv.Derivative(u, x, x)
print(du11)
print(du11.deriv_order)
# Result
## Derivative(u(x), (x, 2))
## (1, 1)
du2 = dv.Derivative(u, (x, 2))
print(du2)
print(du2.deriv_order)
# Result
## Derivative(u(x), (x, 2))
## (2,)

# Whut!?
print(dv.Derivative(u, x, deriv_order=(2,4)))
# Result
## Derivative(u(x), (x, 2))

# Also pretty wild
print(dv.Derivative(u, (x, 0)))
# Result
## Derivative(u(x), x)

# Really? Double check this...
print(dv.Derivative(-1))
# Result
## Traceback (most recent call last):
##   File "<stdin>", line 1, in <module>
##   File "/media/devito/devito_py312/src/devito/devito/finite_differences/derivative.py", line 98, in __new__
##     raise ValueError("`expr` must be a Differentiable object")
## ValueError: `expr` must be a Differentiable object

# Maybe it's sufficient to handle the case:
print(dv.Derivative(sym.sympify(-1)))

# Cannot expand
print(b.expand())
# Result
## ...
## RecursionError: maximum recursion depth exceeded

with pytest.raises(TypeError):
_ = Derivative(u, (x, a))

def test_expand_mul(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these work when using the derivative shortcuts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is actually what is being tested (see the setup_class method)

Copy link
Contributor

@mloubout mloubout left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly nitpicky comments, looks like a nice cleanup

assert du11 == du2
assert du11.deriv_order == du2.deriv_order

@pytest.mark.xfail(raises=ValueError)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not the right way to test an exception. You should do

with pytest.raises(ValueError):
    Derivative(self.u, self.x, deriv_order=(2, 4))

This will mark the test a successful if it raises the right exception instead of as "accepted failure"

Same below

except Exception as e:
raise ValueError("`expr` must be a Differentiable object") from e

# Validate `dims`. It can be:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we put those in classmethods with appropriate names, like cls._process_dims, cls._process_order, ...

# Default finite difference orders depending on input dimension (.dt or .dx)
# It's possible that the expr is a `sympy.Number` at this point, which
# has derivative 0, unless we're taking a 0th derivative.
if isinstance(expr, sympy.Number):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should maybe be more general as "check if expr depends on dims" so something like if not expr.free_symbols & set(dims)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like an optimisation, but actually is not the right thing to do here. Consider the case where you want to make a substitution by using a temporary. Now you cannot construct the Derivative as it will be prematurely simplified to zero, unless you explicitly construct the correct type of temporary.

I think in the case of just a number this is a safe thing to do.

A near copy of sympy.core.expr.Expr.as_independent
with a bug fixed
"""
from sympy import Symbol
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these are external sympy imports so they shouldn't need to be local yes

@@ -113,6 +113,73 @@ def is_Staggered(self):
def is_TimeDependent(self):
return any(i.is_Time for i in self.dimensions)

def as_independent(self, *deps, **hint):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems quite brutal that we need that almost full copy paste here. Do we know which lines actually cause issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's horrible, it is this line:
if (want is not func or func is not Add and func is not Mul):
which becomes
if (want is not func or not issubclass(func, Add) and not issubclass(func, Mul)):

raise ValueError(f'Expected `(dim, deriv_order)`, got {dims[0]}')
elif len(dims) == 2 and not isinstance(dims[1], Iterable) and is_integer(dims[1]):
# special case of single dimension and order
dims = (dims, )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just use as_tuple

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesn't work in this case, dim is already a tuple, but I want a nested tuple

@@ -497,3 +505,57 @@ def _eval_fd(self, expr, **kwargs):
res = res.xreplace(e)

return res

def _eval_expand_nest(self, **hints):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would some kind of recursing singledispatch here tidy things up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think sympy is doing the right thing for us here


if func is Add: # all terms were treated as commutative
return (Add(*indep), _unevaluated_Add(*depend))
else: # handle noncommutative by stopping at first dependent term
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: This file has a lot of comments which aren't in our nominal style

cls.b = a.subs({cls.u: -5*cls.u.dx + 4*cls.u + 3})

def test_reconstructible(self):
''' Check that devito.Derivatives are reconstructible from func and args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ultra nitpick: Docstrings not matching wider style in the codebase

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reference somewhere?

Comment on lines +96 to +101
# TODO: Delete this
if kwargs.get('preprocessed', False):
from warnings import warn
warn('I removed the `preprocessed` kwarg')

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# TODO: Delete this
if kwargs.get('preprocessed', False):
from warnings import warn
warn('I removed the `preprocessed` kwarg')

@JDBetteridge JDBetteridge force-pushed the JDBetteridge/expand_derivative branch 3 times, most recently from 4b05609 to 0700fb2 Compare June 16, 2025 13:00
@JDBetteridge JDBetteridge force-pushed the JDBetteridge/expand_derivative branch from 0700fb2 to 8d66b66 Compare June 30, 2025 21:38
@JDBetteridge JDBetteridge mentioned this pull request Jul 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API api (symbolics, types, ...)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants