Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add binary_cross_entropy_with_logits op to ONNX opset version 12 #49675

Merged
merged 232 commits into from Jan 20, 2021

Conversation

hwangdeyu
Copy link
Collaborator

Fixes #{#47997}
Exporting the operator binary_cross_entropy_with_logits to ONNX opset version 12.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Dec 21, 2020

💊 CI failures summary and remediations

As of commit 0e09ee9 (more details on the Dr. CI page):



5 failures not recognized by patterns:

Job Step Action
CircleCI pytorch_linux_xenial_py3_clang5_asan_test2 Run tests 🔁 rerun
CircleCI pytorch_linux_bionic_py3_8_gcc9_coverage_test2 Run tests 🔁 rerun
CircleCI pytorch_xla_linux_bionic_py3_6_clang9_test Run tests 🔁 rerun
CircleCI pytorch_linux_bionic_py3_8_gcc9_coverage_test1 Run tests 🔁 rerun
CircleCI pytorch_linux_xenial_py3_clang5_asan_test1 Run tests 🔁 rerun

❄️ 3 failures tentatively classified as flaky

but reruns have not yet been triggered to confirm:

See CircleCI build pytorch_linux_xenial_cuda10_2_cudnn7_py3_jit_legacy_test (1/3)

Step: "Unknown" (full log | diagnosis details | 🔁 rerun) ❄️

Waiting for a VM assignment: ............................................................................................................................................................................................................................................................................................................
Build-agent version 1.0.50617-fbf8220b (2021-01-14T08:47:18+0000)
Creating a dedicated VM with ubuntu-1604:202007-01 image
Waiting for a VM assignment: ............................................................................................................................................................................................................................................................................................................

We timed out preparing a VM for this build, potentially due to our infrastructure or cloud provider.  Please retry the build in a few minutes

Unexpected capacity error: error caused by capacity

See CircleCI build pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test1 (2/3)

Step: "Unknown" (full log | diagnosis details | 🔁 rerun) ❄️

Waiting for a VM assignment: ............................................................................................................................................................................................................................................................................................................
Build-agent version 1.0.50617-fbf8220b (2021-01-14T08:47:18+0000)
Creating a dedicated VM with ubuntu-1604:202007-01 image
Waiting for a VM assignment: ............................................................................................................................................................................................................................................................................................................

We timed out preparing a VM for this build, potentially due to our infrastructure or cloud provider.  Please retry the build in a few minutes

Unexpected capacity error: error caused by capacity

See CircleCI build pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test2 (3/3)

Step: "Unknown" (full log | diagnosis details | 🔁 rerun) ❄️

Waiting for a VM assignment: ............................................................................................................................................................................................................................................................................................................
Build-agent version 1.0.50617-fbf8220b (2021-01-14T08:47:18+0000)
Creating a dedicated VM with ubuntu-1604:202007-01 image
Waiting for a VM assignment: ............................................................................................................................................................................................................................................................................................................

We timed out preparing a VM for this build, potentially due to our infrastructure or cloud provider.  Please retry the build in a few minutes

Unexpected capacity error: error caused by capacity


ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@BowenBao BowenBao changed the title Add binary_cross_entropy_with_logits op to ONNX opset version 12 [ONNX] Add binary_cross_entropy_with_logits op to ONNX opset version 12 Dec 22, 2020
@BowenBao
Copy link
Collaborator

please rebase with onnx_ms_1 to resolve CI issues.

test/onnx/test_pytorch_onnx_onnxruntime.py Outdated Show resolved Hide resolved
torch/onnx/symbolic_opset12.py Outdated Show resolved Hide resolved
test/onnx/test_pytorch_onnx_onnxruntime.py Show resolved Hide resolved
test/onnx/test_pytorch_onnx_onnxruntime.py Outdated Show resolved Hide resolved
bertmaher and others added 17 commits December 23, 2020 15:17
Summary:
Pull Request resolved: pytorch#49396

Pull Request resolved: pytorch#49271

Two things:

1. These throw exceptions in their constructor, which causes a segfault (*), so
   move the exceptions to ::make.
2. They technically support FP types but the rules are complicated so let's not
   bother.

(*) The reason for the segfault: all Exprs including these inherit from
KernelScopedObject, whose constructor adds the object to a list for destruction
at the end of the containing KernelArena's lifetime.  But if the derived-class
constructor throws, the object is deleted even though it's still in the
KernelArena's list.  So when the KernelArena is itself deleted, it double-frees
the pointer and dies.  I've also fixed And, Or, and Xor in this diff.
ghstack-source-id: 118594998

Test Plan: `buck test //caffe2/test:jit`

Reviewed By: bwasti

Differential Revision: D25512052

fbshipit-source-id: 42670b3be0cc1600dc5cda6811f7f270a2c88bba
Summary:
Pull Request resolved: pytorch#49340

This refines the fusion group to include on certain types of operations.  We cannot safely handle "canRunNatively" types and the memonger pass causes regressions on some internal models, so it was disabled (to be revisited with proper memory optimization once Tensor pools are implemented)

Test Plan:
```
buck test mode/no-gpu caffe2/test:static_runtime
buck test //caffe2/benchmarks/static_runtime:static_runtime_cpptest
```

Reviewed By: ZolotukhinM

Differential Revision: D25520105

fbshipit-source-id: add61d103e4f8b4615f5402e760893ef759a60a9
Summary: Pull Request resolved: pytorch#48992

Differential Revision: D25388100

Test Plan: Imported from OSS

Reviewed By: heitorschueroff

Pulled By: ZolotukhinM

fbshipit-source-id: d95713af2220cf4f99ac92f59f8e5b902f2f3822
Summary:
BC-breaking note:

This PR changes the behavior of the any and all functions to always return a bool tensor. Previously these functions were only defined on bool and uint8 tensors, and when called on uint8 tensors they would also return a uint8 tensor. (When called on a bool tensor they would return a bool tensor.)

PR summary:

pytorch#44790 (comment)

Fixes 2 and 3

Also Fixes pytorch#48352

Changes
* Output dtype is always `bool` (consistent with numpy) **BC Breaking (Previously used to match the input dtype**)
* Uses vectorized version for all dtypes on CPU
* Enables test for complex
* Update doc for `torch.all` and `torch.any`

TODO
* [x] Update docs
* [x] Benchmark
* [x] Raise issue on XLA

Pull Request resolved: pytorch#47878

Reviewed By: H-Huang

Differential Revision: D25421263

Pulled By: mruberry

fbshipit-source-id: c6c681ef94004d2bcc787be61a72aa059b333e69
…L_LAUNCH_CHECK() (pytorch#49424)

Summary:
Pull Request resolved: pytorch#49424

As per conversation in this [comment](https://www.internalfb.com/intern/diff/D25541113 (https://github.com/pytorch/pytorch/commit/e2510a0b60232aba5160ceb18b6ece8c59a9b79d)/?dest_fbid=393026838623691&transaction_id=3818008671564312) on D25541113 (pytorch@e2510a0), although THError does more than just log any errors associated cuda kernel launches, we're going to go ahead and replace it with C10_CUDA_KERNEL_LAUNCH_CHECK, so as to be consistent throughout the code base.
Standardization FTW.

This commit is purposefully sent in as a single file change so it can be easily reverted if it introduces a regression.

Test Plan:
Checked that the code still builds with
```
buck build //caffe2/aten:ATen-cu
```
Also ran basic aten tests
```
buck test //caffe2/aten:atest
```

Reviewed By: r-barnes

Differential Revision: D25567863

fbshipit-source-id: 1093bfe2b6ca6b9a3bfb79dcdc5d713f6025eb77
Summary:
Signed-off-by: caozhong <zhong.z.cao@intel.com>

Pull Request resolved: pytorch#48827

Reviewed By: agolynski

Differential Revision: D25375988

Pulled By: ailzhang

fbshipit-source-id: a8d5ab4572d991d6d96dfe758011517651ff0a6b
…ings.warn (pytorch#49313)

Summary:
Adding a flag torch_jit_disable_warning_prints to optimize interpreter performance by suppressing (potentially large amount) of warnings.warn.

This is to work around TorchScript's warning behavior mismatch with Python. Python by default triggers a warning once per location but TorchScript doesn't support it. This causes same warning to trigger and print once per inference run, hurting performance.

Pull Request resolved: pytorch#49313

Reviewed By: SplitInfinity

Differential Revision: D25534274

Pulled By: gmagogsfm

fbshipit-source-id: eaeb57a335c3e6c7eb259671645db05d781e80a2
…s in async execution (pytorch#49322)

Summary:
Pull Request resolved: pytorch#49322

In some cases async execution might loose dependencies (Alias like ops) or produce suboptimal scheduling when there is an option which parts to schedule first. Example of the later behavior can happen in ModelParallel training where copy can get lower priority compared to the rest of the execution on the given GPU, which will caused other GPUs to starve.

This operator allows to address these issues by introducing extra explicit dependencies between ops.

Test Plan:
Unit-test/
E2E testing in the future diffs.

Reviewed By: xianjiec

Differential Revision: D24933471

fbshipit-source-id: 1668994c7856d73926cde022378a99e1e8db3567
Summary: Pull Request resolved: pytorch#49415

Test Plan: Imported from OSS

Reviewed By: zdevito

Differential Revision: D25565341

Pulled By: jamesr66a

fbshipit-source-id: 2290ab62572632788809ba16319578bf0c0260ee
…reapply) (pytorch#49408)

Summary:
Pull Request resolved: pytorch#49408

Nearly every non-test callsite doesn't need to capture any variables anyway, and this saves 48 bytes per callback.
ghstack-source-id: 118665808

Test Plan:
Wait for GitHub CI since we had C++14-specific issues with
this one in previous PR pytorch#48629

Reviewed By: malfet

Differential Revision: D25563207

fbshipit-source-id: 6a2831205917d465f8248ca37429ba2428d5626d
Summary:
Since NCCL is an optional CUDA dependency, remove nccl.cpp from the core filelist

Pull Request resolved: pytorch#49429

Reviewed By: nikithamalgifb

Differential Revision: D25569883

Pulled By: malfet

fbshipit-source-id: 61371a4c6b0438e4e0a7f094975b9a9f9ffa4032
Summary:
Fixes pytorch#47462, but not completely.

Update breathe to the latest version to get fixes for the "Unable to resolve..." issues. There are still some build errors, but much fewer than before.

Pull Request resolved: pytorch#49407

Reviewed By: izdeby

Differential Revision: D25562163

Pulled By: glaringlee

fbshipit-source-id: 91bfd9e9ac70723816309f489022d72853f5fdc5
Summary:
Pull Request resolved: pytorch#49447

Adding an out variant for `permute`. It's better than fixing the copy inside contiguous because 1) we can leverage the c2 math library, 2) contiguous creates a tensor inside the function which isn't managed by the MemoryPlanner in StaticRuntime

Test Plan:
Benchmark:
```
After:
I1214 12:35:32.218775 991920 PyTorchPredictorBenchLib.cpp:209] PyTorch run finished. Milliseconds per iter: 0.0902339. Iters per second: 11082.3

Before:
I1214 12:35:43.368770 992620 PyTorchPredictorBenchLib.cpp:209] PyTorch run finished. Milliseconds per iter: 0.0961521. Iters per second: 10400.2
```

Reviewed By: yinghai

Differential Revision: D25541666

fbshipit-source-id: 013ed0d4080cd01de4d3e1b031ab51e5032e6651
Summary: Pull Request resolved: pytorch#49388

Test Plan: Imported from OSS

Reviewed By: zou3519

Differential Revision: D25553672

Pulled By: glaringlee

fbshipit-source-id: e9f2233bd678a90768844af2d8d5e2994d59e304
…ets (pytorch#49113)

Summary: Pull Request resolved: pytorch#49113

Reviewed By: ajyu

Differential Revision: D25388512

fbshipit-source-id: 3daa5b9387a3a10b6c220688df06540c4d844aea
pytorch#49346)

Summary:
Pull Request resolved: pytorch#49346

This is less ambitious redo of
pytorch#49129.

We make the

```
xq_slice = xq[:, [0], :, :]
```

indexing syntax work if `xq` is a quantized Tensor.  For now, we are
making the code not crash, with an in efficient `dq -> index -> q`
implementation.  A future PR can optimize performance by removing
the unnecessary memory copies (which will require some non-trivial
changes to TensorIterator).

Test Plan:
```
python test/test_quantization.py TestQuantizedOps.test_advanced_indexing
```

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D25539365

fbshipit-source-id: 98485875aaaf5743e1a940e170258057691be4fa
Summary:
Pull Request resolved: pytorch#49373

Unescaping the string in RPC error message to provide better error msg

Test Plan: CI

Reviewed By: xush6528

Differential Revision: D25511730

fbshipit-source-id: 054f46d5ffbcb1350012362a023fafb1fe57fca1
@hwangdeyu hwangdeyu requested a review from albanD as a code owner January 6, 2021 06:09
Copy link
Collaborator

@BowenBao BowenBao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks! minor comments with improving helper function.

torch/onnx/symbolic_opset12.py Outdated Show resolved Hide resolved
torch/onnx/symbolic_opset12.py Show resolved Hide resolved
@BowenBao BowenBao merged commit 566406c into pytorch:onnx_ms_1 Jan 20, 2021
BowenBao added a commit that referenced this pull request Jan 21, 2021
…12 (#49675)

Fixes #{#47997}
Exporting the operator binary_cross_entropy_with_logits to ONNX opset version 12.

[ghstack-poisoned]
BowenBao added a commit that referenced this pull request Jan 21, 2021
…et version 12 (#49675)"

Fixes #{#47997}
Exporting the operator binary_cross_entropy_with_logits to ONNX opset version 12.

[ghstack-poisoned]
BowenBao added a commit that referenced this pull request Jan 22, 2021
…et version 12 (#49675)"

Fixes #{#47997}
Exporting the operator binary_cross_entropy_with_logits to ONNX opset version 12.

[ghstack-poisoned]
BowenBao added a commit that referenced this pull request Jan 22, 2021
…et version 12 (#49675)"

Fixes #{#47997}
Exporting the operator binary_cross_entropy_with_logits to ONNX opset version 12.

Differential Revision: [D26023936](https://our.internmc.facebook.com/intern/diff/D26023936)

[ghstack-poisoned]
BowenBao added a commit that referenced this pull request Jan 25, 2021
…et version 12 (#49675)"


Fixes #{#47997}
Exporting the operator binary_cross_entropy_with_logits to ONNX opset version 12.

[ghstack-poisoned]
BowenBao added a commit that referenced this pull request Jan 25, 2021
…et version 12 (#49675)"


Fixes #{#47997}
Exporting the operator binary_cross_entropy_with_logits to ONNX opset version 12.

Differential Revision: [D26050885](https://our.internmc.facebook.com/intern/diff/D26050885)

[ghstack-poisoned]
BowenBao added a commit that referenced this pull request Jan 26, 2021
…et version 12 (#49675)"


Fixes #{#47997}
Exporting the operator binary_cross_entropy_with_logits to ONNX opset version 12.

Differential Revision: [D26050885](https://our.internmc.facebook.com/intern/diff/D26050885)

[ghstack-poisoned]
facebook-github-bot pushed a commit that referenced this pull request Jan 28, 2021
…12 (#49675) (#50908)

Summary:
Pull Request resolved: #50908

Fixes #{#47997}
Exporting the operator binary_cross_entropy_with_logits to ONNX opset version 12.

Test Plan: Imported from OSS

Reviewed By: pbelevich

Differential Revision: D26050885

Pulled By: SplitInfinity

fbshipit-source-id: e4167895eed804739aa50481679500a4d564b360
BowenBao added a commit to BowenBao/pytorch that referenced this pull request Jan 28, 2021
…12 (pytorch#49675)

Fixes #{pytorch#47997}
Exporting the operator binary_cross_entropy_with_logits to ONNX opset version 12.

ghstack-source-id: 4d3467df7821e1499788cc18ae6f57c973c28d49
Pull Request resolved: pytorch#50908
@hwangdeyu hwangdeyu deleted the deyu/bce_with_logits_sy12 branch August 30, 2021 03:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet