Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite #115

Merged
merged 31 commits into from
Aug 8, 2020
Merged

Rewrite #115

merged 31 commits into from
Aug 8, 2020

Conversation

patrick-kidger
Copy link
Contributor

Hey Ricky. Following on from our discussion in #113 I've made a number of changes.

Highlights:

  • Standardised (and tidied up) a lot of the code
    • Adaptive RK methods now share code
    • Tests now share code
  • Added adjoint_params as an argument to odeint_adjoint.
  • Updated a lot of documentation.
  • Extended the tests (Now take about 10 minutes to run in total.)
  • Sped things up by switching the internals from operating on tuples-of-tensors to just tensors, with a compatibility layer now going the other way for the input/output.
    • This removes a surprising amount of overhead from the Python interpreter. I see a large (~50%) speedup on some problems I'm considering atm - specifically NCDE/ODE-RNN models; as you'd expect these are more heavily affected by solver overhead than func-dominated conv models.

The interface should be stable, and there are no backward compatibility issues that I'm aware of.

A miscellany of other things:

  • I'm using Python>=3.6 features. This is now the minimum version supported by PyTorch so I think this should be fine, but this isn't crucial and can be changed if you like.
  • One thing which I've left alone, but would like to change, is to standardise the default rtol/atol values for odeint and odeint_adjoint: at present they are different. Thoughts?
  • The 'tsit5' and 'adams' solvers have been removed from the list of solvers: the first is known to be broken and the latter consistently fails tests (even before these changes), so it probably has a bug too.
  • 'fixed_adams' has been renamed to 'implicit_adams' for consistency. (The old name also remains for backward compatibility.)
  • Fixed a bug in _select_initial_step (giving a NaN step size) when both y0 and f0 are simultaneously zero. Previously this raised a spurious error; now it works correctly.
  • Fixed a bug in _select_initial_step in that it was h0 = 0.01 * max(d0 / d1) rather than h0 = 0.01 * max(d0) / max(d1). This will of course produce a slight change in the end results, but with the default settings nothing that I've noticed up to a few decimal places. The number of NFEs seems to shift either up or down a bit depending on the problem. (In particular the adjoint ran afoul of this bug - skipping the details, the initial step was always taken to be h1, and 100 * h0 was never used.)
  • Added a norm argument to the adaptive solvers, to set the norm wrt which the accept/reject step is done. (Currently used for example in the adjoint.)
  • I've bumped the version number.

I realise this is quite a long list of things. Let me know your thoughts.

…, _flat_to_shape, grid_points. Renamed adams methods to be clearer about the difference between each of them. Demagickified runge_kutta_step.
Fixed bug breaking explicit adams
Fixed wrong atol in odeint
Added dtype casting after interacting with tableaus etc.
Added correct dtype to tableau etc. tensors
Added version number to source.
Copy link
Owner

@rtqichen rtqichen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Patrick,

Thanks for the bug fixes, and the changes in general look awesome. I had a quick skim and left a couple comments for now. I'll have a closer look (and play around with it a bit first) in the next week or so.

I'm a bit worried about switching the internals from operating on tuples-of-tensors to just tensors because this was put in so there wouldn't need to be concatenations of large matrices in every function evaluation. Other than compute, this uses extra memory since each function evaluation will require whole copies of the state (+ parameters in the adjoint solve).

But I can also see how the overhead of multiple CUDA kernels for the additions and multiplications can slow things down, especially in solving the adjoint state. I'm going to play around with this a bit first. If it doesn't increase the wall-clock speed of single solves for some worst case examples too much, then I'm all for it since it's also easier to code the backend.

I tried benchmarking this for my experiments, but adjoint currently fails when options is provided but does not contain the norm keyword. Likely some part of the logic in adjoint.py needs to be fixed. Can you take a look? Example below.

(Looking back, max over RMS was a bit of a weird choice. I think I tried treating each Tensor state as a separate ODE.)

Here's the example:

import torch
import torch.nn as nn
from torchdiffeq import odeint_adjoint

class ODEfunc(nn.Module):

    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(2, 10), nn.Linear(10, 2))

    def forward(self, t, state):
        x, y = state
        dx = self.net(x)
        dy = torch.zeros_like(y)
        return dx, dy

f = ODEfunc()
initial_state = (torch.randn(10, 2), torch.randn(10000))
options = {}
sol = odeint_adjoint(f, initial_state, torch.linspace(0.0, 1.0, 20), options=options)
print(sol)

sol[0].sum().backward()

torchdiffeq/_impl/solvers.py Outdated Show resolved Hide resolved
torchdiffeq/_impl/adjoint.py Outdated Show resolved Hide resolved
@patrick-kidger
Copy link
Contributor Author

patrick-kidger commented Aug 5, 2020

I'm very glad things seem positive overall.

Regarding tensors vs tuple-of-tensors: this occurred to me; I found that the extra cat didn't add very much e.g. 740MB->790MB for a 600k parameter convolutional model.

Fixed the adjoint bug, thanks.

And hah, I did wonder about the norms. The adaptive RK used the mixed norm whilst the documentation claimed it used the sup norm. Meanwhile the convergence of implicit adams is done with a sup norm (which I've left alone), and tsit5 uses L2. 😄

I've realised that we use torch 1.3.0 features (torch.promote_types) so I've upped the required version in setup.py. Fine?

Regarding comparisons on benchmarks - I found that the combination of smaller problems + adjoint got a bit slower (in the sense of NFEs; they got faster by wall clock time) due to the second bugfix for _select_initial_step. The particular case of adjoint equations produces rather small 100 * h0, and h1 can actually be a better choice. For now I've just accepted that; but we could try and do something about it?

Copy link
Owner

@rtqichen rtqichen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to report I'm only seeing improvements in speed in my experiments. I left a couple minor comments, but let's merge this!

It should be fine to use this more conservative / robust option for initial_step_size as the default. People interested in speed can simply set the initial step size.

As for standardizing atol and rtol values, what do you think about setting them to 1.49012e-8? This would be in line with scipy's odeint and jax's as well, though it's likely designed for double precision..

torchdiffeq/_impl/dopri8.py Outdated Show resolved Hide resolved
torchdiffeq/__init__.py Outdated Show resolved Hide resolved
@patrick-kidger
Copy link
Contributor Author

Excellent!

Regarding atol, rtol - honestly I'm not sure. The value you suggest is indeed calibrated to double precision (1.49012e-8 = sqrt(finfo(float64).eps)), and I'm not sure I'd want to automatically change the precision based on the input dtype, that feels a bit too magic to me. I'd note that scipy's solve_ivp has default values of 1e-3, 1e-5, so standards seem to be all over the show.

I think in order to least surprise people transitioning from the previous version then we should try and use values similar to what are already used (maybe take the more stringent tolerance from odeint and odeint_adjoint for each parameter), but that's not a strong feeling - if there's a good reason to change the default tolerances then now's a good time to do it, IMO.

Given the new version, do you think it'd be nice to publicise this on reddit/twitter? Off the top of my head, at least relative to the documentation before, the new features are:

  • New solvers:
    • dopri8,
    • bosh3,
    • adaptive_heun.
  • Rapidly calculating discontinuous vector fields (for example when using different parameters on different pieces), via
    • grid_points and eps for adaptive solvers,
    • eps for fixed solvers.
  • Finer control over which tensors adjoint gradients are calculated for via adjoint_params.
  • Custom control over what settings to use in the adjoint via adjoint_rtol, adjoint_atol, adjoint_method, adjoint_options.
  • Setting particular step locations for fixed solvers via grid_constructor.
  • Setting the norm that accept/reject steps are calculated with respect to for adaptive solvers.
  • Up to 50% speed improvements due to a rewrite of the internals!
  • New examples on latent ODEs and continuous normalising flows.
  • Updated documentation.

@rtqichen
Copy link
Owner

rtqichen commented Aug 7, 2020

Let's use odeint's default values for odeint_adjoint. Anyone seriously using this library and using adjoint would likely use their own tolerance values, and keeping odeint's values seems like a safe bet. (Though I wonder whether atol=1e-8 even does anything for single precision.)

I have one more major feature that I'd like to add before publicizing these accumulated changes, if that's okay with you? Should be done by end of September. We can also show some nice motivations / visualizations for the grid_points feature if we publicize after that research work of yours is ready. Publicizing the new changes is a great idea though.

@patrick-kidger
Copy link
Contributor Author

Your logic makes sense; odeint's values it is.

For publicising, sure thing. I'm curious what the "major feature" is though! I understand if that's under wraps for the time being though. :)

@rtqichen rtqichen merged commit cb0de6e into rtqichen:master Aug 8, 2020
@rtqichen
Copy link
Owner

rtqichen commented Aug 8, 2020

Thanks for the PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants