Skip to content

Commit

Permalink
[py][vulkan][reland] Add is_vulkan to py api, add vulkan to device ty…
Browse files Browse the repository at this point in the history
…pe parsing (#46655)

Summary: Pull Request resolved: #46655

Test Plan: Imported from OSS

Pulled By: IvanKobzarev

Reviewed By: mrshenli

Differential Revision: D24448984

fbshipit-source-id: 5000846a06077f7a5a06dd51da422d2a42f70820
  • Loading branch information
IvanKobzarev authored and facebook-github-bot committed Oct 22, 2020
1 parent bc1ce58 commit 3112e23
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 3 deletions.
5 changes: 3 additions & 2 deletions c10/core/Device.cpp
Expand Up @@ -30,7 +30,7 @@
namespace c10 {
namespace {
DeviceType parse_type(const std::string& device_string) {
static const std::array<std::pair<std::string, DeviceType>, 10> types = {{
static const std::array<std::pair<std::string, DeviceType>, 11> types = {{
{"cpu", DeviceType::CPU},
{"cuda", DeviceType::CUDA},
{"mkldnn", DeviceType::MKLDNN},
Expand All @@ -41,6 +41,7 @@ DeviceType parse_type(const std::string& device_string) {
{"fpga", DeviceType::FPGA},
{"msnpu", DeviceType::MSNPU},
{"xla", DeviceType::XLA},
{"vulkan", DeviceType::Vulkan},
}};
auto device = std::find_if(
types.begin(),
Expand All @@ -52,7 +53,7 @@ DeviceType parse_type(const std::string& device_string) {
return device->second;
}
AT_ERROR(
"Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu, xla device type at start of device string: ", device_string);
"Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu, xla, vulkan device type at start of device string: ", device_string);
}
} // namespace

Expand Down
1 change: 1 addition & 0 deletions tools/pyi/gen_pyi.py
Expand Up @@ -577,6 +577,7 @@ def gen_pyi(declarations_path, out):
'is_quantized': ['is_quantized: _bool'],
'is_meta': ['is_meta: _bool'],
'is_mkldnn': ['is_mkldnn: _bool'],
'is_vulkan': ['is_vulkan: _bool'],
'storage_offset': ['def storage_offset(self) -> _int: ...'],
'to': ['def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
'def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, '
Expand Down
12 changes: 12 additions & 0 deletions torch/csrc/autograd/python_variable.cpp
Expand Up @@ -568,6 +568,17 @@ PyObject *THPVariable_is_mkldnn(THPVariable *self, void *unused)
END_HANDLE_TH_ERRORS
}

PyObject *THPVariable_is_vulkan(THPVariable *self, void *unused)
{
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject *)self)) {
return handle_torch_function_getter(self, "is_vulkan");
}
auto& self_ = self->cdata;
return torch::autograd::utils::wrap(self_.is_vulkan());
END_HANDLE_TH_ERRORS
}

PyObject *THPVariable_is_quantized(THPVariable *self, void *unused)
{
HANDLE_TH_ERRORS
Expand Down Expand Up @@ -697,6 +708,7 @@ static struct PyGetSetDef THPVariable_properties[] = {
{"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr},
{"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr},
{"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr},
{"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr},
{"is_complex", (getter)THPVariable_is_complex, nullptr, nullptr, nullptr},
{"is_quantized", (getter)THPVariable_is_quantized, nullptr, nullptr, nullptr},
{"is_meta", (getter)THPVariable_is_meta, nullptr, nullptr, nullptr},
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/frontend/sugared_value.cpp
Expand Up @@ -109,6 +109,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
{"is_sparse", "prim"},
{"is_mkldnn", "prim"},
{"is_quantized", "prim"},
{"is_vulkan", "prim"},
{"is_meta", "prim"},
{"is_leaf", "aten"},
{"requires_grad", "prim"},
Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp
Expand Up @@ -272,6 +272,14 @@ RegisterOperators reg(
push(stack, a.is_mkldnn());
},
aliasAnalysisFromSchema()),
Operator(
"prim::is_vulkan(Tensor a) -> bool",
[](Stack* stack) {
at::Tensor a;
pop(stack, a);
push(stack, a.is_vulkan());
},
aliasAnalysisFromSchema()),
Operator(
"prim::is_quantized(Tensor a) -> bool",
[](Stack* stack) {
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Expand Up @@ -815,6 +815,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor.is_mkldnn.__get__: lambda self: -1,
Tensor.is_quantized.__get__: lambda self: -1,
Tensor.is_sparse.__get__: lambda self: -1,
Tensor.is_vulkan.__get__: lambda self: -1,
Tensor.layout.__get__: lambda self: -1,
Tensor.name.__get__: lambda self: -1,
Tensor.names.__get__: lambda self: -1,
Expand Down
Expand Up @@ -244,7 +244,8 @@ def test_invalid_devices(self):

with self.assertRaisesRegex(
RuntimeError,
r"Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu, xla device type at start of device string",
r"Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu, xla, vulkan"
" device type at start of device string",
):
list(
self._create_remote_module_iter(
Expand Down

0 comments on commit 3112e23

Please sign in to comment.