Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The gradient of odeint_adjoint is zero with multiple GPUs #119

Closed
AtlantixJJ opened this issue Sep 11, 2020 · 8 comments
Closed

The gradient of odeint_adjoint is zero with multiple GPUs #119

AtlantixJJ opened this issue Sep 11, 2020 · 8 comments

Comments

@AtlantixJJ
Copy link

AtlantixJJ commented Sep 11, 2020

I found that using exactly the same code, I got the following results:

  1. Single GPU: odeint and odeint_adjoint worked just fine.
  2. Multiple GPU: odeint worked fine but odeint_adjoint always resulted in zero gradient.
  3. Using the adjoint sensitivity in torchdyn, multiple GPUs works fine.

My pytorch version is 1.5.0, torchdiffeq version is 0.1.0., CUDA version is 10.0.130, python version is 3.7.7.

I noticed that in your implementation of adjoint method, you put the odeint under torch.no_grad while torchdyn did not.

This is your code:

class OdeintAdjointMethod(torch.autograd.Function):

    @staticmethod
    def forward(ctx, shapes, func, y0, t, rtol, atol, method, options, adjoint_rtol, adjoint_atol, adjoint_method,
                adjoint_options, t_requires_grad, *adjoint_params):

        ctx.shapes = shapes
        ctx.func = func
        ctx.adjoint_rtol = adjoint_rtol
        ctx.adjoint_atol = adjoint_atol
        ctx.adjoint_method = adjoint_method
        ctx.adjoint_options = adjoint_options
        ctx.t_requires_grad = t_requires_grad

        with torch.no_grad():
            y = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options)
        ctx.save_for_backward(t, y, *adjoint_params)
        return y

This is their code: (https://github.com/DiffEqML/torchdyn/blob/master/torchdyn/sensitivity/adjoint.py)

    def _define_autograd_adjoint(self):
        class autograd_adjoint(torch.autograd.Function):
            @staticmethod
            def forward(ctx, h0, flat_params, s_span):
                sol = odeint(self.func, h0, self.s_span, rtol=self.rtol, atol=self.atol,
                             method=self.method, options=self.options)
                ctx.save_for_backward(self.s_span, self.flat_params, sol)
                return sol[-1]

            @staticmethod
            def backward(ctx, *grad_output):
                s, flat_params, sol = ctx.saved_tensors
                self.f_params = tuple(self.func.parameters())
                adj0 = self._init_adjoint_state(sol, grad_output)
                adj_sol = odeint(self.adjoint_dynamics, adj0, self.s_span.flip(0),
                               rtol=self.rtol, atol=self.atol, method=self.method, options=self.options)
                λ = adj_sol[1]
                μ = adj_sol[2]
                return (λ, μ, None)
        return autograd_adjoint

Also, I found that your FFJORD code also worked with single GPU but failed with multiple GPUs:

截屏2020-09-11 上午11 37 35

The running command is:

export CUDA_VISIBLE_DEVICES=6,7
python train_cnf.py --data mnist --dims 64,64,64 --strides 1,1,1,1 --num_blocks 2 --layer_type concat --multiscale True --rademacher True
@AtlantixJJ AtlantixJJ changed the title Gradient is zero with multiple GPU The gradient of odeint_adjoint is zero with multiple GPUs Sep 11, 2020
@rtqichen
Copy link
Owner

rtqichen commented Sep 11, 2020

Thanks for reporting this!

Went down a rabbit hole, but I found the source of the problem. It's due to nn.DataParallel's replicate removing func.parameters(). This broke backward compatibility for us in pytorch 1.5. See discussion on this here pytorch/pytorch#38493.

I'd think torchdyn's handling of parameters would have the same problem in this regard though.

I'll add a fix for this soon.

@AtlantixJJ
Copy link
Author

Thank you! Your software is really brilliant!

@rtqichen
Copy link
Owner

Should be fixed with commit d58887f. Got ffjord running.

You can install the latest version using

python -m pip install git+https://github.com/rtqichen/torchdiffeq

Let me know if it still doesn't work for you.

@AtlantixJJ
Copy link
Author

I run the command and still cannot get ffjord running.
I tried to find out whether I installed the correct version, but found that the installation is weird. The file location of torchdiffeq cannot be opened. So I don't know if something in the installation broke or your fix does not work for me.

In [1]: import torchdiffeq

In [2]: torchdiffeq.__version__
Out[2]: '0.1.0'

In [3]: torchdiffeq.__file__
Out[3]: '/home/b146466/anaconda3/envs/xjj/lib/python3.7/site-packages/torchdiffeq-0.1.0-py3.7.egg/torchdiffeq/__init__.py'

In [4]: cd /home/b146466/anaconda3/envs/xjj/lib/python3.7/site-packages/torchdiffeq-0.1.0-py3.7.egg/torchdiffeq
[Errno 20] Not a directory: '/home/b146466/anaconda3/envs/xjj/lib/python3.7/site-packages/torchdiffeq-0.1.0-py3.7.egg/torchdiffeq'
/home/b146466

@AtlantixJJ
Copy link
Author

I copied the latest torchdiffeq folder directly to ffjord and uninstalled torchdiffeq in the system, but ffjord train_cnf.py still gives zero gradient with multiple GPUs.

@rtqichen
Copy link
Owner

Can you install using

git clone git@github.com:rtqichen/torchdiffeq.git
cd torchdiffeq
pip install -e .

and try again?

I haven't updated the version yet.

@rtqichen
Copy link
Owner

Oh, I was testing on pytorch 1.6! Can you try updating?

If not, I'll take another look tomorrow with 1.5.

@AtlantixJJ
Copy link
Author

Pytorch 1.6 works with your solution. You are amazing! Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants