Skip to content

Commit

Permalink
Update base for Update on "[pytorch][PR] Add AVX512 support in ATen &…
Browse files Browse the repository at this point in the history
… remove AVX support"

### Remaining Tasks

- [ ] Collate results of benchmarks on two Intel Xeon machines (with & without CUDA, to check if CPU throttling causes issues with GPUs) - make graphs, including Roofline model plots (Intel Advisor can't make them with libgomp, though, but with Intel OpenMP).

### Summary

1. This draft PR produces binaries with with 3 types of ATen kernels - default, AVX2, AVX512 . Using the environment variable `ATEN_AVX512_256=TRUE`  also results in 3 types of kernels, but the compiler can use 32 ymm registers for AVX2, instead of the default 16. ATen kernels for `CPU_CAPABILITY_AVX` have been removed.

2. `nansum` is not using AVX512 kernel right now, as it has poorer accuracy for Float16, than does AVX2 or DEFAULT, whose respective accuracies aren't very good either (#59415).
It was more convenient to disable AVX512 dispatch for all dtypes of `nansum` for now.

3. On Windows , ATen Quantized AVX512 kernels are not being used, as quantization tests are flaky. If `--continue-through-failure` is used, then `test_compare_model_outputs_functional_static` fails. But if this test is skipped, `test_compare_model_outputs_conv_static` fails. If both these tests are skipped, then a third one fails. These are hard to debug right now due to not having access to a Windows machine with AVX512 support, so it was more convenient to disable AVX512 dispatch of all ATen Quantized kernels on Windows for now.

4. One test is currently being skipped -
[test_lstm` in `quantization.bc](#59098) - It fails only on Cascade Lake machines, irrespective of the `ATEN_CPU_CAPABILITY` used, because FBGEMM uses `AVX512_VNNI` on machines that support it. The value of `reduce_range` should be used as `False` on such machines.

The list of the changes is at https://gist.github.com/imaginary-person/4b4fda660534f0493bf9573d511a878d.


Credits to @ezyang for proposing `AVX512_256` - these use AVX2 intrinsics but benefit from 32 registers, instead of the 16 ymm registers that AVX2 uses.
Credits to @limo1996 for the initial proposal, and for optimizing `hsub_pd` & `hadd_pd`, which didn't have direct AVX512 equivalents, and are being used in some kernels. He also refactored `vec/functional.h` to remove duplicated code.
Credits to @quickwritereader for helping fix 4 failing complex multiplication & division tests.

### Testing
1. `vec_test_all_types` was modified to test basic AVX512 support, as tests already existed for AVX2.
Only one test had to be modified, as it was hardcoded for AVX2.
2.  `pytorch_linux_bionic_py3_8_gcc9_coverage_test1` & `pytorch_linux_bionic_py3_8_gcc9_coverage_test2` are now using `linux.2xlarge` instances, as they support AVX512. They were used for testing AVX512 kernels, as AVX512 kernels are being used by default in both of the CI checks. Windows CI checks had already been using machines with AVX512 support.

### Would the downclocking caused by AVX512 pose an issue?

I think it's important to note that AVX2 causes downclocking as well, and the additional downclocking caused by AVX512 may not hamper performance on some Skylake machines & beyond, because of the double vector-size. I think that [this post with verifiable references is a must-read](https://community.intel.com/t5/Software-Tuning-Performance/Unexpected-power-vs-cores-profile-for-MKL-kernels-on-modern-Xeon/m-p/1133869/highlight/true#M6450). Also, AVX512 would _probably not_ hurt performance on a high-end machine, [but measurements are recommended](https://lemire.me/blog/2018/09/07/avx-512-when-and-how-to-use-these-new-instructions/). In case it does, `ATEN_AVX512_256=TRUE` can be used for building PyTorch, as AVX2 can then use 32 ymm registers instead of the default 16. [FBGEMM uses `AVX512_256` only on Xeon D processors](pytorch/FBGEMM#209), which are said to have poor AVX512 performance.

This [official data](https://www.intel.com/content/dam/www/public/us/en/documents/specification-updates/xeon-scalable-spec-update.pdf) is for the Intel Skylake family, and the first link helps understand its significance. Cascade Lake & Ice Lake SP Xeon processors are said to be even better when it comes to AVX512 performance.

Here is the corresponding data for [Cascade Lake](https://cdrdv2.intel.com/v1/dl/getContent/338848) -

![CASCADE LAKE AVX2](https://user-images.githubusercontent.com/76181208/120666172-ffec3f80-c451-11eb-8ea1-8933ccc12a1b.PNG)
![CASCADE LAKE AVX512](https://user-images.githubusercontent.com/76181208/120666190-04b0f380-c452-11eb-9faa-38d233c874c8.PNG)

The corresponding data isn't publicly available for Intel Xeon SP 3rd gen (Ice Lake SP), but [Intel mentioned that the 3rd gen has frequency improvements pertaining to AVX512](https://newsroom.intel.com/wp-content/uploads/sites/11/2021/04/3rd-Gen-Intel-Xeon-Scalable-Platform-Press-Presentation-281884.pdf). Ice Lake SP machines also have 48 KB L1D caches, so that's another reason for AVX512 performance to be better on them.


### Is PyTorch always faster with AVX512?

No, but then PyTorch is not always faster with AVX2 either. Please refer to #60202. The benefit from vectorization is apparent with with small tensors that fit in caches or in kernels that are more compute heavy. For instance, AVX512 or AVX2 would yield no benefit for adding two 64 MB tensors, but adding two 1 MB tensors would do well with AVX2, and even more so with AVX512.

It seems that memory-bound computations, such as adding two 64 MB tensors can be slow with vectorization (depending upon the number of threads used), as the effects of downclocking can then be observed.

Original pull request: #56992

Differential Revision: [D29266289](https://our.internmc.facebook.com/intern/diff/D29266289/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D29266289/)!

[ghstack-poisoned]
  • Loading branch information
ezyang committed Jul 20, 2021
2 parents 641f6ef + 59a5312 commit 6040712
Show file tree
Hide file tree
Showing 23 changed files with 323 additions and 503 deletions.
6 changes: 4 additions & 2 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Expand Up @@ -1272,7 +1272,8 @@ Tensor cholesky(const Tensor &self, bool upper) {
"and\n"
"U = torch.cholesky(A, upper=True)\n",
"should be replaced with\n",
"U = torch.linalg.cholesky(A.transpose(-2, -1).conj()).transpose(-2, -1).conj()"
"U = torch.linalg.cholesky(A).transpose(-2, -1).conj().\n"
"This transform will produce equivalent results for all valid (symmetric positive definite) inputs."
);
if (self.numel() == 0) {
return at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
Expand Down Expand Up @@ -1310,7 +1311,8 @@ Tensor& cholesky_out(const Tensor &self, bool upper, Tensor &result) {
"and\n"
"U = torch.cholesky(A, upper=True)\n",
"should be replaced with\n",
"U = torch.linalg.cholesky(A.transpose(-2, -1).conj()).transpose(-2, -1).conj()"
"U = torch.linalg.cholesky(A).transpose(-2, -1).conj().\n"
"This transform will produce equivalent results for all valid (symmetric positive definite) inputs."
);
checkSameDevice("cholesky", result, self);
checkLinalgCompatibleDtype("cholesky", result, self);
Expand Down
78 changes: 21 additions & 57 deletions aten/src/ATen/native/BinaryOps.cpp
Expand Up @@ -127,6 +127,18 @@ TORCH_META_FUNC2(bitwise_right_shift, Tensor) (
build_borrowing_binary_op(maybe_get_output(), self, other);
}

TORCH_META_FUNC2(bitwise_and, Tensor) (const Tensor& self, const Tensor& other) {
build_borrowing_binary_op(maybe_get_output(), self, other);
}

TORCH_META_FUNC2(bitwise_or, Tensor) (const Tensor& self, const Tensor& other) {
build_borrowing_binary_op(maybe_get_output(), self, other);
}

TORCH_META_FUNC2(bitwise_xor, Tensor) (const Tensor& self, const Tensor& other) {
build_borrowing_binary_op(maybe_get_output(), self, other);
}

TORCH_META_FUNC2(fmod, Tensor) (const Tensor& self, const Tensor& other) {
build_borrowing_binary_op(maybe_get_output(), self, other);
}
Expand Down Expand Up @@ -366,6 +378,9 @@ TORCH_IMPL_FUNC(func_out) (const Tensor& self, const Tensor& other, const Tensor
func_stub(device_type(), *this); \
}

CREATE_BINARY_TORCH_IMPL_FUNC(bitwise_and_out, bitwise_and_stub);
CREATE_BINARY_TORCH_IMPL_FUNC(bitwise_or_out, bitwise_or_stub);
CREATE_BINARY_TORCH_IMPL_FUNC(bitwise_xor_out, bitwise_xor_stub);
CREATE_BINARY_TORCH_IMPL_FUNC(maximum_out, maximum_stub);
CREATE_BINARY_TORCH_IMPL_FUNC(minimum_out, minimum_stub);
CREATE_BINARY_TORCH_IMPL_FUNC(fmax_out, fmax_stub);
Expand Down Expand Up @@ -711,33 +726,16 @@ Tensor rsub(const Tensor& self, const Scalar& other, const Scalar& alpha) {
return native::rsub(self, wrapped_scalar_tensor(other), alpha);
}

Tensor& bitwise_and_out(const Tensor& self, const Tensor& other, Tensor& result) {
auto iter = TensorIterator::binary_op(result, self, other);
bitwise_and_stub(iter.device_type(), iter);
return result;
}

Tensor bitwise_and(const Tensor& self, const Tensor& other) {
Tensor result = at::empty({0}, self.options());
at::bitwise_and_out(result, self, other);
return result;
}

Tensor& bitwise_and_(Tensor& self, const Tensor& other) {
return at::bitwise_and_out(self, self, other);
}

Tensor& bitwise_and_out(const Tensor& self, const Scalar& other, Tensor& result) {
return at::bitwise_and_out(result, self, wrapped_scalar_tensor(other));
}

Tensor bitwise_and(const Tensor& self, const Scalar& other) {
Tensor result = at::empty({0}, self.options());
return at::bitwise_and_out(result, self, other);
return at::bitwise_and(self, wrapped_scalar_tensor(other));
}

Tensor& bitwise_and_(Tensor& self, const Scalar& other) {
return at::bitwise_and_out(self, self, other);
return self.bitwise_and_(wrapped_scalar_tensor(other));
}

// Legacy and interfaces. They are aliased to bitwise_and* functions
Expand All @@ -757,33 +755,16 @@ Tensor& __iand__(Tensor& self, const Scalar& other) {
return self.bitwise_and_(other);
}

Tensor& bitwise_or_out(const Tensor& self, const Tensor& other, Tensor& result) {
auto iter = TensorIterator::binary_op(result, self, other);
bitwise_or_stub(iter.device_type(), iter);
return result;
}

Tensor bitwise_or(const Tensor& self, const Tensor& other) {
Tensor result = at::empty({0}, self.options());
at::bitwise_or_out(result, self, other);
return result;
}

Tensor& bitwise_or_(Tensor& self, const Tensor& other) {
return at::bitwise_or_out(self, self, other);
}

Tensor& bitwise_or_out(const Tensor& self, const Scalar& other, Tensor& result) {
return at::bitwise_or_out(result, self, wrapped_scalar_tensor(other));
}

Tensor bitwise_or(const Tensor& self, const Scalar& other) {
Tensor result = at::empty({0}, self.options());
return at::bitwise_or_out(result, self, other);
return at::bitwise_or(self, wrapped_scalar_tensor(other));
}

Tensor& bitwise_or_(Tensor& self, const Scalar& other) {
return at::bitwise_or_out(self, self, other);
return self.bitwise_or_(wrapped_scalar_tensor(other));
}

// Legacy or interfaces. They are aliased to bitwise_or* functions
Expand All @@ -803,33 +784,16 @@ Tensor& __ior__(Tensor& self, const Scalar& other) {
return self.bitwise_or_(other);
}

Tensor& bitwise_xor_out(const Tensor& self, const Tensor& other, Tensor& result) {
auto iter = TensorIterator::binary_op(result, self, other);
bitwise_xor_stub(iter.device_type(), iter);
return result;
}

Tensor bitwise_xor(const Tensor& self, const Tensor& other) {
Tensor result = at::empty({0}, self.options());
at::bitwise_xor_out(result, self, other);
return result;
}

Tensor& bitwise_xor_(Tensor& self, const Tensor& other) {
return at::bitwise_xor_out(self, self, other);
}

Tensor& bitwise_xor_out(const Tensor& self, const Scalar& other, Tensor& result) {
return at::bitwise_xor_out(result, self, wrapped_scalar_tensor(other));
}

Tensor bitwise_xor(const Tensor& self, const Scalar& other) {
Tensor result = at::empty({0}, self.options());
return at::bitwise_xor_out(result, self, other);
return at::bitwise_xor(self, wrapped_scalar_tensor(other));
}

Tensor& bitwise_xor_(Tensor& self, const Scalar& other) {
return at::bitwise_xor_out(self, self, other);
return self.bitwise_xor_(wrapped_scalar_tensor(other));
}

// Legacy xor interfaces. They are aliased to bitwise_xor* functions
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/BinaryOps.h
Expand Up @@ -54,9 +54,9 @@ DECLARE_DISPATCH(structured_binary_fn, div_floor_stub);
DECLARE_DISPATCH(structured_binary_fn, div_trunc_stub);
DECLARE_DISPATCH(structured_binary_fn, atan2_stub);
DECLARE_DISPATCH(structured_binary_fn, remainder_stub);
DECLARE_DISPATCH(binary_fn, bitwise_and_stub);
DECLARE_DISPATCH(binary_fn, bitwise_or_stub);
DECLARE_DISPATCH(binary_fn, bitwise_xor_stub);
DECLARE_DISPATCH(structured_binary_fn, bitwise_and_stub);
DECLARE_DISPATCH(structured_binary_fn, bitwise_or_stub);
DECLARE_DISPATCH(structured_binary_fn, bitwise_xor_stub);
DECLARE_DISPATCH(structured_binary_fn, lshift_stub);
DECLARE_DISPATCH(structured_binary_fn, rshift_stub);
DECLARE_DISPATCH(binary_fn, logical_xor_stub);
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Expand Up @@ -240,7 +240,7 @@ void remainder_kernel(TensorIteratorBase& iter) {
}
}

void bitwise_and_kernel(TensorIterator& iter) {
void bitwise_and_kernel(TensorIteratorBase& iter) {
if (iter.dtype() == ScalarType::Bool) {
cpu_kernel(
iter,
Expand All @@ -261,7 +261,7 @@ void bitwise_and_kernel(TensorIterator& iter) {
}
}

void bitwise_or_kernel(TensorIterator& iter) {
void bitwise_or_kernel(TensorIteratorBase& iter) {
if (iter.dtype() == ScalarType::Bool) {
cpu_kernel(
iter,
Expand All @@ -282,7 +282,7 @@ void bitwise_or_kernel(TensorIterator& iter) {
}
}

void bitwise_xor_kernel(TensorIterator& iter) {
void bitwise_xor_kernel(TensorIteratorBase& iter) {
if (iter.dtype() == ScalarType::Bool) {
// Boolean type does not work with ^ (bitwise XOR) in C++. bitwise_xor wraps this operation for both Boolean and
// integral types.
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/cuda/BinaryBitwiseOpsKernels.cu
Expand Up @@ -23,7 +23,7 @@ struct BitwiseAndFunctor<bool> {
}
};

void bitwise_and_kernel_cuda(TensorIterator& iter) {
void bitwise_and_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_INTEGRAL_TYPES_AND(kBool, iter.dtype(), "bitwise_and_cuda", [&]() {
BitwiseAndFunctor<scalar_t> f;
gpu_kernel_with_scalars(iter, f);
Expand All @@ -44,7 +44,7 @@ struct BitwiseOrFunctor<bool> {
}
};

void bitwise_or_kernel_cuda(TensorIterator& iter) {
void bitwise_or_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_INTEGRAL_TYPES_AND(kBool, iter.dtype(), "bitwise_or_cuda", [&]() {
BitwiseOrFunctor<scalar_t> f;
gpu_kernel_with_scalars(iter, f);
Expand All @@ -65,7 +65,7 @@ struct BitwiseXorFunctor<bool> {
}
};

void bitwise_xor_kernel_cuda(TensorIterator& iter) {
void bitwise_xor_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_INTEGRAL_TYPES_AND(kBool, iter.dtype(), "bitwise_xor_cuda", [&]() {
BitwiseXorFunctor<scalar_t> f;
gpu_kernel_with_scalars(iter, f);
Expand Down
22 changes: 16 additions & 6 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -5692,6 +5692,8 @@

- func: bitwise_and.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured: True
structured_inherits: TensorIteratorBase
variants: function
dispatch:
CPU, CUDA: bitwise_and_out
Expand All @@ -5700,19 +5702,18 @@
device_check: NoCheck # TensorIterator
variants: function
dispatch:
CPU, CUDA: bitwise_and_out
CompositeExplicitAutograd: bitwise_and_out

- func: bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor
device_check: NoCheck # TensorIterator
variants: method, function
dispatch:
CPU, CUDA: bitwise_and
CompositeExplicitAutograd: bitwise_and

- func: bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor
device_check: NoCheck # TensorIterator
variants: method, function
dispatch:
CPU, CUDA: bitwise_and
structured_delegate: bitwise_and.Tensor_out

- func: bitwise_and_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
device_check: NoCheck # TensorIterator
Expand All @@ -5721,6 +5722,7 @@
- func: bitwise_and_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: method
structured_delegate: bitwise_and.Tensor_out

- func: __and__.Scalar(Tensor self, Scalar other) -> Tensor
device_check: NoCheck # TensorIterator
Expand All @@ -5740,6 +5742,8 @@

- func: bitwise_or.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured: True
structured_inherits: TensorIteratorBase
variants: function
dispatch:
CPU, CUDA: bitwise_or_out
Expand All @@ -5748,7 +5752,7 @@
device_check: NoCheck # TensorIterator
variants: function
dispatch:
CPU, CUDA: bitwise_or_out
CompositeExplicitAutograd: bitwise_or_out

- func: bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor
device_check: NoCheck # TensorIterator
Expand All @@ -5757,6 +5761,7 @@
- func: bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor
device_check: NoCheck # TensorIterator
variants: method, function
structured_delegate: bitwise_or.Tensor_out

- func: bitwise_or_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
device_check: NoCheck # TensorIterator
Expand All @@ -5765,6 +5770,7 @@
- func: bitwise_or_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: method
structured_delegate: bitwise_or.Tensor_out

- func: __or__.Scalar(Tensor self, Scalar other) -> Tensor
device_check: NoCheck # TensorIterator
Expand All @@ -5784,6 +5790,8 @@

- func: bitwise_xor.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured: True
structured_inherits: TensorIteratorBase
variants: function
dispatch:
CPU, CUDA: bitwise_xor_out
Expand All @@ -5792,7 +5800,7 @@
device_check: NoCheck # TensorIterator
variants: function
dispatch:
CPU, CUDA: bitwise_xor_out
CompositeExplicitAutograd: bitwise_xor_out

- func: bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor
device_check: NoCheck # TensorIterator
Expand All @@ -5801,6 +5809,7 @@
- func: bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor
device_check: NoCheck # TensorIterator
variants: method, function
structured_delegate: bitwise_xor.Tensor_out

- func: bitwise_xor_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
device_check: NoCheck # TensorIterator
Expand All @@ -5809,6 +5818,7 @@
- func: bitwise_xor_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: method
structured_delegate: bitwise_xor.Tensor_out

- func: __xor__.Scalar(Tensor self, Scalar other) -> Tensor
device_check: NoCheck # TensorIterator
Expand Down
4 changes: 3 additions & 1 deletion torch/_torch_docs.py
Expand Up @@ -2082,7 +2082,9 @@ def merge_dicts(*dicts):
.. code:: python
U = torch.linalg.cholesky(A.transpose(-2, -1).conj()).transpose(-2, -1).conj()
U = torch.linalg.cholesky(A).transpose(-2, -1).conj()
This transform will produce equivalent results for all valid (symmetric positive definite) inputs.
Args:
input (Tensor): the input tensor :math:`A` of size :math:`(*, n, n)` where `*` is zero or more
Expand Down
7 changes: 3 additions & 4 deletions torch/csrc/api/src/nn/modules/rnn.cpp
Expand Up @@ -127,10 +127,9 @@ void RNNImplBase<Derived>::reset() {
layer_params.emplace_back(w_hr);
param_names.emplace_back("weight_hr_l{layer}{suffix}");
}
for(const auto i : c10::irange(param_names.size())) { // NOLINT(modernize-loop-convert)
std::string x = std::regex_replace(param_names[i], std::regex("\\{layer\\}"), c10::str(layer));
x = std::regex_replace(x, std::regex("\\{suffix\\}"), c10::str(suffix));
param_names[i] = x;
for(auto& param_name : param_names) {
std::string x = std::regex_replace(param_name, std::regex("\\{layer\\}"), c10::str(layer));
param_name = std::regex_replace(x, std::regex("\\{suffix\\}"), c10::str(suffix));
}

for(const auto i : c10::irange(param_names.size())) {
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/autograd/engine.cpp
Expand Up @@ -577,6 +577,12 @@ void GraphTask::exec_post_processing() {
// 2. The callback's results can safely be used on (user-facing) caller_current_streams
// after backward().
c10::MultiStreamGuard g(caller_current_streams_filtered);

// Set the ThreadLocalState before calling the function.
// NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask
// always saves ThreadLocalState without grad_mode.
at::ThreadLocalStateGuard tls_guard(this->thread_locals_);

// WARNING: Don't use a range-for loop here because more callbacks may be
// added in between callback calls, so iterators may become invalidated.
// NOLINTNEXTLINE(modernize-loop-convert)
Expand Down
3 changes: 1 addition & 2 deletions torch/csrc/cuda/utils.cpp
@@ -1,6 +1,5 @@
#include <torch/csrc/python_headers.h>
// NOLINTNEXTLINE(modernize-deprecated-headers)
#include <stdarg.h>
#include <cstdarg>
#include <string>
#include <torch/csrc/cuda/THCP.h>

Expand Down
3 changes: 1 addition & 2 deletions torch/csrc/deploy/deploy.h
@@ -1,10 +1,9 @@
#pragma once
// NOLINTNEXTLINE(modernize-deprecated-headers)
#include <assert.h>
#include <c10/util/irange.h>
#include <torch/csrc/api/include/torch/imethod.h>
#include <torch/csrc/deploy/interpreter/interpreter_impl.h>
#include <torch/csrc/jit/serialization/import.h>
#include <cassert>
#include <fstream>
#include <iostream>
#include <string>
Expand Down

0 comments on commit 6040712

Please sign in to comment.