Skip to content

Commit

Permalink
[py][vulkan] Add is_vulkan to py api, add vulkan to device type parsi…
Browse files Browse the repository at this point in the history
…ng (#46511)

Summary: Pull Request resolved: #46511

Test Plan: Imported from OSS

Reviewed By: AshkanAliabadi

Differential Revision: D24379422

Pulled By: IvanKobzarev

fbshipit-source-id: afab89bb9e17c50934083598262bbe14ea82e893
  • Loading branch information
IvanKobzarev authored and facebook-github-bot committed Oct 21, 2020
1 parent a651b87 commit e8fbe54
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 2 deletions.
5 changes: 3 additions & 2 deletions c10/core/Device.cpp
Expand Up @@ -30,7 +30,7 @@
namespace c10 {
namespace {
DeviceType parse_type(const std::string& device_string) {
static const std::array<std::pair<std::string, DeviceType>, 10> types = {{
static const std::array<std::pair<std::string, DeviceType>, 11> types = {{
{"cpu", DeviceType::CPU},
{"cuda", DeviceType::CUDA},
{"mkldnn", DeviceType::MKLDNN},
Expand All @@ -41,6 +41,7 @@ DeviceType parse_type(const std::string& device_string) {
{"fpga", DeviceType::FPGA},
{"msnpu", DeviceType::MSNPU},
{"xla", DeviceType::XLA},
{"vulkan", DeviceType::Vulkan},
}};
auto device = std::find_if(
types.begin(),
Expand All @@ -52,7 +53,7 @@ DeviceType parse_type(const std::string& device_string) {
return device->second;
}
AT_ERROR(
"Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu, xla device type at start of device string: ", device_string);
"Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu, xla, vulkan device type at start of device string: ", device_string);
}
} // namespace

Expand Down
1 change: 1 addition & 0 deletions tools/pyi/gen_pyi.py
Expand Up @@ -577,6 +577,7 @@ def gen_pyi(declarations_path, out):
'is_quantized': ['is_quantized: _bool'],
'is_meta': ['is_meta: _bool'],
'is_mkldnn': ['is_mkldnn: _bool'],
'is_vulkan': ['is_vulkan: _bool'],
'storage_offset': ['def storage_offset(self) -> _int: ...'],
'to': ['def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
'def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, '
Expand Down
12 changes: 12 additions & 0 deletions torch/csrc/autograd/python_variable.cpp
Expand Up @@ -568,6 +568,17 @@ PyObject *THPVariable_is_mkldnn(THPVariable *self, void *unused)
END_HANDLE_TH_ERRORS
}

PyObject *THPVariable_is_vulkan(THPVariable *self, void *unused)
{
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject *)self)) {
return handle_torch_function_getter(self, "is_vulkan");
}
auto& self_ = self->cdata;
return torch::autograd::utils::wrap(self_.is_vulkan());
END_HANDLE_TH_ERRORS
}

PyObject *THPVariable_is_quantized(THPVariable *self, void *unused)
{
HANDLE_TH_ERRORS
Expand Down Expand Up @@ -697,6 +708,7 @@ static struct PyGetSetDef THPVariable_properties[] = {
{"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr},
{"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr},
{"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr},
{"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr},
{"is_complex", (getter)THPVariable_is_complex, nullptr, nullptr, nullptr},
{"is_quantized", (getter)THPVariable_is_quantized, nullptr, nullptr, nullptr},
{"is_meta", (getter)THPVariable_is_meta, nullptr, nullptr, nullptr},
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/frontend/sugared_value.cpp
Expand Up @@ -109,6 +109,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
{"is_sparse", "prim"},
{"is_mkldnn", "prim"},
{"is_quantized", "prim"},
{"is_vulkan", "prim"},
{"is_meta", "prim"},
{"is_leaf", "aten"},
{"requires_grad", "prim"},
Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp
Expand Up @@ -272,6 +272,14 @@ RegisterOperators reg(
push(stack, a.is_mkldnn());
},
aliasAnalysisFromSchema()),
Operator(
"prim::is_vulkan(Tensor a) -> bool",
[](Stack* stack) {
at::Tensor a;
pop(stack, a);
push(stack, a.is_vulkan());
},
aliasAnalysisFromSchema()),
Operator(
"prim::is_quantized(Tensor a) -> bool",
[](Stack* stack) {
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Expand Up @@ -817,6 +817,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor.is_mkldnn.__get__: lambda self: -1,
Tensor.is_quantized.__get__: lambda self: -1,
Tensor.is_sparse.__get__: lambda self: -1,
Tensor.is_vulkan.__get__: lambda self: -1,
Tensor.layout.__get__: lambda self: -1,
Tensor.name.__get__: lambda self: -1,
Tensor.names.__get__: lambda self: -1,
Expand Down

0 comments on commit e8fbe54

Please sign in to comment.