-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[spirv] SPIR-V / Vulkan NDArray #4202
Conversation
/format |
✔️ Deploy Preview for docsite-preview canceled. 🔨 Explore the source changes: 5d8763e 🔍 Inspect the deploy log: https://app.netlify.com/sites/docsite-preview/deploys/61f8e516a046830007a86af3 |
/format |
/format |
@@ -30,7 +30,9 @@ if (WIN32) | |||
set(CMAKE_CXX_FLAGS | |||
"${CMAKE_CXX_FLAGS} /Zc:__cplusplus /std:c++17 /bigobj /wd4244 /wd4267 /nologo /Zi /D \"_CRT_SECURE_NO_WARNINGS\" /D \"_ENABLE_EXTENDED_ALIGNED_STORAGE\"") | |||
else() | |||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -g -gcodeview -fsized-deallocation -target x86_64-pc-windows-msvc") | |||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -fsized-deallocation -target x86_64-pc-windows-msvc") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nits: could you use target_compile_options
? https://cmake.org/cmake/help/latest/command/target_compile_options.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At this point the TaichiCore targets have not been defined (as this is the CXXFlags cmake file), also this aligns with the flags behavior on other platforms.
ext_arrays[i] = *(DeviceAllocation *)(host_ctx->args[i]); | ||
} else { | ||
ext_arrays[i] = kDeviceNullAllocation; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to update ext_array_size
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, ext_array_size
is only for ext arrs (for host device context blitter). For the runtime, with NDArrays, it does not need to know the buffer size. I will add a comment to clarify this.
auto ctx_blitter = HostDeviceContextBlitter::maybe_make( | ||
&ti_kernel->ti_kernel_attribs().ctx_attribs, host_ctx, device_, | ||
host_result_buffer_, ctx_buffer.get(), ctx_buffer_host.get()); | ||
|
||
std::unordered_map<int, DeviceAllocation> ext_arrays; | ||
std::unordered_map<int, size_t> ext_array_size; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess ext_
is no longer the accurate term, because it covers NDarray as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work!
@@ -181,8 +188,8 @@ class HostDeviceContextBlitter { | |||
const auto dt = ret.dt; | |||
do { | |||
if (ret.is_array) { | |||
// void *host_ptr = host_ctx_->get_arg<void *>(i); | |||
// std::memcpy(host_ptr, device_ptr, ret.stride); | |||
void *host_ptr = host_ctx_->get_arg<void *>(i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OOC: do we support array args as return value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet at this point. But we can leave this in: should be easier to add later down the line.
size_t size = arg.stride; | ||
|
||
for (int ax = 0; ax < 8; ax++) { | ||
// FIXME: how and when do we determine the size of ext arrs? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI opengl got it from kernel->args[i].size
but IMHO we should remove that in Args
class. https://github.com/taichi-dev/taichi/blob/master/taichi/program/callable.h#L22
But the algorithm here also need to handle the a torch tensor with shape [1, 0, 1]
(total size is still 0).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we unify the shape calculation algo here?
/format |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
size_t size = arg.stride; | ||
|
||
for (int ax = 0; ax < 8; ax++) { | ||
// FIXME: how and when do we determine the size of ext arrs? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we unify the shape calculation algo here?
/rerun |
No description provided.