Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Port CLX wordpiece tokenizer into cuDF #4981

Closed
harrism opened this issue Apr 21, 2020 · 17 comments
Closed

[FEA] Port CLX wordpiece tokenizer into cuDF #4981

harrism opened this issue Apr 21, 2020 · 17 comments
Assignees
Labels
feature request New feature or request libcudf Affects libcudf (C++/CUDA) code. Python Affects Python cuDF API. strings strings issues (C++ and Python)

Comments

@harrism
Copy link
Member

harrism commented Apr 21, 2020

Is your feature request related to a problem? Please describe.

The RAPIDS CLX repo has GPU-accelerated tokenizers that will be more widely usable if incorporated into libcudf/cuDF. We would like to port the CLX wordpiece tokenizer to libcudf.

Describe the solution you'd like
The new wordpiece tokenizer should be redesigned to use libcudf types (columns, tables, views) for input and output while maintaining or improving its current high performance.

Since cuDF already has a (different type of) tokenizer, we should ensure consistency in tokenization APIs within the library.

Once it is ported to libcudf, we will need to write new cython bindings and provide a Python interface.

@BartleyR took an action item to provide an easy to run benchmark of the current tokenizer to ensure we don't have regressions when we port.

@harrism harrism added feature request New feature or request libcudf Affects libcudf (C++/CUDA) code. Python Affects Python cuDF API. strings strings issues (C++ and Python) labels Apr 21, 2020
@BartleyR
Copy link
Member

BartleyR commented Jun 11, 2020

Adding more information about the subword tokenizer and references. Typically, these types of tokenizers accept either a list of filenames or a list of strings/sentences. This version of our subword tokenizer is based on the batch_encode_plus tokenizer in Hugging Face's transformers repo (https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.PreTrainedTokenizer.batch_encode_plus). Our version doesn't support all of the parameters that this one does, but it does support key ones like stride and pad_to_max_length, both required for cyBERT. ScaPy wraps Hugging Face's batch_encode_plus tokenizer with their own API (https://github.com/explosion/thinc/blob/master/examples/02_transformers_tagger_bert.ipynb). We should align the function signature to the Hugging Face version. For arguments that we don't support, we could catch/throw a NotImplementedError. They could default to None.

Currently, in the GPU subword tokenizer, we have a mandatory parameters of max_characters that needs to be set. If truncated=True is set, then max_sentences and max_tensor_rows are both equal to the length of the input series (rows in the cuDF or length of a cuDF series ). If truncated=False, then the user would need to provide an estimate for the value of max_characters.

The tokenizer is currently called a wordpiece tokenizer, but we would like to rename it to be a subword tokenizer.

@harrism
Copy link
Member Author

harrism commented Jun 15, 2020

We need to work in terms of columns and tables. And I think we should not load files in this tokenizer, that should be done with the file reader/writer APIs (cuIO). So I guess the input would be a column of strings.

@BartleyR
Copy link
Member

And I think we should not load files in this tokenizer, that should be done with the file reader/writer APIs (cuIO). So I guess the input would be a column of strings.

I agree with this. I'd rather use cuIO to do the file in/out. Column of strings is what we've been using.

@davidwendt
Copy link
Contributor

davidwendt commented Jun 16, 2020

What about the vocab/hash-table file that is loaded here?
https://github.com/rapidsai/clx/blob/9338eb263536be33117398a67b734cf0c324d262/cpp/src/wordPieceTokenizer.cu#L282

GpuWordPieceTokenizer::GpuWordPieceTokenizer(std::string vocab_file, <------
    uint32_t max_num_chars, uint32_t max_inp_chars_per_word): 
device_hash_table(nullptr), device_bin_coefficients(nullptr), device_bin_offsets(nullptr) {

How should this be handled? The file appears to be loaded as 3 vectors (hash-table, coefficients, offsets) and 6 scalars (unk-tok-id, first-tok-id, sep-tok-id, outer-table-a, outer-table-b, num-bins). The bert_hash_table.txt appears here.
Is it expected that we ship with this file in the conda package?

@kkraus14 for his opinion as well on this.

@davidwendt
Copy link
Contributor

How are the 3 device memory buffers returned in the TokenizerResult object freed?

I see they are allocated here:
https://github.com/rapidsai/clx/blob/66c05d130824f51c23852579299d217d555102b5/cpp/src/main.cu#L35-L40

  cudaMalloc((void**)&result->device_tensor_tokenIDS, result->nrows_tensor*max_sequence_length*sizeof(uint32_t));
  cudaMalloc((void**)&result->device_attention_mask, result->nrows_tensor*max_sequence_length*sizeof(uint32_t));
  cudaMalloc((void**)&result->device_tensor_metadata, result->nrows_tensor*3*sizeof(uint32_t));
  cudaMemcpy(result->device_tensor_tokenIDS, tokenizer.get_tensor_tokenIDS(), result->nrows_tensor*max_sequence_length*sizeof(uint32_t), cudaMemcpyDeviceToDevice);
  cudaMemcpy(result->device_attention_mask, tokenizer.get_attention_mask(), result->nrows_tensor*max_sequence_length*sizeof(uint32_t), cudaMemcpyDeviceToDevice);
  cudaMemcpy(result->device_tensor_metadata, tokenizer.get_tensor_metadata(), result->nrows_tensor*3*sizeof(uint32_t), cudaMemcpyDeviceToDevice);

But I don't see where they are freed
https://github.com/rapidsai/clx/blob/66c05d130824f51c23852579299d217d555102b5/python/clx/analytics/tokenizer_wrapper.pyx#L87-L100

    device_tokenIDS = device_array_from_ptr(<uintptr_t>result.device_tensor_tokenIDS,
                                            shape=(result.nrows_tensor,max_sequence_length),
                                            dtype=np.int32)
    device_mask = device_array_from_ptr(<uintptr_t>result.device_attention_mask,
                                        shape=(result.nrows_tensor,max_sequence_length),
                                        dtype=np.int32)
    device_metadata = device_array_from_ptr(<uintptr_t>result.device_tensor_metadata,
                                            shape=(result.nrows_tensor,3),
                                            dtype=np.int32)

    token = from_dlpack(device_tokenIDS.toDlpack())
    mask = from_dlpack(device_mask.toDlpack())
    metadata = from_dlpack(device_metadata.toDlpack())
    return token.type(torch.long), mask.type(torch.long), metadata.type(torch.long)

The device_array_from_ptr converts the device pointer to cupy.cuda.memory.UnownedMemory(ptr, datasize, None)
The toDlpack does Zero-copy conversion to a DLPack tensor
The from_dlpack mentions The tensor will share the memory with the object represented in the dlpack

So if these are all zero-copy conversions, then who is taking ownership and is calling cudaFree() on these memory buffers?

@BartleyR
Copy link
Member

How are the 3 device memory buffers returned in the TokenizerResult object freed?

It actually doesn't look like they are being freed. But maybe I'm missing something - @Iroy30 or @brhodes10?

@brhodes10
Copy link

How are the 3 device memory buffers returned in the TokenizerResult object freed?

It actually doesn't look like they are being freed. But maybe I'm missing something - @Iroy30 or @brhodes10?

I cannot see where it is being freed. Willing to take advice on how to fix that.

@kkraus14
Copy link
Collaborator

I cannot see where it is being freed. Willing to take advice on how to fix that.

Because they're being explicitly allocated with cudaMalloc currently we can't really use much of our existing machinery to free them. We'd have to either build a Numba finalizer that we register around them or build a Cython class that calls cudaFree when they go out of scope. Neither is very trivial.

Is waiting for it to be ported into cuDF an option? Within cuDF we can make the initial allocations using rmm::device_buffer which we have Cython classes to handle scoping the lifetime of cleanly.

@davidwendt
Copy link
Contributor

@kkraus14 I'm porting this right now #5511 .
I can allocate these using rmm::device_buffer but I will need some help tying this into the cudf machinery.

@BartleyR
Copy link
Member

Is waiting for it to be ported into cuDF an option?

Absolutely. We have a working version right now, so taking care of this at integration time is fine.

@kkraus14
Copy link
Collaborator

I can allocate these using rmm::device_buffer but I will need some help tying this into the cudf machinery.

ACK. We can do C++ first and do Cython / Python as a follow up unless you want to do it all at once for unit testing purposes.

If you make them unique_ptr<rmm::device_buffer> then we can just call the Cython code of: https://github.com/rapidsai/rmm/blob/branch-0.15/python/rmm/_lib/device_buffer.pyx#L126-L130 and that returns us a rmm.DeviceBuffer Python object that uses the smart pointer to control memory lifetime. We can then immediately build a cudf.Column on top of that rmm.DeviceBuffer or we can push it through CuPy depending on the expected output.

@davidwendt
Copy link
Contributor

davidwendt commented Jun 23, 2020

@BartleyR What is the expected output here? The current CLX code is returning Torch tensors (?)
https://github.com/rapidsai/clx/blob/66c05d130824f51c23852579299d217d555102b5/python/clx/analytics/tokenizer_wrapper.pyx#L56-L59

import torch
from torch.utils.dlpack import from_dlpack


    token = from_dlpack(device_tokenIDS.toDlpack())
    mask = from_dlpack(device_mask.toDlpack())
    metadata = from_dlpack(device_metadata.toDlpack())
    return token.type(torch.long), mask.type(torch.long), metadata.type(torch.long)

Would it be ok if this would return 3 cudf.Column or just 3 rmm.DeviceBuffer instead?

@BartleyR
Copy link
Member

Would it be ok if this would return 3 cudf.Column or just 3 rmm.DeviceBuffer instead?

It was returning Torch tensors because that is our current pipeline. After this, those tensors would be directly fed into a PyTorch model. However, @raykallen and I have been talking about making it more generic so it could be fed to a TF workflow as well. I'm fine with either from a pipeline perspective, but I'm not sure if one of them is more performant when going to a tensor as a next step. We don't envision there being a use case where you would not go directly to a tensor then into a BERT model for inference, but I can't say with 100% certainty. So I would prefer to err on the side of (1) efficiency and (2) generalizability.

@davidwendt
Copy link
Contributor

Ok, thanks. I think a rmm::DeviceBuffer may be the most general and efficient here. I will start with that and see how it goes.

@harrism
Copy link
Member Author

harrism commented Jun 24, 2020

I thought we already discussed that the input and output would become cudf::columns. Then the API would be similar to any libcudf API and the caller is responsible for deleting the results when finished.

@BartleyR
Copy link
Member

I thought we already discussed that the input and output would become cudf::columns. Then the API would be similar to any libcudf API and the caller is responsible for deleting the results when finished.

@davidwendt Mark is right and I forgot we discussed this.

@harrism
Copy link
Member Author

harrism commented Jul 19, 2020

Closed by #5511

@harrism harrism closed this as completed Jul 19, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request libcudf Affects libcudf (C++/CUDA) code. Python Affects Python cuDF API. strings strings issues (C++ and Python)
Projects
None yet
Development

No branches or pull requests

5 participants