Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[spirv] SPIR-V / Vulkan NDArray #4202

Merged
merged 11 commits into from
Feb 7, 2022
Merged

[spirv] SPIR-V / Vulkan NDArray #4202

merged 11 commits into from
Feb 7, 2022

Conversation

bobcao3
Copy link
Collaborator

@bobcao3 bobcao3 commented Jan 29, 2022

No description provided.

@bobcao3
Copy link
Collaborator Author

bobcao3 commented Jan 29, 2022

/format

@netlify
Copy link

netlify bot commented Jan 29, 2022

✔️ Deploy Preview for docsite-preview canceled.

🔨 Explore the source changes: 5d8763e

🔍 Inspect the deploy log: https://app.netlify.com/sites/docsite-preview/deploys/61f8e516a046830007a86af3

@bobcao3 bobcao3 requested review from k-ye and ailzhang January 29, 2022 10:08
@bobcao3
Copy link
Collaborator Author

bobcao3 commented Jan 29, 2022

/format

@bobcao3 bobcao3 marked this pull request as ready for review January 29, 2022 17:15
@bobcao3
Copy link
Collaborator Author

bobcao3 commented Jan 29, 2022

/format

@@ -30,7 +30,9 @@ if (WIN32)
set(CMAKE_CXX_FLAGS
"${CMAKE_CXX_FLAGS} /Zc:__cplusplus /std:c++17 /bigobj /wd4244 /wd4267 /nologo /Zi /D \"_CRT_SECURE_NO_WARNINGS\" /D \"_ENABLE_EXTENDED_ALIGNED_STORAGE\"")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -g -gcodeview -fsized-deallocation -target x86_64-pc-windows-msvc")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -fsized-deallocation -target x86_64-pc-windows-msvc")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point the TaichiCore targets have not been defined (as this is the CXXFlags cmake file), also this aligns with the flags behavior on other platforms.

ext_arrays[i] = *(DeviceAllocation *)(host_ctx->args[i]);
} else {
ext_arrays[i] = kDeviceNullAllocation;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to update ext_array_size here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, ext_array_size is only for ext arrs (for host device context blitter). For the runtime, with NDArrays, it does not need to know the buffer size. I will add a comment to clarify this.

auto ctx_blitter = HostDeviceContextBlitter::maybe_make(
&ti_kernel->ti_kernel_attribs().ctx_attribs, host_ctx, device_,
host_result_buffer_, ctx_buffer.get(), ctx_buffer_host.get());

std::unordered_map<int, DeviceAllocation> ext_arrays;
std::unordered_map<int, size_t> ext_array_size;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess ext_ is no longer the accurate term, because it covers NDarray as well?

Copy link
Contributor

@ailzhang ailzhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!

@@ -181,8 +188,8 @@ class HostDeviceContextBlitter {
const auto dt = ret.dt;
do {
if (ret.is_array) {
// void *host_ptr = host_ctx_->get_arg<void *>(i);
// std::memcpy(host_ptr, device_ptr, ret.stride);
void *host_ptr = host_ctx_->get_arg<void *>(i);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOC: do we support array args as return value?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet at this point. But we can leave this in: should be easier to add later down the line.

size_t size = arg.stride;

for (int ax = 0; ax < 8; ax++) {
// FIXME: how and when do we determine the size of ext arrs?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI opengl got it from kernel->args[i].size but IMHO we should remove that in Args class. https://github.com/taichi-dev/taichi/blob/master/taichi/program/callable.h#L22
But the algorithm here also need to handle the a torch tensor with shape [1, 0, 1] (total size is still 0).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we unify the shape calculation algo here?

@bobcao3
Copy link
Collaborator Author

bobcao3 commented Jan 31, 2022

/format

Copy link
Member

@k-ye k-ye left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

size_t size = arg.stride;

for (int ax = 0; ax < 8; ax++) {
// FIXME: how and when do we determine the size of ext arrs?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we unify the shape calculation algo here?

@bobcao3
Copy link
Collaborator Author

bobcao3 commented Feb 7, 2022

/rerun

@bobcao3 bobcao3 merged commit 49b4318 into taichi-dev:master Feb 7, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants