Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/vonmises upstream #33418

Closed

Conversation

ahmadsalim
Copy link
Contributor

@ahmadsalim ahmadsalim commented Feb 17, 2020

Third try of #33177 😄

@dr-ci
Copy link

dr-ci bot commented Feb 17, 2020

💊 CircleCI build failures summary and remediations

As of commit a3dfc73:

  • 1/1 failures introduced in this PR

Detailed failure analysis

One may explore the probable reasons each build failed interactively on the Dr. CI website.

🕵️ 1 new failure recognized by patterns

The following build failures do not appear to be due to upstream breakage:

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_test (1/1)

Step: "Test" (full log | pattern match details)

Feb 25 15:14:21 RuntimeError: test_quantization failed!
Feb 25 15:14:21 Ran 36 tests in 58.018s 
Feb 25 15:14:21  
Feb 25 15:14:21 FAILED (errors=1, skipped=1) 
Feb 25 15:14:21  
Feb 25 15:14:21 Generating XML reports... 
Feb 25 15:14:21 Traceback (most recent call last): 
Feb 25 15:14:21   File "test/run_test.py", line 493, in <module> 
Feb 25 15:14:21     main() 
Feb 25 15:14:21   File "test/run_test.py", line 486, in main 
Feb 25 15:14:21     raise RuntimeError(message) 
Feb 25 15:14:21 RuntimeError: test_quantization failed! 
Feb 25 15:14:21 + cleanup 
Feb 25 15:14:21 + retcode=1 
Feb 25 15:14:21 + set +x 
Feb 25 15:14:21 =================== sccache compilation log =================== 
Feb 25 15:14:21 =========== If your build fails, please take a look at the log above for possible reasons =========== 
Feb 25 15:14:21 Compile requests                  7 
Feb 25 15:14:21 Compile requests executed         6 
Feb 25 15:14:21 Cache hits                        0 
Feb 25 15:14:21 Cache misses                      6 
Feb 25 15:14:21 Cache timeouts                    0 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

This comment has been revised 8 times.

Add tests for VonMises logprob and sample

Signed-off-by: Ahmad Salim Al-Sibahi <ahmad@di.ku.dk>
Fix linting issues and JIT compilation

Fix proposal_r

Fix issues with JIT not working with torch.Size
@yf225 yf225 requested review from fritzo and ezyang February 19, 2020 21:05
@yf225 yf225 added module: distributions Related to torch.distributions triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Feb 19, 2020
Copy link
Collaborator

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM after moving tests down to avoid rng seed churn

test/test_distributions.py Outdated Show resolved Hide resolved
torch/distributions/von_mises.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Test errors appear unrelated, all relating to TestFuser

======================================================================
ERROR: test_abs_cpu (__main__.TestFuser)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\jit_utils.py", line 592, in wrapper
    fn(*args, **kwargs)
  File "test_jit_fuser.py", line 79, in test_abs_cpu
    self._test_fused_abs()
  File "test_jit_fuser.py", line 74, in _test_fused_abs
    self.assertAllFused(func.graph_for(a))
  File "C:\Users\circleci\project\build\win_tmp\build\torch\jit\__init__.py", line 2094, in _graph_for
    self(*args, **kwargs)
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 87, in prof_func_call
    return prof_callable(func_call, *args, **kwargs)
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 84, in prof_callable
    return callable(*args, **kwargs)
RuntimeError: Caught an unknown exception!

======================================================================
ERROR: test_chunk_correctness (__main__.TestFuser)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\jit_utils.py", line 592, in wrapper
    fn(*args, **kwargs)
  File "test_jit_fuser.py", line 223, in test_chunk_correctness
    return self._test_chunk_correctness(self, 'cpu')
  File "test_jit_fuser.py", line 218, in _test_chunk_correctness
    self.checkScript(fn, [tensor])
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\jit_utils.py", line 398, in checkScript
    opt_script_outputs = scripted_fn(*recording_inputs)
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 87, in prof_func_call
    return prof_callable(func_call, *args, **kwargs)
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 84, in prof_callable
    return callable(*args, **kwargs)
RuntimeError: Caught an unknown exception!

======================================================================
ERROR: test_scalar (__main__.TestFuser)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\jit_utils.py", line 592, in wrapper
    fn(*args, **kwargs)
  File "test_jit_fuser.py", line 851, in test_scalar
    ge = self.checkScript(fn, (x, y))
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\jit_utils.py", line 398, in checkScript
    opt_script_outputs = scripted_fn(*recording_inputs)
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 87, in prof_func_call
    return prof_callable(func_call, *args, **kwargs)
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 84, in prof_callable
    return callable(*args, **kwargs)
