Skip to content

Commit

Permalink
[ONNX] Update documentation (#58712)
Browse files Browse the repository at this point in the history
* Add introductory paragraph explaining what ONNX is and what the
  torch.onnx module does.
* In "Tracing vs Scripting" and doc-string for torch.onnx.export(),
  clarify that exporting always happens on ScriptModules and that
  tracing and scripting are the two ways to produce a ScriptModule.
* Remove examples of using Caffe2 to run exported models.
  Caffe2's website says it's deprecated, so it's probably best not to
  encourage people to use it by including it in examples.
* Remove a lot of content that's redundant:
  * The example of how to mix tracing and scripting, and instead
    link to Introduction to TorchScript, which includes very similar
    content.
  * "Type annotations" section. Link to TorchScript docs which explain
    that in more detail.
  * "Using dictionaries to handle Named Arguments as model inputs"
    section. It's redundant with the description of the `args` argument
    to `export()`, which appears on the same page once the HTML
    is generated.
  * Remove the list of supported Tensor indexing patterns. If it's not
    in the list of unsupported patterns, users can assume it's
    supported, so having both is redundant.
  * Remove the list of supported operators and models.
    I think the list of supported operators is not very useful.
    A list of supported model architectures may be useful, but in
    reality it's already very out of date. We should add it back if
    / when we have a system for keeping it up to date.
  * "Operator Export Type" section. It's redundant with the description
    of the `operator_export_type` arg to to `export()`, which appears on
    the same page once the HTML is generated.
  * "Use external data format" section. It's redundant with the
    description of the `use_external_data_format` arg to `export()`.
  * "Training" section.  It's redundant with the
    description of the `training` arg to `export()`.
* Move the content about different operator implementations producing
  different results from the "Limitations" section into the doc for the
  `operator_export_type` arg.
* Document "quantized" -> "caffe2" behavior of
  OperatorExportTypes.ONNX_ATEN_FALLBACK.
* Combing the text about using torch.Tensor.item() and the text about
  using NumPy types into a section titled
  "Avoid NumPy and built-in Python types", since they're both
  fundamentally about the same issue.
* Rename "Write PyTorch model in Torch way" to "Avoiding Pitfalls".
* Lots of minor fixes: spelling, grammar, brevity, fixing links, adding
  links.
* Clarify limitation on input and output types. Phrasing it in terms of
  PyTorch types is much more accessible than in terms of TorchScript
  types. Also clarify what actually happens when dict and str are used
  as inputs and outputs.
* In Supported operators, use torch function and class names and link
  to them. This is more user friendly than using the internal aten
  op names.
* Remove references to VariableType.h, which doesn't appear to contain
  the information that it once did. Instead refer to the generated
  .pyi files.
* Remove the text in the FAQ about appending to lists within loops.
  I think this limitation is no longer present
  (perhaps since #51577).
* Minor fixes to some code I read along the way.
* Explain the current rationale for the weird ::prim_PythonOp op name.

Co-authored-by: Gary Miguel <garymiguel@microsoft.com>

ghstack-source-id: 13223efb59e88e505b2d5fbb534b1e234d547d95
Pull Request resolved: #60249
  • Loading branch information
BowenBao committed Jun 25, 2021
1 parent a6d475e commit 639039a
Show file tree
Hide file tree
Showing 4 changed files with 558 additions and 1,218 deletions.
28 changes: 15 additions & 13 deletions caffe2/contrib/aten/docs/sample.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import tempfile

import numpy as np

from torch import nn
Expand All @@ -10,7 +12,8 @@
class MyFunction(Function):
@staticmethod
def forward(ctx, x, y):
return x*x + y
return x * x + y

@staticmethod
def symbolic(graph, x, y):
x2 = graph.at("mul", x, x)
Expand All @@ -26,21 +29,20 @@ def forward(self, x, y):
x = nn.ReLU()(x)
return MyFunction.apply(x, y)

f = tempfile.NamedTemporaryFile()
torch.onnx.export(MyModule(),
(Variable(torch.ones(3,4)), Variable(torch.ones(3,4))),
"output.onnx",
verbose=True)
(Variable(torch.ones(3, 4)), Variable(torch.ones(3, 4))),
f, verbose=True)

# prints the graph for debugging:
# graph(%1 : Float(3, 4)
# %2 : Float(3, 4)) {
# %3 : Float(3, 4) = Relu(%1), uses = [%4.i0, %4.i1];
# %4 : UNKNOWN_TYPE = ATen[operator=mul](%3, %3), uses = [%5.i0];
# %5 : Float(3, 4) = ATen[operator=add](%4, %2), uses = [%0.i0];
# return (%5);
# }
# graph(%input : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
# %y : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
# %2 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = onnx::Relu(%input)
# %3 : Tensor = onnx::ATen[operator="mul"](%2, %2)
# %4 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = onnx::ATen[operator="add"](%3, %y)
# return (%4)

graph = onnx.load("output.onnx")
graph = onnx.load(f.name)

a = np.random.randn(3, 4).astype(np.float32)
b = np.random.randn(3, 4).astype(np.float32)
Expand All @@ -50,5 +52,5 @@ def forward(self, x, y):
c2_out = prepared_backend.run(W)[0]

x = np.maximum(a, 0)
r = x*x + b
r = x * x + b
np.testing.assert_array_almost_equal(r, c2_out)

0 comments on commit 639039a

Please sign in to comment.