Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[py][vulkan][reland] Add is_vulkan to py api, add vulkan to device type parsing #46655

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions c10/core/Device.cpp
Expand Up @@ -30,7 +30,7 @@
namespace c10 {
namespace {
DeviceType parse_type(const std::string& device_string) {
static const std::array<std::pair<std::string, DeviceType>, 10> types = {{
static const std::array<std::pair<std::string, DeviceType>, 11> types = {{
{"cpu", DeviceType::CPU},
{"cuda", DeviceType::CUDA},
{"mkldnn", DeviceType::MKLDNN},
Expand All @@ -41,6 +41,7 @@ DeviceType parse_type(const std::string& device_string) {
{"fpga", DeviceType::FPGA},
{"msnpu", DeviceType::MSNPU},
{"xla", DeviceType::XLA},
{"vulkan", DeviceType::Vulkan},
}};
auto device = std::find_if(
types.begin(),
Expand All @@ -52,7 +53,7 @@ DeviceType parse_type(const std::string& device_string) {
return device->second;
}
AT_ERROR(
"Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu, xla device type at start of device string: ", device_string);
"Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu, xla, vulkan device type at start of device string: ", device_string);
}
} // namespace

Expand Down
1 change: 1 addition & 0 deletions tools/pyi/gen_pyi.py
Expand Up @@ -577,6 +577,7 @@ def gen_pyi(declarations_path, out):
'is_quantized': ['is_quantized: _bool'],
'is_meta': ['is_meta: _bool'],
'is_mkldnn': ['is_mkldnn: _bool'],
'is_vulkan': ['is_vulkan: _bool'],
'storage_offset': ['def storage_offset(self) -> _int: ...'],
'to': ['def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
'def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, '
Expand Down
12 changes: 12 additions & 0 deletions torch/csrc/autograd/python_variable.cpp
Expand Up @@ -568,6 +568,17 @@ PyObject *THPVariable_is_mkldnn(THPVariable *self, void *unused)
END_HANDLE_TH_ERRORS
}

PyObject *THPVariable_is_vulkan(THPVariable *self, void *unused)
{
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject *)self)) {
return handle_torch_function_getter(self, "is_vulkan");
}
auto& self_ = self->cdata;
return torch::autograd::utils::wrap(self_.is_vulkan());
END_HANDLE_TH_ERRORS
}

PyObject *THPVariable_is_quantized(THPVariable *self, void *unused)
{
HANDLE_TH_ERRORS
Expand Down Expand Up @@ -697,6 +708,7 @@ static struct PyGetSetDef THPVariable_properties[] = {
{"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr},
{"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr},
{"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr},
{"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr},
{"is_complex", (getter)THPVariable_is_complex, nullptr, nullptr, nullptr},
{"is_quantized", (getter)THPVariable_is_quantized, nullptr, nullptr, nullptr},
{"is_meta", (getter)THPVariable_is_meta, nullptr, nullptr, nullptr},
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/frontend/sugared_value.cpp
Expand Up @@ -109,6 +109,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
{"is_sparse", "prim"},
{"is_mkldnn", "prim"},
{"is_quantized", "prim"},
{"is_vulkan", "prim"},
{"is_meta", "prim"},
{"is_leaf", "aten"},
{"requires_grad", "prim"},
Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp
Expand Up @@ -272,6 +272,14 @@ RegisterOperators reg(
push(stack, a.is_mkldnn());
},
aliasAnalysisFromSchema()),
Operator(
"prim::is_vulkan(Tensor a) -> bool",
[](Stack* stack) {
at::Tensor a;
pop(stack, a);
push(stack, a.is_vulkan());
},
aliasAnalysisFromSchema()),
Operator(
"prim::is_quantized(Tensor a) -> bool",
[](Stack* stack) {
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Expand Up @@ -817,6 +817,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor.is_mkldnn.__get__: lambda self: -1,
Tensor.is_quantized.__get__: lambda self: -1,
Tensor.is_sparse.__get__: lambda self: -1,
Tensor.is_vulkan.__get__: lambda self: -1,
Tensor.layout.__get__: lambda self: -1,
Tensor.name.__get__: lambda self: -1,
Tensor.names.__get__: lambda self: -1,
Expand Down
Expand Up @@ -244,7 +244,8 @@ def test_invalid_devices(self):

with self.assertRaisesRegex(
RuntimeError,
r"Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu, xla device type at start of device string",
r"Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu, xla, vulkan"
" device type at start of device string",
):
list(
self._create_remote_module_iter(
Expand Down