RuntimeError: Caught an unknown exception!

======================================================================
ERROR: test_where_and_typing (__main__.TestFuser)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\jit_utils.py", line 592, in wrapper
    fn(*args, **kwargs)
  File "test_jit_fuser.py", line 906, in test_where_and_typing
    self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'})
  File "C:\Users\circleci\project\build\win_tmp\build\torch\jit\__init__.py", line 2094, in _graph_for
    self(*args, **kwargs)
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 87, in prof_func_call
    return prof_callable(func_call, *args, **kwargs)
  File "C:\Users\circleci\project\build\win_tmp\build\torch\testing\_internal\common_utils.py", line 84, in prof_callable
    return callable(*args, **kwargs)
RuntimeError: Caught an unknown exception!

----------------------------------------------------------------------
Ran 46 tests in 12.462s

FAILED (errors=4, skipped=10)
Traceback (most recent call last):
  File "run_test.py", line 486, in <module>
    main()
  File "run_test.py", line 479, in main
    raise RuntimeError(message)
RuntimeError: test_jit_fuser failed!

@fritzo
Copy link
Collaborator

fritzo commented Feb 20, 2020

@pytestbot merge this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in 24659d2.

hczhu pushed a commit that referenced this pull request Feb 28, 2020
Summary:
Third try of #33177 😄
Pull Request resolved: #33418

Differential Revision: D20069683

Pulled By: ezyang

fbshipit-source-id: f58e45e91b672bfde2e41a4480215ba4c613f9de
ttumiel pushed a commit to ttumiel/pytorch that referenced this pull request Mar 4, 2020
Summary:
Third try of pytorch#33177 😄
Pull Request resolved: pytorch#33418

Differential Revision: D20069683

Pulled By: ezyang

fbshipit-source-id: f58e45e91b672bfde2e41a4480215ba4c613f9de
return result


@torch.jit.script
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't dig too much into the rest of this PR, but would be it OK to remove the @torch.jit.script here?

If this function is used in some other TorchScript method/module it will automatically be compiled, and having a @torch.jit.script decorator inside the core PyTorch library starts up the jit compiler anytime someone does import torch which does some expensive initialization (so it impacts all users of PyTorch instead of only those using the jit)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, please go for it. Actually, I didn't realize we aren't supposed to not put torch.jit.script decorators in the core code; we should add a lint rule for this.

Copy link
Collaborator

@fritzo fritzo Mar 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I recommended @torch.jit.script here was to handle the data-dependent control flow when jit tracing code involving this distribution. In Pyro we use this pattern of @torch.jit.scripting low level helpers that cannot be traced. Do you have a recommended workaround in PyTorch core code, such that we can support torch.jit.trace?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a PR that did pretty much what you want (do nothing, but if tracing then compile the function) in #25746, I could revive that as something like @torch.jit._lazy_script

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@driazati sure, a @torch.jit._lazy_script would be great.

driazati pushed a commit that referenced this pull request Mar 18, 2020
Some users maintain libraries of code that is largely trace-able but not
script-able. However, some functions may need to be `@torch.jit.script`ed if
they contain control flow so the tracer will use the compiler version.
This however impacts library start up time as in #33418, so this PR adds
a workaround in the form of a `@torch.jit._lazy_script_while_tracing`
that will only initialize the compiler if the function is called while
actually tracing.
facebook-github-bot pushed a commit that referenced this pull request Mar 24, 2020
Summary:
Stacked PRs
 * #34938 - [jit] Remove stray `script`
 * **#34935 - [jit] Add lazy script decorator**

Some users maintain libraries of code that is largely trace-able but not
script-able. However, some functions may need to be `torch.jit.script`ed if
they contain control flow so the tracer will use the compiler version.
This however impacts library start up time as in #33418, so this PR adds
a workaround in the form of a `torch.jit._lazy_script_while_tracing`
that will only initialize the compiler if the function is called while
actually tracing.

Pull Request resolved: #34935

Pulled By: driazati

Differential Revision: D20569778

fbshipit-source-id: d87c88c02b1abc86b283729ab8db94285d7d4853
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: distributions Related to torch.distributions open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants