Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add binary_cross_entropy_with_logits op to ONNX opset version 12 #49675

Merged
merged 232 commits into from
Jan 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
232 commits
Select commit Hold shift + click to select a range
170908b
[ONNX] Support onnx if/loop sequence output in opset 13 - (#49270)
BowenBao Dec 17, 2020
203e181
Symbolic function for torch.square (#49446)
jiafatom Dec 17, 2020
29dd23f
[ONNX] Support onnx if/loop sequence output in opset 13 - (#49270)
BowenBao Dec 17, 2020
f7c63eb
Symbolic function for torch.square (#49446)
jiafatom Dec 17, 2020
086fcf6
[te] Fix bugs with shift operators (#49396)
bertmaher Dec 15, 2020
43aa3be
[static runtime] refine fusion group (#49340)
bwasti Dec 15, 2020
837ac43
[JIT] Support multiple outputs in subgraph matcher. (#48992)
Dec 15, 2020
4bdc202
[numpy] torch.{all/any} : output dtype is always bool (#47878)
kshitij12345 Dec 15, 2020
6428593
Replace THError() check in THCTensorMathReduce.cu with C10_CUDA_KERNE…
Dec 15, 2020
3e6bdd1
Fix include files for out-of-tree compilation (#48827)
CaoZhongZ Dec 15, 2020
9058e5f
Add flag torch_jit_disable_warning_prints to allow disabling all warn…
gmagogsfm Dec 15, 2020
f360b23
[DPER] Introduce barrier operation to force synchronization of thread…
kennyhorror Dec 16, 2020
4c667a1
[FX] Rename Node._uses and refactor Node.all_input_nodes (#49415)
Dec 16, 2020
1c9a0bf
[PyTorch] Use plain old function pointer for RecordFunctionCallback (…
swolchok Dec 16, 2020
4558c13
[CMake] Use libtorch_cuda list defined in bzl file (#49429)
malfet Dec 16, 2020
6275612
update breathe (#49407)
mattip Dec 16, 2020
7439f10
[StaticRuntime] Permute_out (#49447)
Dec 16, 2020
edea937
fix optimizer.pyi typo 'statue'->'state' (#49388)
Dec 16, 2020
1f1c0f5
[StaticRuntime] Fusion pass for ClipRanges/GatherRanges/LengthsToOffs…
Dec 16, 2020
b1547e4
quantized tensor: add preliminary support for advanced indexing, try …
vkuzo Dec 16, 2020
28a5455
Unescape string in RPC error message (#49373)
rohan-varma Dec 16, 2020
bc97e02
[StaticRuntime][ATen] Add out variant for narrow_copy (#49449)
Dec 16, 2020
1479e05
Revert D25554109: [StaticRuntime][ATen] Add out variant for narrow_copy
Dec 16, 2020
00e3716
Making ops c10 full: optional out arguments (#49083)
smessmer Dec 16, 2020
bdfa87e
Making ops c10-full: optional lists (#49088)
smessmer Dec 16, 2020
076d62f
[PyTorch] Avoid move-constructing a List in listConstruct (#49355)
swolchok Dec 16, 2020
197266d
Enhanced generators with grad-mode decorators (#49017)
ivannz Dec 16, 2020
5165093
webdataset prototype - ListDirFilesIterableDataset (#48944)
Dec 16, 2020
9745362
webdataset prototype - LoadFilesFromDiskIterableDataset (#48955)
Dec 16, 2020
bf3d1b4
CUDA BFloat embedding (#44848)
zasdfgbnm Dec 16, 2020
a213e48
Instantiate PackedConvWeight to avoid linking error (#49442)
iseeyuan Dec 16, 2020
d73c1f4
.circleci: downgrade conda-package-handling to 1.6.0 (#49434)
seemethere Dec 16, 2020
98c4a4d
[Docs] Updating init_process_group docs to indicate correct rank rang…
osalpekar Dec 16, 2020
e9c93eb
[c10d Store] Store Python Docs Fixes (#49130)
osalpekar Dec 16, 2020
6cb4910
Add sinc operator (#48740)
soulitzer Dec 16, 2020
7b4218c
Revert "Revert D24923679: Fixed einsum compatibility/performance issu…
heitorschueroff Dec 16, 2020
6a56da9
[caffe2][autograd] Avoid extensive -Wunused-variable warnings on _any…
jdonald Dec 16, 2020
5125131
Revert D25421263: [pytorch][PR] [numpy] torch.{all/any} : output dtyp…
ngimel Dec 16, 2020
c7ce84b
Reland "Add test for empty tensors for batch matmuls" (#48797)
zasdfgbnm Dec 16, 2020
1352101
Adding support for CuDNN-based LSTM with projections (#47725)
Dec 16, 2020
0991d63
Move inplace_is_vmap_compatible to BatchedTensorImpl.h (#49118)
zou3519 Dec 16, 2020
b2acf95
Update accumulate_grad to support vmap (#49119)
zou3519 Dec 16, 2020
da5c385
Update TensorPipe submodule (#49467)
lw Dec 16, 2020
94344a2
Add docs/README.md to make existing doc build info more discoverable …
rgommers Dec 16, 2020
6315a7e
Updated derivative rules for complex svd and pinverse (#47761)
IvanYashchuk Dec 16, 2020
bbaa6bb
[quant][docs] Add fx graph mode quantization to quantization docs (#4…
jerryzh168 Dec 16, 2020
0d82603
stft: Change require_complex warning to an error (#49022)
peterbell10 Dec 16, 2020
0176da6
Revert D25564477: [pytorch][PR] Add sinc operator
soulitzer Dec 16, 2020
8dcd580
Making ops c10-full: Storage arguments (#49146)
smessmer Dec 16, 2020
6f50a18
Allow zero annealing epochs (#47579)
Daniil-Osokin Dec 16, 2020
3bbc766
Revert D25507480: [quant][docs] Add fx graph mode quantization to qua…
Dec 16, 2020
e7b6a29
Fix link in distributed contributing doc and add link (#49141)
rohan-varma Dec 16, 2020
470a9cf
Add note to torch docs for sinh/cosh (#49413)
soulitzer Dec 16, 2020
ce124c2
Refine `ConvParams::use_nnpack()` (#49464)
malfet Dec 16, 2020
0998854
T66557700 Support default argument values of a method (#48863)
frankseide Dec 16, 2020
c971a62
[PyTorch] Merge CoinflipTLS into RecordFunctionTLS (#49359)
swolchok Dec 16, 2020
4df68b3
[PyTorch] Avoid extra Tensor refcounting in _cat_out_cpu (#49364)
swolchok Dec 16, 2020
bff610b
[PyTorch] Use .sizes() instead of .size() in _cat_out_cpu (#49368)
swolchok Dec 16, 2020
51e4cc9
[PyTorch] Use .sizes() isntead of .size() in cat_serial_kernel_impl (…
swolchok Dec 16, 2020
e70d3f0
[PyTorch] Make tls_local_dispatch_key_set inlineable (reapply) (#49412)
swolchok Dec 16, 2020
fb4da16
BFloat16: add explicit dtype support for to_mkldnn and to_dense (#48881)
XiaobingSuper Dec 17, 2020
15bc45f
Introduce tools.codegen.api.translate (#49122)
ezyang Dec 17, 2020
c482c5d
Revert D25569586: stft: Change require_complex warning to an error
Dec 17, 2020
fb0a942
[NNC] Dont inline outputs buffers on cpu (#49488)
Dec 17, 2020
c694e7d
Prevent accidentally writing old style ops (#49510)
smessmer Dec 17, 2020
b39b6cb
.circleci: Only downgrade if we have conda (#49519)
seemethere Dec 17, 2020
3be7381
Fix bad error message when int overflow (#48250)
Kiyosora Dec 17, 2020
12c9616
Relax the atol/rtol of layernorm math kernel test. (#49507)
Dec 17, 2020
2aa0817
Fix CUDA extension ninja build (#49344)
zasdfgbnm Dec 17, 2020
dc052aa
[extensions] fix `is_ninja_available` during cuda extension building …
stas00 Dec 17, 2020
6362b78
[NNC] Add Support For is_nan (#48973)
Dec 17, 2020
a0d6342
[NNC] add support for masked_fill (#48974)
Dec 17, 2020
08fd21f
Add fusion support of aten::to (#48976)
Dec 17, 2020
5ac65cb
eager quant: remove fake_quant after add/mul nodes during QAT (#49213)
vkuzo Dec 17, 2020
6c5a43d
fx quant: move {input|output}_quantized_idxs cfg from convert to prep…
vkuzo Dec 17, 2020
f7a7355
fx quant: do not insert observers at quantized inputs (#49239)
vkuzo Dec 17, 2020
f604f1b
fx quant: fix fq when input is quantized and node does not need fq (#…
vkuzo Dec 17, 2020
b7a36d0
fx quant: make sure observer is inserted before a quantized output (#…
vkuzo Dec 17, 2020
1aa640b
add files to SLOW_TESTS for target determinator (#49500)
Dec 17, 2020
5aed6b3
[reland] Support torch.distributed.irecv(src=None, ...) (#49383)
pritamdamania Dec 17, 2020
46971a5
Set caffe2::pthreadpool() size in ParallelOpenMP (#45566)
dbalchev Dec 17, 2020
6a59ef2
Add torch._foreach_zero_ API (#47286)
Dec 17, 2020
3b1186d
Bring back math_silu_backward which works for all backends. (#49439)
Dec 17, 2020
99ba415
[quant][be] Add typing for quantization_mappings.py (#49179)
jerryzh168 Dec 17, 2020
5494a81
Add BFloat16 support for isinf and isfinite (#49356)
zasdfgbnm Dec 17, 2020
276e68e
Change aten::native_layer_norm signature to match torch.layer_norm de…
rdspring1 Dec 17, 2020
54636e1
Adding fix for invalid annotation types for dictionary (#49425)
nikithamalgifb Dec 17, 2020
c18bc82
[pt] fuse ClipRangesGatherSigridHash (#49181)
ajyu Dec 17, 2020
2e3adbd
Revert D25574962: [pytorch][PR] Updated derivative rules for complex …
Dec 17, 2020
0a2ba5d
Remove set_quantizer_ from native_functions.yaml (#49463)
smessmer Dec 17, 2020
87a4bc5
[C2] Revive unsafe CoalesceOp (#49402)
kennyhorror Dec 17, 2020
9df6183
[AutoAccept][Codemod][FBSourceClangFormatLinter] Daily `arc lint --ta…
Dec 17, 2020
728a912
PyLong_{As/From}{Long/UnsignedLong} lint checks (#49280)
peterjc123 Dec 17, 2020
b8c8d33
[reland][quant][docs] Add fx graph mode quantization to quantization …
jerryzh168 Dec 17, 2020
2853fa3
Refactor RPC matchBuiltInOp to get rid of exception swallowing (#49009)
rohan-varma Dec 17, 2020
0567619
Revert D25105217: [pytorch][PR] Fix bad error message when int overflow
ezyang Dec 17, 2020
83f6ad5
Set is_non_overlapping_and_dense_ flag in OpaqueTensorImpl constructo…
SS-JIA Dec 17, 2020
0deecfc
Test distributed collectives profiling with Gloo on GPU (#49072)
rohan-varma Dec 17, 2020
8d6bce8
Revert D25152559: T66557700 Support default argument values of a method
iseeyuan Dec 17, 2020
e8b6219
[te] Add fast log approximation based on sleef
bwasti Dec 17, 2020
14bb5d0
[quant][eagermode][fix] Fix quantization for DeQuantStub (#49428)
jerryzh168 Dec 17, 2020
cfd0951
.github: Add action workflow to update S3 HTMLS (#49509)
seemethere Dec 17, 2020
1c90741
[FileStore] Implemented numKeys and Added Tests (#49556)
osalpekar Dec 17, 2020
9611cf3
[FileStore] Updating Docs to Reflect FileStore changes (#49557)
osalpekar Dec 17, 2020
ddddf93
Revert D25445815: [te] Add fast log approximation based on sleef
ezyang Dec 17, 2020
0e10eb7
Add dict comprehension (#47774)
Dec 17, 2020
4e1b7d2
Revert D25547962: [PyTorch] Make tls_local_dispatch_key_set inlineabl…
Dec 18, 2020
7c49006
Revert D25546409: [PyTorch] Use .sizes() isntead of .size() in cat_se…
Dec 18, 2020
917cdeb
Revert D25545777: [PyTorch] Use .sizes() instead of .size() in _cat_o…
Dec 18, 2020
c04718b
Revert D25544731: [PyTorch] Avoid extra Tensor refcounting in _cat_ou…
Dec 18, 2020
0ec5fb3
Revert D25542799: [PyTorch] Merge CoinflipTLS into RecordFunctionTLS
Dec 18, 2020
309d517
[te][reapply] Add fast log approximation based on sleef (#49575)
bwasti Dec 18, 2020
723010e
[ddp launch] solve zombie problem (#49305)
stas00 Dec 18, 2020
4c9c61e
Add more list peephole idioms (#48268)
Dec 18, 2020
29e296d
disable concat nested namespace check (#49571)
Dec 18, 2020
f7ed11e
Add type inference for dequantization.tensors (#49517)
houseroad Dec 18, 2020
1f570a0
FLOPS Roofline Analysis Feature for PyTorch Profiler. (#46506)
xuzhao9 Dec 18, 2020
8c28731
Disables method variant grad and grad grad checks (#49576)
Dec 18, 2020
14c3255
Use store based barrier in init_process_group. (#49419)
pritamdamania Dec 18, 2020
ce7608a
Fix CustomAutogradTest.ReentrantPriority rerun failures (#49581)
malfet Dec 18, 2020
a90a450
Set USE_KINETO=1 (#49201)
Dec 18, 2020
0f3059d
Revert D25480770: Set USE_KINETO=1
Dec 18, 2020
4fc5d14
Support integral types for kAbs in SimpleIREvaluator (#49357)
Dec 18, 2020
f6d0b3c
Add op bench for caffe2 quantile op (#49598)
ShijunK Dec 18, 2020
25d77c8
add checkout PR tip step for quick checks (#49590)
Dec 18, 2020
41dbb0e
Refactor VmapPhysicalView::newLogicalToPhysical (#49482)
zou3519 Dec 18, 2020
2aa7bd0
fixed the first line of torch.rst to match the __init__.py file's fir…
jonykarki Dec 18, 2020
e2bc618
Fix Module backward hooks for all Tensor inputs/outputs (#46163)
albanD Dec 18, 2020
e253a31
Remove deadlines for Caffe2 hypothesis_test when running on GPU. (#49…
Dec 18, 2020
a7d4333
[FX] Enforce args is tuple and kwargs is dict (#49526)
Dec 18, 2020
efb851b
Renaming CAFFE2_API to TORCH_API (#49496)
janeyx99 Dec 18, 2020
2e88a18
[PyTorch Mobile] Export Operator List from Mobile CompilationUnit ins…
dhruvbird Dec 18, 2020
422e2d1
New profiler API (#48280)
Dec 18, 2020
d770127
Adding support for bitwise augassignment operators (#44621)
nikithamalgifb Dec 18, 2020
3f2a6c5
Test pipeline parallelism works with DDP. (#48470)
pritamdamania Dec 18, 2020
7e483a7
[FX] Emit named tuple construction node when NamedTuple appears as an…
Dec 18, 2020
e530504
[package] implicitly extern stdlib before mocking (#49306)
zdevito Dec 18, 2020
113ca4d
Upload test times to S3 (#49190)
samestep Dec 18, 2020
c4d42b4
Cleanup APIs for pipeline parallelism. (#48630)
pritamdamania Dec 18, 2020
8bef7b7
[torchscript] Fix constant propagation schemas (#49605)
IvanKobzarev Dec 18, 2020
0839efa
Add sinc operator (#48740)
soulitzer Dec 18, 2020
3086f7f
Output stacks (support for SVG visualization) (#48438)
Dec 19, 2020
c5e477a
`torch.reciprocal`: promote integer inputs to float (#49102)
soulitzer Dec 19, 2020
55296c4
[NNC] Disable masked fill (#49622)
Dec 19, 2020
65b8aa3
[Issue #46210] added torch.fx.len() to provide support for len(); add…
huiguoo Dec 19, 2020
c8968bf
Inline coverage report combining/reporting (#49615)
malfet Dec 19, 2020
a17f9e6
[Gradient Compression] Implement the original layerwise PowerSGD (#49…
Dec 19, 2020
b375c45
Improve documentation for pipeline parallelism. (#48638)
pritamdamania Dec 19, 2020
8c75384
Add benchmark for torch.distributed.pipeline.sync.Pipe (#49577)
pritamdamania Dec 19, 2020
4dd4d0c
Bump tensorpipe version (#49599)
Dec 19, 2020
32073ec
Fix lint (#49629)
mrshenli Dec 19, 2020
cd8ef1a
[quant][graphmode][fx] Allow user to specify qconfig for call_method …
jerryzh168 Dec 19, 2020
9b0a4c6
Revert D25511543: [Gradient Compression] Implement the original layer…
mrshenli Dec 19, 2020
8d2580f
[PyTorch Mobile] Preserve bundled input related methods when calling …
bearzx Dec 19, 2020
c6210c3
Disable test on windows (#49636)
Dec 19, 2020
20c0038
Remove DataPtr extractor from CUDAFuture (#48840)
lw Dec 19, 2020
0eecd3d
disable kthvalue overlap (#48254)
guol-fnst Dec 19, 2020
5ea5c01
Resubmit: [Gradient Compression] Implement the original layerwise Pow…
Dec 20, 2020
8e25d99
Updated derivative rules for complex svd and pinverse (#47761)
IvanYashchuk Dec 20, 2020
50e3afc
[Gradient Compression] Add error feedback to layerwise PowerSGD (#49418)
Dec 21, 2020
53750d2
[Gradient Compression] Replace the assertions in PowerSGD comm hook b…
Dec 21, 2020
9098cc7
Add support for torch.tensor_split to accept a tensor for `indices` a…
Dec 21, 2020
c7f5af6
[AutoAccept][Codemod][FBSourceClangFormatLinter] Daily `arc lint --ta…
Dec 21, 2020
b410315
[WIP][DataLoader] CollateIterableDataset prototype (#48933)
ejguan Dec 21, 2020
cf9ad1f
[WIP][DataLoader] Prototype of BatchIterableDataset (#49186)
ejguan Dec 21, 2020
dc3bbaa
[WIP][DataLoader] Prototype of SamplerIterableDataset (#49363)
ejguan Dec 21, 2020
b3355fd
[Mask R-CNN]Add Int8 AABB Generate proposals Op (#49574)
anshuljain1 Dec 21, 2020
6da4c09
Fix sinc docs typo (#49667)
soulitzer Dec 21, 2020
0edf70b
Added linalg.solve (#48456)
IvanYashchuk Dec 21, 2020
0c19f79
Fix return type Any for Ternary ops (#49165)
ejguan Dec 21, 2020
ffc1c0c
Fix typo in add_pr_curve docstrings. (#49648)
theodumont Dec 21, 2020
0b652f9
Fixed a typo in dataloader.py. (#49437)
tmcclintock Dec 21, 2020
0a6a102
[NNC] Intermediate allocs flattened and dependency support (#49554)
nickgg Dec 21, 2020
11d2494
Implementing NumPy-like function torch.broadcast_to (#48997)
RockingJavaBean Dec 21, 2020
38ff78f
Sparse-sparse matrix multiplication (CPU/CUDA) (#39526)
aocsa Dec 21, 2020
209bddb
[BE] Introduce `set_cwd` context manager (#49657)
malfet Dec 21, 2020
2e52d1d
add close() method to tqdm mock (#46040)
pmeier Dec 21, 2020
3dafed5
Dynamic GRU quantization support (#49448)
raghuramank100 Dec 21, 2020
6f66ee4
converted current debugging statements in LLVM codegen to jit-logging…
huiguoo Dec 21, 2020
a20a1f9
added macros in jit logging to check whether loggings are enabled; re…
huiguoo Dec 21, 2020
8cb4a36
change block codegen to handle new inlining in NNC (#47687)
Dec 21, 2020
2af5914
Clean up backward compatibility skip list (#49691)
houseroad Dec 21, 2020
83c91f9
Enable product for bool tensor (#48637)
Kiyosora Dec 21, 2020
56115b7
Fix test_cuda_init_race skip rules (#49693)
malfet Dec 21, 2020
97d64bc
Add base forward grad logic (#49097)
albanD Dec 21, 2020
b77390b
Do not use negative values in GCD computation. (#49379)
navahgar Dec 21, 2020
4ab6172
[jit][tracer] allow traced modules to return dicts with tuple values …
bradleyhd Dec 21, 2020
eb6a2ab
Move device guard from MultiTensorApply.cuh (#46664)
Dec 22, 2020
4164cb2
Use store based barrier only for certain store types. (#49694)
pritamdamania Dec 22, 2020
ccde23b
Fix TCPStore type coercion (#49685)
H-Huang Dec 22, 2020
1e9a97f
replacing THC_CLASS and THC_API with TORCH_CUDA_API (#49690)
janeyx99 Dec 22, 2020
220afd2
Revert D25607503: Add base forward grad logic
Dec 22, 2020
1b63e24
[TensorExpr] Change `LoopNest::vectorize` to accept `For*` instead of…
Dec 22, 2020
d1fac89
[TensorExpr] Move `SimpleIREval` implementation from .h to .cpp. (#49…
Dec 22, 2020
ca537cd
unbreak mypy torch/quantization (#49549)
vkuzo Dec 22, 2020
6d8e9d3
fx quant: types for fusion_patterns.py (#49606)
vkuzo Dec 22, 2020
1de10d5
fx quant: add types to observed_module.py (#49607)
vkuzo Dec 22, 2020
1869bd7
fx quant: fix types on _find_quants (#49616)
vkuzo Dec 22, 2020
c8aefec
[FX] Fix python code having spurious newlines from placeholders (#49720)
Dec 22, 2020
f41aa50
[pt][ATen] Optimize bmm (#49506)
Dec 22, 2020
db3f718
[PyTorch] Remove direct reference to native symbols in sparse related…
iseeyuan Dec 22, 2020
4cebcbd
[Gradient Compression] Warm-start of PowerSGD (#49451)
Dec 22, 2020
10b5558
NewModuleTest: Don't call both check_jacobian and gradcheck (#49566)
zou3519 Dec 22, 2020
25c852b
[fix] inplace remainder/% (#49390)
kshitij12345 Dec 22, 2020
bc28081
Complex backward for torch.sqrt (#49461)
anjali411 Dec 22, 2020
03214d5
[ROCm] add 4.0 to nightly builds (#49632)
jeffdaily Dec 22, 2020
a813673
Make PyTorch partially cross-compilable for Apple M1 (#49701)
malfet Dec 22, 2020
2d2a1f6
[onnxifi] Get rid of class member (#49380)
khabinov Dec 22, 2020
e241e1a
Reland: Add base forward grad logic (#49734)
albanD Dec 22, 2020
1c39e42
Fix get_overlap_status for tensors without storage (#49638)
asuhan Dec 22, 2020
4406379
Minor doc fix: change truncating to rounding in TF32 docs (#49625)
Dec 22, 2020
40e15e5
remove unused THCBlas (#49725)
ngimel Dec 22, 2020
5e176cb
only upload s3 stats on master, nightly, and release branch (#49645)
Dec 22, 2020
73985d9
Merge pull request #1 from pytorch/onnx_ms_1
hwangdeyu Dec 23, 2020
2bfe745
Merge branch 'onnx_ms_1' of github.com:hwangdeyu/pytorch into onnx_ms_1
Dec 23, 2020
525ac26
[ONNX] Support onnx if/loop sequence output in opset 13 - (#49270)
BowenBao Dec 17, 2020
9259b03
Symbolic function for torch.square (#49446)
jiafatom Dec 17, 2020
1baebbb
[ONNX] Add checks in ONNXSetDynamicInputShape (#49783)
jiafatom Jan 4, 2021
4898616
[ONNX] Enable export af aten::__derive_index (#49514)
neginraoof Jan 5, 2021
eef5191
[ONNX] Update symbolic for unfold (#49378)
KsenijaS Jan 5, 2021
97a8af1
[ONNX] Update the sequence of initializers in exported graph so that …
fatcat-z Jan 5, 2021
616da7c
[ONNX] Enable opset 13 ops (#49612)
neginraoof Jan 6, 2021
b3ae16c
Merge branch 'onnx_ms_1' of https://github.com/pytorch/pytorch into p…
Jan 6, 2021
2c69cf3
t push origin onnx_ms_1:Merge branch 'pytorch-onnx_ms_1' into onnx_ms_1
Jan 6, 2021
c92808f
Reland: Add base forward grad logic (#49734)
albanD Dec 22, 2020
d144220
add binary_cross_entropy_with_logits op to ONNX opset version 12
Dec 21, 2020
e6dd64a
fix format
Dec 21, 2020
0992510
add comprehensive tests
Dec 22, 2020
cdc08ce
fix comments:fix reduction message, delete duplicate test
Dec 23, 2020
0e09ee9
Merge remote-tracking branch 'origin1/onnx_ms_1' into deyu/bce_with_l…
Jan 14, 2021
d2ebe7e
Merge remote-tracking branch 'origin1/onnx_ms_1' into deyu/bce_with_l…
Jan 15, 2021
5275cc5
replace mustBeNone() to symblic_help fuction _is_none()
Jan 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
46 changes: 46 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -5158,6 +5158,52 @@ def forward(self, input, target):
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
self.run_test(NLLModel(), (input, target))


@skipIfUnsupportedMinOpsetVersion(12)
def test_binary_cross_entropy_with_logits(self):
x = torch.randn(5)
y = torch.empty(5).random_(2)
self._bce_logits_loss(x, y)

x = torch.randn(2, 3, 5, 7)
y = torch.empty(2, 3, 5, 7).random_(2)
weight = torch.tensor([2])
self._bce_logits_loss(x, y, weight)

x = torch.FloatTensor([[-0.4089, -1.2471, 0.5907], [-0.4897, -0.8267, -0.7349], [0.5241, -0.1246, -0.4751]])
y = torch.FloatTensor([[0, 1, 1], [0, 0, 1], [1, 0, 1]])
pos_weight = torch.empty([3]).random_(2)
self._bce_logits_loss(x, y, pos_weight)

x = torch.randn(3, 3, 4)
y = torch.empty(3, 3, 4).random_(2)
weight = torch.tensor([3])
pos_weight = torch.empty([3, 4]).random_(2)
self._bce_logits_loss(x, y, weight, pos_weight)

def _bce_logits_loss(self, x, y, weight=None, pos_weight=None):
class BCEWithLogitsLossNoneWeights(torch.nn.Module):
hwangdeyu marked this conversation as resolved.
Show resolved Hide resolved
def forward(self, input, target, weight, pos_weight):
return torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight,
pos_weight=pos_weight, reduction='none')

self.run_test(BCEWithLogitsLossNoneWeights(), input=(x, y, weight, pos_weight))

class BCEWithLogitsLossMeanWeights(torch.nn.Module):
def forward(self, input, target, weight, pos_weight):
return torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight,
pos_weight=pos_weight, reduction='mean')

self.run_test(BCEWithLogitsLossMeanWeights(), input=(x, y, weight, pos_weight))

class BCEWithLogitsLossSumWeights(torch.nn.Module):
def forward(self, input, target, weight, pos_weight):
return torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight,
pos_weight=pos_weight, reduction='sum')

self.run_test(BCEWithLogitsLossSumWeights(), input=(x, y, weight, pos_weight))


def test_torch_mm(self):
class M(torch.nn.Module):
def forward(self, mat1, mat2):
Expand Down
28 changes: 28 additions & 0 deletions torch/onnx/symbolic_opset12.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,34 @@ def nll_loss2d(g, self, target, weight, reduction, ignore_index):
return nll_loss(g, self, target, weight, reduction, ignore_index)


@parse_args('v', 'v', 'v', 'v', 'i')
def binary_cross_entropy_with_logits(g, input, target, weight, pos_weight, reduction):
from torch.onnx.symbolic_opset9 import sigmoid, log, sub, neg, mul, add
p = g.op("Constant", value_t=torch.tensor([1]))
sig_x = sigmoid(g, input)
log_sig_x = log(g, sig_x)
sub_1_x = sub(g, p, sig_x)
sub_1_y = sub(g, p, target)
log_1_x = log(g, sub_1_x)
if pos_weight is None or sym_help._is_none(pos_weight):
output = neg(g, add(g, mul(g, target, log_sig_x), mul(g, sub_1_y, log_1_x)))
else:
output = neg(g, add(g, mul(g, mul(g, target, log_sig_x), pos_weight), mul(g, sub_1_y, log_1_x)))

if weight is not None and not sym_help._is_none(weight):
hwangdeyu marked this conversation as resolved.
Show resolved Hide resolved
output = mul(g, weight, output)

reduction = sym_help._maybe_get_const(reduction, 'i')
if reduction == 0:
return output
elif reduction == 1:
return g.op("ReduceMean", output)
elif reduction == 2:
return g.op("ReduceSum", output)
else:
return sym_help._onnx_unsupported("binary_cross_entropy_with_logits with reduction other than none, mean, or sum")


def celu(g, self, alpha):
alpha = sym_help._maybe_get_const(alpha, 'f')
# if the input is of type double cast it to float
Expand Down