Skip to content

Commit

Permalink
Update on "[py][vulkan][reland] Add is_vulkan to py api, add vulkan t…
Browse files Browse the repository at this point in the history
…o device type parsing"


Summary:

Test Plan: Imported from OSS

Pulled By: IvanKobzarev

Differential Revision: [D24448984](https://our.internmc.facebook.com/intern/diff/D24448984)

Reland of the PR: #46511

The initial PR broke tests as they assert the error message that  was changed in PR

torch/testing/_internal/distributed/nn/api/remote_module_test.py

In this PR it is changed accordingly

[ghstack-poisoned]
  • Loading branch information
IvanKobzarev committed Oct 22, 2020
2 parents 3b26ac6 + 13decdd commit 98ba6b5
Show file tree
Hide file tree
Showing 41 changed files with 860 additions and 200 deletions.
2 changes: 1 addition & 1 deletion .circleci/scripts/windows_cudnn_install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ if [[ "$CUDA_VERSION" == "10" ]]; then
cudnn_installer_name="cudnn-10.1-windows10-x64-v7.6.4.38"
elif [[ "$CUDA_VERSION" == "11" ]]; then
cuda_complete_version="11.0"
cudnn_installer_name="cudnn-11.0-windows-x64-v8.0.2.39"
cudnn_installer_name="cudnn-11.0-windows-x64-v8.0.4.30"
else
echo "CUDNN for CUDA_VERSION $CUDA_VERSION is not supported yet"
exit 1
Expand Down
6 changes: 4 additions & 2 deletions .clang-tidy
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
---
# NOTE there must be no spaces before the '-', so put the comma last.
Checks: '-*,
InheritParentConfig: true
Checks: '
bugprone-*,
-bugprone-forward-declaration-namespace,
-bugprone-macro-parentheses,
Expand All @@ -17,6 +18,7 @@ cppcoreguidelines-*,
-cppcoreguidelines-pro-type-union-access,
-cppcoreguidelines-pro-type-vararg,
-cppcoreguidelines-special-member-functions,
-facebook-hte-RelativeInclude,
hicpp-exception-baseclass,
hicpp-avoid-goto,
modernize-*,
Expand All @@ -27,7 +29,7 @@ modernize-*,
-modernize-use-trailing-return-type,
performance-*,
-performance-noexcept-move-constructor,
'
'
HeaderFilterRegex: 'torch/csrc/.*'
AnalyzeTemporaryDtors: false
CheckOptions:
Expand Down
5 changes: 0 additions & 5 deletions aten/src/ATen/core/NamedRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,6 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
m.impl("conj.out", CppFunction::makeFallthrough());
m.impl("contiguous", CppFunction::makeFallthrough());
m.impl("copy_", CppFunction::makeFallthrough());
m.impl("copy_imag", CppFunction::makeFallthrough());
m.impl("copy_imag.out", CppFunction::makeFallthrough());
m.impl("copy_real", CppFunction::makeFallthrough());
m.impl("copy_real.out", CppFunction::makeFallthrough());
m.impl("cos", CppFunction::makeFallthrough());
m.impl("cos.out", CppFunction::makeFallthrough());
m.impl("cos_", CppFunction::makeFallthrough());
Expand Down Expand Up @@ -506,6 +502,5 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
m.impl("is_leaf", CppFunction::makeFallthrough());
m.impl("_version", CppFunction::makeFallthrough());
m.impl("requires_grad_", CppFunction::makeFallthrough());
m.impl("requires_grad", CppFunction::makeFallthrough());
m.impl("retain_grad", CppFunction::makeFallthrough());
}
12 changes: 12 additions & 0 deletions aten/src/ATen/core/dispatch/Dispatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,4 +322,16 @@ void Dispatcher::setManuallyBoxedKernelFor_(const OperatorHandle& op, KernelFunc
// NB: Do not need to set manually boxed kernel for backend fallbacks
}

std::vector<OperatorHandle> Dispatcher::findDanglingImpls() const {
return operatorLookupTable_.read([&] (const ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) -> std::vector<OperatorHandle> {
std::vector<OperatorHandle> opsWithDanglingImpls;
for (const auto& op : operatorLookupTable) {
if (!op.second.hasSchema()) {
opsWithDanglingImpls.push_back(op.second);
}
}
return opsWithDanglingImpls;
});
}

}
22 changes: 22 additions & 0 deletions aten/src/ATen/core/dispatch/Dispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,28 @@ class CAFFE2_API Dispatcher final {
return dispatch_key != DispatchKey::BackendSelect;
}

//
// ------------------------------------------------------------------------
//
// Assertions
//
// ------------------------------------------------------------------------

/**
* For testing purposes.
* Returns a list of all operators that were created through calls to registerImpl(),
* without any corresponding calls to registerDef(). After static initialization
* is done this is almost certainly a bug, as the created OperatorHandle won't have
* any schema associated with it and users calling the op through the dispatcher
* won't be able to access it
*
* Note that we cannot enforce this invariant "as we go" during static initialization,
* due to undefined static initialization order- we have no guarantees over the order
* in which .def() and .impl() calls are registered in the dispatcher at static
* initialization time. So this function should only be called after static initialization.
*/
std::vector<OperatorHandle> findDanglingImpls() const;

private:
Dispatcher();

Expand Down
16 changes: 16 additions & 0 deletions aten/src/ATen/core/op_registration/op_registration_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2251,6 +2251,22 @@ TEST(NewOperatorRegistrationTest, testDelayedListener) {
EXPECT_EQ(initial_num_deregisters + 1, listener_ptr->num_deregisters_);
}

TEST(NewOperatorRegistrationTest, testImplNoDefGetsCaught) {
auto danglingImpls = Dispatcher::singleton().findDanglingImpls();
std::string error_str = "Discovered operators that have been registered through the dispatcher"
" without explicitly specifying their schemas. Please do so using"
" the TORCH_LIBRARY macro. Suspect operators:\n";
for (auto& op : danglingImpls) {
auto& op_name = op.operator_name();
error_str += "\t" + op_name.name;
if (op_name.overload_name != "") {
error_str += "." + op_name.overload_name;
}
error_str += "\n";
}
ASSERT_EQ(danglingImpls.size(), 0) << error_str;
}

}

#pragma GCC diagnostic pop
6 changes: 3 additions & 3 deletions aten/src/ATen/cpu/vec256/vec256_float_neon.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,12 @@ template <> class Vec256<float> {
// Only required because vec256_qint refers to this.
// Once we specialize that implementation for ARM
// this should be removed. TODO (kimishpatel)
const float operator[](int idx) const {
float operator[](int idx) const {
__at_align32__ float tmp[size()];
store(tmp);
return tmp[idx];
};
const float operator[](int idx) {
}
float operator[](int idx) {
__at_align32__ float tmp[size()];
store(tmp);
return tmp[idx];
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/native/cuda/EmbeddingBag.cu
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,13 @@ Tensor _embedding_bag_per_sample_weights_backward_cuda(
dim3 grid((num_samples + warps_per_block - 1) / warps_per_block);

auto output = at::empty({num_samples}, grad.options());

// Early return when there is no samples in the batch. This saves unnecesary kernel
// launch, but also prevents cudaGetLastError() to complain about invalid launch args
if (num_samples == 0) {
return output;
}

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "_embedding_bag_per_sample_weights_backward_cuda", [&]() {
_embedding_bag_per_sample_weights_backward_kernel<scalar_t>
Expand All @@ -459,6 +466,7 @@ Tensor _embedding_bag_per_sample_weights_backward_cuda(
num_samples,
embedding_features,
output.data_ptr<scalar_t>());
AT_CUDA_CHECK(cudaGetLastError());
}
);
return output;
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/native/quantized/library.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,16 @@ TORCH_LIBRARY(quantized, m) {
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv3d_unpack(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> (Tensor unpacked_weights, Tensor? B_origin)"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv2d_stride(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int[]"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv2d_padding(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int[]"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv2d_output_padding(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int[]"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv2d_dilation(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int[]"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv2d_groups(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv2d_transpose(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv3d_stride(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv3d_padding(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv3d_output_padding(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv3d_dilation(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv3d_groups(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv3d_transpose(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int"));
// conv_tranpsose
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose1d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor"));
Expand All @@ -98,6 +101,7 @@ TORCH_LIBRARY(quantized, m) {
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d_unpack(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> (Tensor unpacked_weights, Tensor? B_origin)"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d_stride(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int[]"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d_padding(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int[]"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d_output_padding(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int[]"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d_dilation(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int[]"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d_groups(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv_transpose2d_transpose(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> int"));
Expand Down Expand Up @@ -170,6 +174,10 @@ TORCH_LIBRARY(_quantized, m) {
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::conv2d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::conv2d_relu(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::conv2d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::conv_transpose1d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::conv_transpose2d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::conv_transpose1d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::conv_transpose2d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack"));
Expand Down
2 changes: 1 addition & 1 deletion aten/src/THC/THCDeviceUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ __host__ __device__ __forceinline__ T THCRoundUp(T a, T b) {
*/
template <typename T>
__device__ __forceinline__ T doLdg(const T* p) {
#ifndef __HIP_PLATFORM_HCC__
#if __CUDA_ARCH__ >= 350 && !defined __HIP_PLATFORM_HCC__
return __ldg(p);
#else
return *p;
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/operator_benchmark/pt/conv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import torch.nn as nn

from . import configs
from pt import configs

"""
Microbenchmarks for Conv1d and ConvTranspose1d operators.
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/operator_benchmark/pt/embeddingbag_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import operator_benchmark as op_bench
import torch
import numpy
from . import configs
from pt import configs

"""EmbeddingBag Operator Benchmark"""

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/operator_benchmark/pt/linear_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import torch.nn as nn

from . import configs
from pt import configs


"""Microbenchmarks for Linear operator."""
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/operator_benchmark/pt/qconv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import torch.nn.quantized as nnq

from . import configs
from pt import configs

"""
Microbenchmarks for qConv operators.
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/operator_benchmark/pt/qembeddingbag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import torch.nn.quantized as nnq
import numpy
from . import configs
from pt import configs

"""
Microbenchmarks for qEmbeddingBag operators.
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/operator_benchmark/pt/qlinear_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd

from . import configs
from pt import configs

"""
Microbenchmarks for Quantized Linear operators.
Expand Down
17 changes: 15 additions & 2 deletions docs/libtorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ The core of pytorch does not depend on Python. A
CMake-based build system compiles the C++ source code into a shared
object, libtorch.so.

Building libtorch
-----------------
Building libtorch using Python
------------------------------

You can use a python script/module located in tools package to build libtorch
::
Expand Down Expand Up @@ -34,3 +34,16 @@ To produce libtorch.a rather than libtorch.so, set the environment variable `BUI
To use ninja rather than make, set `CMAKE_GENERATOR="-GNinja" CMAKE_INSTALL="ninja install"`.

Note that we are working on eliminating tools/build_pytorch_libs.sh in favor of a unified cmake build.

Building libtorch using CMake
--------------------------------------

You can build C++ libtorch.so directly with cmake. For example, to build a Release version from the master branch and install it in the directory specified by CMAKE_INSTALL_PREFIX below, you can use
::
git clone -b master --recurse-submodule https://github.com/pytorch/pytorch.git
mkdir pytorch-build
cd pytorch-build
cmake -DBUILD_SHARED_LIBS:BOOL=ON -DCMAKE_BUILD_TYPE:STRING=Release -DPYTHON_EXECUTABLE:PATH=`which python3` -DCMAKE_INSTALL_PREFIX:PATH=../pytorch-install ../pytorch
cmake --build . --target install

To use release branch v1.6.0, for example, replace ``master`` with ``v1.6.0``. You will get errors if you do not have needed dependencies such as Python3's PyYAML package.

0 comments on commit 98ba6b5

Please sign in to comment.