Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make jit compilation optional for function and use nn.Module #314

Merged
merged 16 commits into from
Nov 20, 2019

Conversation

vincentqb
Copy link
Contributor

@vincentqb vincentqb commented Oct 25, 2019

Removing @torch.jit.script decoration and using the recursive scripting now available in nn.Module has the advantage of

  • shorter library import time
  • clearer message when debugging
  • extend support to environment without jit

Since torch.jit.ScriptModule is also being phased out, torchaudio is starting the process of migrating nn.Module where possible. The deprecation is explained here. Once python 2 is deprecated, we will migrate the remaining to nn.Module, see here for currently supported syntax in python 2 and 3.

Additional details

Current changes:

  • Removing the annotation gives the choice to the user to use them with or without jit.
  • torch.nn.Module is preferred for transforms and also makes transforms jitable.
  • Add tests checking that the behavior of everything is the same in eager mode vs JIT mode
  • Fix existing tests

Questions:

  • Should we leave jit ignore in? see here (Yes, safe for now.)
  • Does jit support torch.no_grad()? see here (Not yet.)
  • Transforms currently do not use generators. nn.Module does not support generators? (generators not yet jittable.)

CC #326

@zou3519
Copy link

zou3519 commented Oct 25, 2019

Questions:

  • Should we leave jit ignore in? see here

Looks safe to leave it in. See the docs for it: https://pytorch.org/docs/stable/jit.html#torch.jit.ignore.

  • Does jit support torch.no_grad()? see here.

I don't think so, but let's ask. cc @wanchaol, I remember you (or maybe it was someone else) was looking into no_grad support for the JIT, does it support it now?

Copy link

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there tests that check that the behavior of everything is the same in eager mode vs JIT mode?

@wanchaol
Copy link

  • Does jit support torch.no_grad()? see here.

I don't think so, but let's ask. cc @wanchaol, I remember you (or maybe it was someone else) was looking into no_grad support for the JIT, does it support it now?

Yeah I was looking into torch.set_grad_enabled() and torch.no_grad() in here, given that it's a context manager rather than a regular function, JIT could not support it now.

@vincentqb
Copy link
Contributor Author

Are there tests that check that the behavior of everything is the same in eager mode vs JIT mode?

Need to add that indeed :)

@vincentqb vincentqb changed the title Avoid forcing jit functions, and use nn.Module. [WIP] Avoid forcing jit functions, and use nn.Module. Oct 28, 2019
@vincentqb vincentqb requested a review from fmassa November 4, 2019 19:35
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change looks safe to me.

Now that we have recursive scripting, it's not necessary anymore to torch.jit.script every function beforehand, as long as the main entrypoint is scripted.

So one could do something like

transforms = T.Compose([...])

if use_script:
    transforms = torch.jit.script(transforms)

and this should work

@vincentqb
Copy link
Contributor Author

vincentqb commented Nov 5, 2019

jit script has recently become recursive. We therefore no longer need to decorate every function.

Advantage of removing @torch.jit.script decoration:

  • Shorter library import time
  • Clearer message when debugging
  • Extend support to environment without jit

Need to add test for jit and non-jit. This should have code generator to either

  • test "jit vs non-jit", or
  • test "non-jit vs reference" and "jit vs reference".

Things to keep in mind:

  • generators that could be used in datasets are not currently jitable (but jit can call into python code or jit ignore?)
  • We have request from the community to make DataLoaders jitable.

Decision from offline discussion: move forward with removing decorator and adding nn.Module where possible.

@vincentqb
Copy link
Contributor Author

The jit tests will need to be updated.

@vincentqb
Copy link
Contributor Author

I've added tests for jit version is close to python, and we have test for python version is close to ground truth. We could decide to also add jit version to truth, but that would add extra code and ressource to run.

@vincentqb
Copy link
Contributor Author

I had to keep one torch.jit.ScriptModule due to python 2 and constants. We'll be able to get rid of that once python 2 is deprecated.

@vincentqb vincentqb changed the title [WIP] Avoid forcing jit functions, and use nn.Module. Make optional jit compilation for function and use nn.Module Nov 19, 2019
@vincentqb vincentqb changed the title Make optional jit compilation for function and use nn.Module Make jit compilation optional for function and use nn.Module Nov 19, 2019
Copy link

@zhangguanheng66 zhangguanheng66 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are rolling back torch.jit.ScriptModule and adding nn.Module back, we should explain the reason behind it to our users. It's better to include more details in the text.

@cpuhrsch
Copy link
Contributor

I think adding these tests for now is fine, but a more rigorous approach in my opinion would be to run all tests once jitted and once not-jitted. You could generate tests for that. pytest allows you to parameterize tests and one of those could be whether the function is to be jitted or not. Of course that also means the tests will run a lot longer (twice plus the compilation itself).

@vincentqb
Copy link
Contributor Author

Since we are rolling back torch.jit.ScriptModule and adding nn.Module back, we should explain the reason behind it to our users. It's better to include more details in the text.

Sounds good. I've updated the description of the PR.

@vincentqb
Copy link
Contributor Author

I think adding these tests for now is fine, but a more rigorous approach in my opinion would be to run all tests once jitted and once not-jitted. You could generate tests for that.

Some functions and modules do not have tests against a known solution, and are just tested for shape or successful run. Some are not even tested (e.g. augmentations.py).

The main worry with jit is that it compiled something that yields different results than what the original code would give. As such, the consistency test with jit at least ensures that is done correctly.

pytest allows you to parameterize tests and one of those could be whether the function is to be jitted or not.

Are you suggesting migrating completely to pytest? If the reason is simply for test parameterization, unittest can do that too :) It would be nice to remove one of the two framework, say keep only pytest or unittest. Since pytorch uses unittest, we should probably just stick with that one though.

Of course that also means the tests will run a lot longer (twice plus the compilation itself).

If resources are not an issue, I'm ok with having tests re-running with jit. I'd suggest doing that as part of a separate PR.

Side note: is there a way with jit to compile a whole python module?

import torch
import torchaudio.functional as F

# run tests
F = torch.jit.script(F)
# run tests again

I'm not saying this is the best way to parametrize, but it seems fun to do :)

Copy link
Contributor

@cpuhrsch cpuhrsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants