Skip to content

Commit

Permalink
[py][vulkan][reland] Add is_vulkan to py api, add vulkan to device ty…
Browse files Browse the repository at this point in the history
…pe parsing

Summary:

Test Plan: Imported from OSS

Pulled By: IvanKobzarev

ghstack-source-id: 6bda8848865704a7850e8c81e5f4aea41269f135
Pull Request resolved: #46655
  • Loading branch information
IvanKobzarev committed Oct 22, 2020
1 parent 13decdd commit cc8a73e
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 3 deletions.
5 changes: 3 additions & 2 deletions c10/core/Device.cpp
Expand Up @@ -30,7 +30,7 @@
namespace c10 {
namespace {
DeviceType parse_type(const std::string& device_string) {
static const std::array<std::pair<std::string, DeviceType>, 10> types = {{
static const std::array<std::pair<std::string, DeviceType>, 11> types = {{
{"cpu", DeviceType::CPU},
{"cuda", DeviceType::CUDA},
{"mkldnn", DeviceType::MKLDNN},
Expand All @@ -41,6 +41,7 @@ DeviceType parse_type(const std::string& device_string) {
{"fpga", DeviceType::FPGA},
{"msnpu", DeviceType::MSNPU},
{"xla", DeviceType::XLA},
{"vulkan", DeviceType::Vulkan},
}};
auto device = std::find_if(
types.begin(),
Expand All @@ -52,7 +53,7 @@ DeviceType parse_type(const std::string& device_string) {
return device->second;
}
AT_ERROR(
"Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu, xla device type at start of device string: ", device_string);
"Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu, xla, vulkan device type at start of device string: ", device_string);
}
} // namespace

Expand Down
1 change: 1 addition & 0 deletions tools/pyi/gen_pyi.py
Expand Up @@ -577,6 +577,7 @@ def gen_pyi(declarations_path, out):
'is_quantized': ['is_quantized: _bool'],
'is_meta': ['is_meta: _bool'],
'is_mkldnn': ['is_mkldnn: _bool'],
'is_vulkan': ['is_vulkan: _bool'],
'storage_offset': ['def storage_offset(self) -> _int: ...'],
'to': ['def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
'def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, '
Expand Down
12 changes: 12 additions & 0 deletions torch/csrc/autograd/python_variable.cpp
Expand Up @@ -568,6 +568,17 @@ PyObject *THPVariable_is_mkldnn(THPVariable *self, void *unused)
END_HANDLE_TH_ERRORS
}

PyObject *THPVariable_is_vulkan(THPVariable *self, void *unused)
{
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject *)self)) {
return handle_torch_function_getter(self, "is_vulkan");
}
auto& self_ = self->cdata;
return torch::autograd::utils::wrap(self_.is_vulkan());
END_HANDLE_TH_ERRORS
}

PyObject *THPVariable_is_quantized(THPVariable *self, void *unused)
{
HANDLE_TH_ERRORS
Expand Down Expand Up @@ -697,6 +708,7 @@ static struct PyGetSetDef THPVariable_properties[] = {
{"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr},
{"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr},
{"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr},
{"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr},
{"is_complex", (getter)THPVariable_is_complex, nullptr, nullptr, nullptr},
{"is_quantized", (getter)THPVariable_is_quantized, nullptr, nullptr, nullptr},
{"is_meta", (getter)THPVariable_is_meta, nullptr, nullptr, nullptr},
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/frontend/sugared_value.cpp
Expand Up @@ -109,6 +109,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
{"is_sparse", "prim"},
{"is_mkldnn", "prim"},
{"is_quantized", "prim"},
{"is_vulkan", "prim"},
{"is_meta", "prim"},
{"is_leaf", "aten"},
{"requires_grad", "prim"},
Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp
Expand Up @@ -272,6 +272,14 @@ RegisterOperators reg(
push(stack, a.is_mkldnn());
},
aliasAnalysisFromSchema()),
Operator(
"prim::is_vulkan(Tensor a) -> bool",
[](Stack* stack) {
at::Tensor a;
pop(stack, a);
push(stack, a.is_vulkan());
},
aliasAnalysisFromSchema()),
Operator(
"prim::is_quantized(Tensor a) -> bool",
[](Stack* stack) {
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Expand Up @@ -817,6 +817,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor.is_mkldnn.__get__: lambda self: -1,
Tensor.is_quantized.__get__: lambda self: -1,
Tensor.is_sparse.__get__: lambda self: -1,
Tensor.is_vulkan.__get__: lambda self: -1,
Tensor.layout.__get__: lambda self: -1,
Tensor.name.__get__: lambda self: -1,
Tensor.names.__get__: lambda self: -1,
Expand Down
Expand Up @@ -244,7 +244,8 @@ def test_invalid_devices(self):

with self.assertRaisesRegex(
RuntimeError,
r"Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu, xla device type at start of device string",
r"Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu, xla, vulkan"
" device type at start of device string",
):
list(
self._create_remote_module_iter(
Expand Down

0 comments on commit cc8a73e

Please sign in to comment.