Skip to content

Conversation

ezyang
Copy link
Contributor

@ezyang ezyang commented Aug 29, 2017

This is a big PR. I don't recommend looking at the commits individually, although you may find it useful to blame lines here and there. (Naming disclaimer: the project uses "toffee" in some places, but this will need to be renamed.)

A lot of folks contributed to this PR. Shout outs to @zdevito, @apaszke, @prigoyal, @killeent, @houseroad; as well as indirect contributions from @Yangqing, @bddppq and @jerryzh168.

User facing interface. At the moment, most functionality is made available from torch.jit and torch.toffee modules. Some refactoring is on going for these interfaces; torch.jit will be made into a private namespace and torch.toffee will grow end-to-end conversion functions.

Key new pieces of code.

  • The compiler IR. There is a new, mutable, SSA, in-memory compiler IR representation which represents a computation in PyTorch. We maintain accurate def-use chains. The IR is defined in torch/csrc/jit/ir.h. Some unusual things about the IR:
    • Deep learning networks generally have multiple inputs AND outputs. However, computation nodes in our IR only have a single output; instead, we project out the particular output we are interested in using a Select node. The "Select invariant" states that any multiple return node (that's most nodes) always has exactly as many Select nodes as outputs. There are no first-class tuples in the language.
    • Computations in the graph are associated with stages, specifying whether or not they occurred in forwards/backwards/backwards-backwards. A computation in some stage cannot depend on a computation from a later stage; in fact, earlier stage computations will always be earlier in the IR than later stage computations.
    • Nodes in the graph are dynamically typed; while there are some special nodes (CppOp and PythonOp), most nodes are represented as strings with arbitrary dictionaries of attribute names to attribute values. Names are represented as interned strings in torch/csrc/jit/interned_strings.h; attribute dynamic dictionaries implemented in torch/csrc/jit/attributes.h.
    • There is a linter to enforce invariants. Run the linter early and often!
    • There are Python bindings using pybind11 for the IR in torch/csrc/jit/python_ir.h. These bindings are used in the definition of exporters for autograd functions.
  • The forward tracer. The forward tracer instruments invocations to autograd Functions in the forward pass and records the execution into the compiler IR.
    • We implemented this tracing logic twice: once for C++ functions, and once for Python functions. The key functions are tracedApply in torch/csrc/autograd/function.cpp (C++ side) and trace_create in torch/csrc/autograd/python_function.cpp. Each of these functions are traced as PythonOps and CppOps.
    • There is a calling convention change for C++ autograd functions; instead of writing Transpose(0, 1).apply({ggI})[0] you write apply_fn<Transpose>(0, 1)(ggI).
    • The tracer can only trace what code was actually executed; e.g., it only sees the conditional branch that was taken.
    • Variables are marked as participating in a trace or not. A variable cannot participate in multiple traces simultaneously. You MUST dispose of a trace object (GC-wise) before you can use it again in another trace.
  • The Toffee exporter. The Toffee exporter takes a raw trace from the forward tracer (consisting of PythonOps and CppOps) and converts it to a Toffee format IR, which can then be exported to protobufs. The code lives in torch/csrc/toffee/export.cpp. Feel free to ignore the logic inside the passes_state_transparently conditional; this is for handling backwards traces.
    • The way transformations are implemented is for each original op, we invoke the primspec associated with it (a new method associated with autograd Functions). This primspec takes as input the Toffee graph being constructed, and the inputs to the node (in the new Toffee IR graph), and extends the graph with the Toffee version of itself. Most primspecs are in Python, but C++ ops like BatchNorm and Convolution are done in C++.
    • After this transformation, we export it to protobufs using the nanopb library. We've checked in the nanopb-protoc generated files for Toffee IR in torch/csrc/toffee.pb.h and cpp; nanopb has a pretty user-unfriendly interface, so there is a small wrapper around it in torch/csrc/toffee.h which simulates the Google protobuf C++ interface. See 619e20dee57e9642a8357bdb5f23e51f8b88d05e how it works.
    • The manually checked in autogen files can be updated with gen_toffee.sh. If there are major changes to the protobuf, you may also have to update the C++ wrapper code.

Experimental code. There are some experimental pieces which are not exercised by the model exporter, but we'd like to merge in. I don't recommend spending too much time on these pieces of code; they are at high risk of being rewritten/being heavily refactored, and they shouldn't affect anything in PyTorch proper.

  • Optimization passes. We have optimization passes for:
    • An "init" pass (soon to be removed; torch/crsc/jit/init_pass.cpp)
    • Dead code elimination (torch/csrc/jit/dead_code_elimination.cpp)
    • Operator fusion (torch/csrc/jit/graph_fusion.cpp). Once operators are fused into fusion groups, the fusion compiler (torch/csrc/jit/fusion_compiler.cpp) handles actually compiling this to CUDA. This has only been lightly tested and is expected to be broken at the moment.
  • The trace executor. The trace executor takes a compiler IR trace, and converts it back into an autograd closure which we can then execute using the autograd engine. The main functionality for this is in torch/crsc/autograd/functions/jit_closure.cpp.
  • The backwards tracer. The backwards tracer is responsible for intercepting invocations of backward functions which were traced, and recording their backward traces to the trace (so forwards and backwards can be cooptimized.) A good chunk of the functionality here is implemented in torch/csrc/autograd/functions/special.cpp
    • One significant source of complexity in the implementation is handling autograd functions whose backward passes are not traceable. In this case, we must fall back on the old PyTorch style (namely, generate an autograd trace when we compute forwards, and run that trace in backwards).

Miscellaneous changes. (In order GitHub renders files)

  • New dependencies: pybind11 and nanopb. Pybind11 is a pretty nice wrapper for Python that can handle many conversions automatically. We intend keep migrating over PyTorch's Python wrapping code to use Pybind11 when applicable. Nanopb is a very small protobuf library; we use it to generate protobufs for the exporter. This prevents users from having to go through the pain of getting protobufs to work.
  • Submodules. PyTorch repo now has submodules: gloo (pre-existing), pybind11 and nanopb. The model for when to subtree versus when to submodule is primarily a question of how likely we will need to apply local patches to a project in PyTorch, versus being able to go through the upstream cycle. There is also a technical reason why gloo was submoduled, see gloo subtree has submodule, which causes submodule init to fail #2426
  • Expect tests in the test suite. A new test method assertExpected lets you assert that a string matches some file which we have saved to disk. If the format changes, you can simply pass --accept to accept the new output. This makes it easy to see before and after in diffs.
  • A new test suite for the JIT. While the JIT is not a public facing component in this diff, it is a key part of the export pipeline, and so tests for it are in test_jit.py
  • There are C++ ports of chunk and split added to torch/csrc/autograd/functions/tensor.cpp
  • torch/csrc/autograd/engine.cpp Autograd is now reentrant; you can call autograd backward from within a backwards function. The implementation is the diff to engine.cpp, and our strategy is to just call back into the scheduling loop so we continue to service requests for the worker thread even when we have kicked off a subgraph task.
  • torch/csrc/autograd/engine.cpp Autograd callbacks have been generalized to support both pre and post callbacks. A pre callback is called before a function is invoked and can be used to rewrite inputs; a post callback is called after the function is invoked and we have output callbacks.
  • New TemplateEnv class in torch/csrc/jit/code_template.h which can be used to format C++ code fragments. It supports key-based template variables and can handle indentation correctly.
  • New utility auto_unique_ptr which constructs its object on the first dereference (rather than eagerly)
  • New utility ResourceGuard, an RAII class which can be explicitly released within its lexical scope, or automatically released when the scope ends.
  • New utility fmap which applies function to all elements of a vector (functionally, so no mutation)
  • New parameter to state_dict(), keep_vars, which causes state_dict to return parameters as Variables rather than tensors. This is used internally by the tracer.

@ezyang
Copy link
Contributor Author

ezyang commented Aug 29, 2017

Jenkins build is broken because it doesn't understand submodules, I'm on it.

gen_toffee.sh Outdated

This comment was marked as off-topic.

test/common.py Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

test/test_jit.py Outdated

This comment was marked as off-topic.

test/test_jit.py Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

torch/jit.py Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

@ezyang
Copy link
Contributor Author

ezyang commented Aug 31, 2017

Latest batch of commits from 'jit' branch are cdecfae..d618864.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Also, add a new trace_fn field to attach forward IR to Variables.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Simple test:

  import torch
  from torch.autograd import Variable
  import torch._C as _C

  x = Variable(torch.Tensor([4]), requires_grad=True)
  y = Variable(torch.Tensor([7]), requires_grad=True)
  z = x * y
  z.sum().backward()

  print(x.grad)
  print(y.grad)

  x.data[0] = 2
  y.data[0] = 3

  (z,) = z._execution_engine.run_forward((x, y), (z,))
  z.sum().backward()

  print(x.grad)
  print(y.grad)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Previously, our AST was a DAG, where shared Nodes indicated a computation
should be reused.  This commit rewrites the IR into a new functional
representation which represents sharing explicitly using variable
bindings.

We offer a few justifications for this new style:

1. The new representation is not all that different from the
old one; it is about as easy to construct, and the lack of an
explicit graph doesn't negatively impact our ability to interpret
the graph, since we've chosen, as a matter of design, to NOT have
the IR participate in the actual execution of a graph.

2. The new let-binding representation has an implicit ordering,
which we can use to conveniently keep track of the original order
the trace showed up as.  This automatically gives us a topsort,
and gives us an easier to read textual representation of our
IR:

  %14 = Embedding %11, %0, -1, None, 2, False, False
  %15 = Dropout %14, 0.2, True, False
  %16 = Index %12, 0
  %17 = Index %12, 1
  %18 = Index %13, 0
  %19 = Index %13, 1
  %20 = Index %15, 0
  %21 = Linear %20, %1, %3
  %22 = Linear %16, %2, %4

3. It moves us closer to a Futhark style language
(http://futhark-lang.org/publications/pldi17.pdf).

Major aspects of the diff

- Node is replaced with Expr and Arg, a pair of mutually recursive
  structures which represent our new language.  In BNF, the language
  looks like this:

    a ::= c | %i
    e ::= %i, ... = e
        | PyOp e, ...
        | Ret %i, ...

  Technically, Ret is not actually a return (no control flow is involved),
  it just tuples up a series of tensors (identified by variables).

  One important invariant is that locals are always tensors; they
  are never constants (this is asymmetric with Args.)

- Arguments support Python constants.  This is an important piece because
  many operators take extra Python literals like integers and tuples in
  order to specify extra parameters about how an operator operates.  Adding
  this was essential to getting word_language_model to work.

- As both Expr and Arg have multiple variants, there is new infrastructure
  for doing case on the variants using ExprVisitor and ArgVisitor.  The
  strategy here is adapted from WebAssembly's visitors, although we have
  generalized to permit arbitrary argument forwarding, which is necessary
  to support tail-recursive visitor calls.  TCO is important because our
  interpreter may recurse arbitrarily deep into a stack of nested lets.
  If users wish, they can also manually case on the type tag.

- Tracing is now turned on and off using _tracer_enter/_tracer_exit in
  torch._C.  _tracer_enter accepts a list of variables which are to be
  treated as arguments; _tracer_exit accepts the list of traced variables
  which should be returned when you reexecute the trace, and returns
  the trace expression which can be reexecuted.  GlobalTracingState
  is a global variable which tracks whether or not we are tracing or not.

- You use run_forward to execute a trace on some set of parameters.

- When under tracing, variables keep track, via trace_local, what the
  name of their variables in the IR are.

Here is a simple runner which leaks memory but can be used to JIT models:

  import torch.autograd.function as F
  import torch._C

  def jit(model):
      import types
      real_forward = model.forward
      def forward(self, *args):
          def flatten(x):
              return tuple(F._iter_variables(x))
          if not hasattr(self, "saved_trace"):
              torch._C._tracer_enter(tuple(self.parameters()) + flatten(args))
              out = real_forward(*args)
              self.saved_trace = torch._C._tracer_exit(flatten(out))
              self.saved_outs = out
              return out
          else:
              flat_out = Variable._execution_engine.run_forward(self.saved_trace, tuple(self.parameters()) + flatten(args))
              return F._unflatten(flat_out, self.saved_outs)

Major problems:

- Sanity checking is spotty at best, especially when users pass in variables.

- The interpreter leaks tensor memory from the store.  When we add back def-use
  we should be able to deallocate tensors as soon as we know they are no longer
  necessary.

- The interpreter needs to reach feature parity with the old execution engine.
  From there, we need to see if backwards can be subsumed as well.

- I still have no confidence in having memory managed everything correctly.
  This requires a close look.

- Rather than return an *open* expression as a trace, we should return a
  *lambda* instead, which knows about how many formal parameters it
  requires.

- The IR is not introspectable from Python at the moment, but this is simply a
  matter of implementing all the binding code.

- The tracer is NOT reentrant (you can't trace while you're inside a trace.)
  Furthermore, no sanity checking is done if you try to incorrectly reuse
  things from one trace in another.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Although ANF style developments traditionally stratifies syntactic
classes into atomic (Arg) and complex (Expr) expressions, where
atomic expressions could be variables, constants or lambdas, Zach has
successfully convinced me that we should do away with the variant here and
always require arguments to be variables.  There are a few reasons for
this:

1) Tensor constants, not currently supported, could be modeled using a
"Constant" instruction, removing the need for them to be representable
directly inline.  An inline constant is marginally more convenient
for peephole optimizations, but since we have gone full ANF, we are going
to need to be able to see across def-uses in any case, and it is not
too much worse to need to handle constants this way.  By the way,
Swift Intermediate Language also made a similar choice, see
the slide on "Literal Instructions" in
http://llvm.org/devmtg/2015-10/slides/GroffLattner-SILHighLevelIR.pdf

2) Scalar constants, which are quite important for passing non-tensor
arguments to Python operators, are now stored out-of-band as NON
first-class values.  This more closely matches the ToffeeIR design,
and makes it clear what parameters are "first class" (tensors only)
and which ones are not.  However, we need to be able to unswizzle
the separate scalar/tensor lists into a unified list in the correct
format; this is what PyFunctionCConv is for.

Also, Locals got renamed into Tuple.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
This prevents nested lets, which are not allowed in ANF.  We
basically have SSA now.

There's some niftiness with the visitor returning a lambda which
then gets fed the actual argument. I like it.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
It is not an /expression/ we trace, but it is a /graph/: that is,
a closed expression which knows its parameters.  Knowing the list
of parameters is helpful and helps remove a hack when interpreting.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
ezyang and others added 16 commits September 5, 2017 07:21
- BC BREAKING: export now also takes a mandatory file-ish argument, specifying
  the file to export the protobuf to.  I rewrote the tests to use BytesIO to
  get out the string so they could parse it again.

- BC BREAKING: export no longer returns the tensors that were computed.  To
  get these, use the internal _export function.

- Multiple inputs to models are now supported by passing a tuple to input.
  (Old API of a single Variable still works.)

- Keyword arguments to models are now supported via kwargs keyword arg.

- Renamed embed_params to export_params, and it now defaults to True.

- Toffee tests now live in their own test_toffee.py file.  I had to
  rename a pile of expect files for this.

- Removed defunct torch.toffee imports from autograd to solve module import
  cycle.

- Helper function _with_file_like to abstract over opening file-ish arguments,
  taken from torch.save()

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
This adds the PyTorch API user documentation for Toffee.
To make the example work, I also converted all "inplace"
ops to export out-of-place in Toffee.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
This is a case of two wrongs make a right.  There were a pair of
related bugs;

- We incorrectly translated Transpose as if it were a Permute;
  but Torch transpose actually is a *swap* between dimensions.

- Why didn't we ever notice it?  In all of our tests, a transpose
  was *solely* done to get a weight matrix into the correct form.
  But Caffe2's FC operator *implicitly* does a transpose on
  the weight matrix.

This commit fixes both of these problems.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
uses suffixes to disambiguate attribute types
This was a doozy!

- 'namespace' is a C++ reserved keyword, so if you have a field named
  this, nanopb will blithely export some malformed C++.  I submitted
  a PR for this: https://github.com/ProjectToffee/ToffeeIR/pull/88

- Zach added support for singular tensor and graph.  While attempting
  to add support for these, I realized that it was actually impossible
  to support them under the default protobuf translation.  The gory
  details are in Note [Callback for nested messages].  The singular
  callbacks needed a new helper which I dubbed msg; it's just
  the singular version of list.

- While I was working on the API, I braino'd with the tensor()
  method.  It turns out this is totally not the right way to think
  about it; it's more string_from_tensor().  So I renamed it.
  I also renamed add_tensor to set_raw_data; add_tensor is a misnomer
  since it implies you can add multiple tensors, which is not true.

- version turned into producer_version.  Actually, this is a bit
  questionable and might change soon.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
- Conv no longer supports bias, so we create an explicit broadcasted
  addition afterwards.  There is one minor problem, however, which is that
  ConvTranspose in Caffe2 has mandatory bias.  So there's a hack.
  See Note [Caffe2ConvTranspose] for the details.
- Squeeze: dims -> axes
- Transpose: axes -> perm
- Reshape lost its extra output (yay!)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
- kernels -> kernel_shape
- Use the new hybrid dict/tuple result object from Toffee
- Write g and t as singulars, not plural
- nanopb generated files update
- Bugfix for msg() micropb helper
- Start recording producer_version/producer_tag
- Use ir_version from proto description
- Value -> value (Constant)
- Remove special-casing for transposed convolution; we now rely
  on the Caffe2 Toffee backend to do something reasonable
- Batchnorm order is no more

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>

@staticmethod
def primspec(g, input, dim_indices):
if dim_indices == range(0, len(dim_indices)):

This comment was marked as off-topic.

saves the resulting traced model to ``alexnet.proto``. (We recommend
running this inference on GPU, because PyTorch does not have an efficient
CPU convolution implementation. It may take some time to initialize the
CUDA instance.)::

This comment was marked as off-topic.


@staticmethod
def primspec(g, input, weight):
return g.appendNode(g.create("PRelu", [input, weight]))

This comment was marked as off-topic.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
@ezyang
Copy link
Contributor Author

ezyang commented Sep 5, 2017

@soumith When CI turns green this is ready to merge.

@soumith soumith merged commit 4fc54af into pytorch:master Sep 5, 2017
houseroad added a commit to houseroad/pytorch that referenced this pull request Feb 7, 2020
…0e6a5e (pytorch#33075)

Summary:
Pull Request resolved: pytorch#33075

Previous import was 65020daafa9183c769938b4512ce543fd5740f8f

Included changes:
- **[8b3f7e2e](onnx/onnx@8b3f7e2e)**: Update Dropout and  BatchNorm to be Training Friendly (pytorch#2568) <Lara Haidar>
- **[61f0bbc5](onnx/onnx@61f0bbc5)**: Fix a bug in ScatterND shape inference (pytorch#2577) <Bowen Bao>
- **[05bce9cf](onnx/onnx@05bce9cf)**: add utility function to make reference attribute whose name is not the same as the attribute it refers. (pytorch#2583) <Ke Zhang>
- **[71181c83](onnx/onnx@71181c83)**: Clarify spec for constant of shape with dim_n = 0 (pytorch#2567) <Negin Raoof>
- **[eadba733](onnx/onnx@eadba733)**: Update sigs.md with link to calendar page (pytorch#2579) <Prasanth Pulavarthi>
- **[08562f8e](onnx/onnx@08562f8e)**: Update working-groups.md (pytorch#2580) <Prasanth Pulavarthi>
- **[0e718913](onnx/onnx@0e718913)**: Fix Slice op's shape inference logic (pytorch#2526) <Hariharan Seshadri>
- **[12111410](onnx/onnx@12111410)**: Add missing spaces to Random*Like doc (pytorch#2572) <Takeshi Watanabe>
- **[7e6e61d6](onnx/onnx@7e6e61d6)**: Contributing: fix typos (pytorch#2571) <Maher Jendoubi>
- **[bbd604ef](onnx/onnx@bbd604ef)**: Add Einsum op (pytorch#2504) <Negin Raoof>
- **[fd3ab73a](onnx/onnx@fd3ab73a)**: Clarify split supports zero length splits (pytorch#2544) <Negin Raoof>
- **[6dd73774](onnx/onnx@6dd73774)**: Fix circleci build and drop unsupported Windows builds (pytorch#2565) <Wei-Sheng Chin>
- **[b3d201a2](onnx/onnx@b3d201a2)**: Fix the formula of intermediate zero calculation for DynamicQuantizeLinear (pytorch#2556) <Yufeng Li>
- **[3613eb25](onnx/onnx@3613eb25)**: Add wording to clarify. (pytorch#2555) <Dwayne Robinson>
- **[dfa4384c](onnx/onnx@dfa4384c)**: Fix shape inference for Split with split attribute (pytorch#2328) <Shinichiro Hamaji>
- **[684fc1bc](onnx/onnx@684fc1bc)**: Keep symbolic dims in Concat with a single input (pytorch#2418) <Shinichiro Hamaji>

Test Plan: ci

Reviewed By: hl475

Differential Revision: D19784487

fbshipit-source-id: 3f445d88a014e9d07d1522b9457dd467c100abec
facebook-github-bot pushed a commit that referenced this pull request Feb 7, 2020
…0e6a5e (#33075)

Summary:
Pull Request resolved: #33075

Previous import was 65020daafa9183c769938b4512ce543fd5740f8f

Included changes:
- **[8b3f7e2e](onnx/onnx@8b3f7e2e)**: Update Dropout and  BatchNorm to be Training Friendly (#2568) <Lara Haidar>
- **[61f0bbc5](onnx/onnx@61f0bbc5)**: Fix a bug in ScatterND shape inference (#2577) <Bowen Bao>
- **[05bce9cf](onnx/onnx@05bce9cf)**: add utility function to make reference attribute whose name is not the same as the attribute it refers. (#2583) <Ke Zhang>
- **[71181c83](onnx/onnx@71181c83)**: Clarify spec for constant of shape with dim_n = 0 (#2567) <Negin Raoof>
- **[eadba733](onnx/onnx@eadba733)**: Update sigs.md with link to calendar page (#2579) <Prasanth Pulavarthi>
- **[08562f8e](onnx/onnx@08562f8e)**: Update working-groups.md (#2580) <Prasanth Pulavarthi>
- **[0e718913](onnx/onnx@0e718913)**: Fix Slice op's shape inference logic (#2526) <Hariharan Seshadri>
- **[12111410](onnx/onnx@12111410)**: Add missing spaces to Random*Like doc (#2572) <Takeshi Watanabe>
- **[7e6e61d6](onnx/onnx@7e6e61d6)**: Contributing: fix typos (#2571) <Maher Jendoubi>
- **[bbd604ef](onnx/onnx@bbd604ef)**: Add Einsum op (#2504) <Negin Raoof>
- **[fd3ab73a](onnx/onnx@fd3ab73a)**: Clarify split supports zero length splits (#2544) <Negin Raoof>
- **[6dd73774](onnx/onnx@6dd73774)**: Fix circleci build and drop unsupported Windows builds (#2565) <Wei-Sheng Chin>
- **[b3d201a2](onnx/onnx@b3d201a2)**: Fix the formula of intermediate zero calculation for DynamicQuantizeLinear (#2556) <Yufeng Li>
- **[3613eb25](onnx/onnx@3613eb25)**: Add wording to clarify. (#2555) <Dwayne Robinson>
- **[dfa4384c](onnx/onnx@dfa4384c)**: Fix shape inference for Split with split attribute (#2328) <Shinichiro Hamaji>
- **[684fc1bc](onnx/onnx@684fc1bc)**: Keep symbolic dims in Concat with a single input (#2418) <Shinichiro Hamaji>

Test Plan: ci

Reviewed By: hl475

Differential Revision: D19784487

fbshipit-source-id: 421cdc3394faeff0168853f4ff065fc599ca3967
ttumiel pushed a commit to ttumiel/pytorch that referenced this pull request Mar 4, 2020
…0e6a5e (pytorch#33075)

Summary:
Pull Request resolved: pytorch#33075

Previous import was 65020daafa9183c769938b4512ce543fd5740f8f

Included changes:
- **[8b3f7e2e](onnx/onnx@8b3f7e2e)**: Update Dropout and  BatchNorm to be Training Friendly (pytorch#2568) <Lara Haidar>
- **[61f0bbc5](onnx/onnx@61f0bbc5)**: Fix a bug in ScatterND shape inference (pytorch#2577) <Bowen Bao>
- **[05bce9cf](onnx/onnx@05bce9cf)**: add utility function to make reference attribute whose name is not the same as the attribute it refers. (pytorch#2583) <Ke Zhang>
- **[71181c83](onnx/onnx@71181c83)**: Clarify spec for constant of shape with dim_n = 0 (pytorch#2567) <Negin Raoof>
- **[eadba733](onnx/onnx@eadba733)**: Update sigs.md with link to calendar page (pytorch#2579) <Prasanth Pulavarthi>
- **[08562f8e](onnx/onnx@08562f8e)**: Update working-groups.md (pytorch#2580) <Prasanth Pulavarthi>
- **[0e718913](onnx/onnx@0e718913)**: Fix Slice op's shape inference logic (pytorch#2526) <Hariharan Seshadri>
- **[12111410](onnx/onnx@12111410)**: Add missing spaces to Random*Like doc (pytorch#2572) <Takeshi Watanabe>
- **[7e6e61d6](onnx/onnx@7e6e61d6)**: Contributing: fix typos (pytorch#2571) <Maher Jendoubi>
- **[bbd604ef](onnx/onnx@bbd604ef)**: Add Einsum op (pytorch#2504) <Negin Raoof>
- **[fd3ab73a](onnx/onnx@fd3ab73a)**: Clarify split supports zero length splits (pytorch#2544) <Negin Raoof>
- **[6dd73774](onnx/onnx@6dd73774)**: Fix circleci build and drop unsupported Windows builds (pytorch#2565) <Wei-Sheng Chin>
- **[b3d201a2](onnx/onnx@b3d201a2)**: Fix the formula of intermediate zero calculation for DynamicQuantizeLinear (pytorch#2556) <Yufeng Li>
- **[3613eb25](onnx/onnx@3613eb25)**: Add wording to clarify. (pytorch#2555) <Dwayne Robinson>
- **[dfa4384c](onnx/onnx@dfa4384c)**: Fix shape inference for Split with split attribute (pytorch#2328) <Shinichiro Hamaji>
- **[684fc1bc](onnx/onnx@684fc1bc)**: Keep symbolic dims in Concat with a single input (pytorch#2418) <Shinichiro Hamaji>

Test Plan: ci

Reviewed By: hl475

Differential Revision: D19784487

fbshipit-source-id: 421cdc3394faeff0168853f4ff065fc599ca3967
samnordmann pushed a commit to samnordmann/pytorch that referenced this pull request Mar 13, 2023
Fixes pytorch#2564

Co-authored-by: Jacob Hinkle <jhinkle@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants