Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] add Hashtable On GPU #74

Merged
merged 1 commit into from
May 11, 2021

Conversation

rhdong
Copy link
Member

@rhdong rhdong commented May 10, 2021

  • add Hashtable On GPU
  • switch cuda tool chain to cuda11
  • update STYLE GUIDE for clang format on MacOS.
  • update bazel version to 3.7.2

@rhdong rhdong requested a review from Lifann May 10, 2021 15:11
size_t default_value_num =
is_full_default ? default_value.shape().dim_size(0) : 1;
CUDA_CHECK(cudaStreamCreate(&_stream));
CUDA_CHECK(cudaMalloc((void**)&d_status, sizeof(bool) * len));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it to use cudaMallocManaged?
Allocates memory that will be automatically managed by the Unified Memory system.

Copy link
Member Author

@rhdong rhdong May 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tried but fail, and I will consider your advice next version, for this is a stable implement.

3.7.2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the bazel version 3.7.2 necessary? If it is, please also mention it in README.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accept

~CuckooHashTableOfTensorsGpu() { delete table_; }

void CreateTable(size_t max_size, gpu::TableWrapperBase<K, V>** pptable) {
if (runtime_dim_ <= 50) {
Copy link
Member

@Lifann Lifann May 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is possible to use mod 50 and switch statement to achieve better performance.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just run one time, and mod50 & switch will have worse readability

CUDA_CHECK(cudaStreamCreate(&_stream));
CUDA_CHECK(cudaMalloc((void**)&d_keys, sizeof(K) * len));
CUDA_CHECK(cudaMalloc((void**)&d_values, sizeof(V) * runtime_dim_ * len));
CUDA_CHECK(cudaMemcpy((void*)d_keys, (void*)keys.tensor_data().data(),
Copy link
Member

@Lifann Lifann May 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the operation is registered on Device_GPU, the input Tensor is already on GPU device memory. So it is not strictly needed to do the copy.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested it, sometimes inputs were in CPU memory that caused crash, so I used memcpy..

Copy link
Member

@Lifann Lifann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made several review comments. Please check it.

Lifann
Lifann previously approved these changes May 11, 2021
Copy link
Member

@Lifann Lifann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Lifann
Lifann previously approved these changes May 11, 2021
@rhdong rhdong merged commit 1404bd0 into tensorflow:master May 11, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants