Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Port of clx subword tokenizer to cudf #5511

Merged
merged 68 commits into from
Jul 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
650b730
initial port of clx subword tokenizer
davidwendt Jun 18, 2020
d2f0028
Merge branch 'branch-0.15' into port-subword-tokenizer
davidwendt Jun 22, 2020
b8dbc5f
update changelog
davidwendt Jun 22, 2020
6b12e6c
fix style violations
davidwendt Jun 22, 2020
5a9a7a0
move some source to details
davidwendt Jun 22, 2020
cbf6118
Merge branch 'branch-0.15' into port-subword-tokenizer
davidwendt Jun 22, 2020
0bf1527
pass stream down
davidwendt Jun 22, 2020
4f4c92d
fix kernel name
davidwendt Jun 23, 2020
aa7277d
Merge branch 'branch-0.15' into port-subword-tokenizer
davidwendt Jun 23, 2020
431fc56
refactor 3 tokenizer classes to one
davidwendt Jun 23, 2020
c0e4d8c
rename basic-tokenizer to normalizer
davidwendt Jun 23, 2020
e759147
rename full-tokenizer to wordpiece-tokenizer
davidwendt Jun 23, 2020
17b7c7a
meant to rename to data-normalizer
davidwendt Jun 23, 2020
14ec5a9
create bigger test data
davidwendt Jun 23, 2020
4968b50
Merge branch 'branch-0.15' into port-subword-tokenizer
davidwendt Jun 23, 2020
eccaa2a
rename TokenizerResult to tokenizer_result
davidwendt Jun 23, 2020
5aa3ef5
add cython for subword_tokenizer
davidwendt Jun 23, 2020
c689f91
declare a cython interface
davidwendt Jun 23, 2020
ad2aa6f
Merge branch 'branch-0.15' into port-subword-tokenizer
davidwendt Jun 23, 2020
650d0c1
fix style violation
davidwendt Jun 24, 2020
1c66503
Merge branch 'branch-0.15' into port-subword-tokenizer
davidwendt Jun 24, 2020
01c468e
return columns instead of device-buffers
davidwendt Jun 24, 2020
94fe0ba
update cython/python for new return type
davidwendt Jun 24, 2020
bb0de69
move kernels to eliminate a header file
davidwendt Jun 24, 2020
a1b542d
add consts to various declarations
davidwendt Jun 24, 2020
27a34e8
Merge branch 'branch-0.15' into port-subword-tokenizer
davidwendt Jun 24, 2020
dad537c
rename tokenIDS to token_ids
davidwendt Jun 24, 2020
9a9576e
fix style violation
davidwendt Jun 24, 2020
8af4a0e
add some nvtx ranges
davidwendt Jun 24, 2020
d57a497
fix some comments; add some TODOs
davidwendt Jun 25, 2020
b677996
reduce cp-data to 1MB header
davidwendt Jun 26, 2020
9a59675
Merge branch 'branch-0.15' into port-subword-tokenizer
davidwendt Jun 26, 2020
f31bcf2
try thrust in place of cub in normalize
davidwendt Jun 26, 2020
fecd5e5
add load-hashed-vocab api declaration
davidwendt Jun 26, 2020
676d39d
use thrust for internal update function
davidwendt Jun 26, 2020
26ddf70
add load-hash-vocab api
davidwendt Jun 26, 2020
d3e0120
Merge branch 'branch-0.15' into port-subword-tokenizer
davidwendt Jun 28, 2020
b04cf58
move per-context-cache utility to detail header
davidwendt Jun 29, 2020
4241b3e
make cp/aux tables singletons
davidwendt Jun 29, 2020
a5477e4
change device_vector to uvector
davidwendt Jun 29, 2020
14694ac
compute row2log values in device code
davidwendt Jun 29, 2020
36462a2
add more doxygen; removed commented out code
davidwendt Jun 29, 2020
b8f3c40
add more consts
davidwendt Jun 29, 2020
8685412
Merge branch 'branch-0.15' into port-subword-tokenizer
davidwendt Jun 29, 2020
a4c5ef7
Merge branch 'branch-0.15' into port-subword-tokenizer
davidwendt Jun 30, 2020
e5934cb
Merge branch 'branch-0.15' into port-subword-tokenizer
davidwendt Jun 30, 2020
e36346e
decl types in python subword_tokenize def
davidwendt Jun 30, 2020
61de291
Merge branch 'branch-0.15' into port-subword-tokenizer
davidwendt Jul 2, 2020
d516907
add pytest
davidwendt Jul 6, 2020
c12de54
fix copydoc
davidwendt Jul 6, 2020
6aa298a
use std::generate in place of for
davidwendt Jul 6, 2020
cc6f31f
remove forceinline decl
davidwendt Jul 6, 2020
15ac715
change define to constexpr
davidwendt Jul 6, 2020
aea146c
fix doxygen param order
davidwendt Jul 6, 2020
15c3c77
change log to tensor
davidwendt Jul 6, 2020
47a4af7
use grid_1d for kernel launch parms
davidwendt Jul 6, 2020
3ac75bc
Merge branch 'branch-0.15' into port-subword-tokenizer
davidwendt Jul 6, 2020
107a042
add more @params
davidwendt Jul 7, 2020
6b64fbe
add comments about hashed vocab values
davidwendt Jul 7, 2020
fa49311
add more comments explaining numbers used
davidwendt Jul 7, 2020
0357beb
add more consts
davidwendt Jul 7, 2020
aaaf986
remove commented out code
davidwendt Jul 7, 2020
f1ab8da
minor fixes like east consts
davidwendt Jul 7, 2020
e022568
add more gtests varying stride and do_truncate parms
davidwendt Jul 7, 2020
5ac2faf
rework tensor-output kernel to use fixed block-size
davidwendt Jul 7, 2020
b4bff95
Merge branch 'branch-0.15' into port-subword-tokenizer
davidwendt Jul 7, 2020
91bd0e2
Merge branch 'branch-0.15' into port-subword-tokenizer
davidwendt Jul 8, 2020
08a33ed
return cupy arrays instead of Series
davidwendt Jul 8, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
- PR #5488 Add plumbings for `.str.replace_tokens`
- PR #5502 Add Unsigned int types support in dlpack
- PR #5497 Add `.str.isinteger` & `.str.isfloat`
- PR #5511 Port of clx subword tokenizer to cudf
- PR #5528 Add unsigned int reading and writing support to parquet
- PR #5510 Add support for `cudf.Index` to create Indexes
- PR #5536 Parquet reader - add support for multiple sources
Expand Down
4 changes: 4 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,10 @@ add_library(cudf
src/text/tokenize.cu
src/text/ngrams_tokenize.cu
src/text/replace.cu
src/text/subword/load_hash_file.cu
src/text/subword/data_normalizer.cu
src/text/subword/wordpiece_tokenizer.cu
src/text/subword/subword_tokenize.cu
src/scalar/scalar.cpp
src/scalar/scalar_factories.cpp
src/dictionary/add_keys.cu
Expand Down
8 changes: 8 additions & 0 deletions cpp/benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,11 @@ set(CSV_WRITER_BENCH_SRC
"${CMAKE_CURRENT_SOURCE_DIR}/io/csv/csv_writer_benchmark.cpp")

ConfigureBench(CSV_WRITER_BENCH "${CSV_WRITER_BENCH_SRC}")

###################################################################################################
# - subword tokenizer benchmark -------------------------------------------------------------------

set(SUBWORD_TOKENIZER_BENCH_SRC
"${CMAKE_CURRENT_SOURCE_DIR}/text/subword_benchmark.cpp")

ConfigureBench(SUBWORD_TOKENIZER_BENCH "${SUBWORD_TOKENIZER_BENCH_SRC}")
81 changes: 81 additions & 0 deletions cpp/benchmarks/text/subword_benchmark.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <benchmark/benchmark.h>
#include <cudf/strings/strings_column_view.hpp>
#include <nvtext/subword_tokenize.hpp>

#include <tests/utilities/column_utilities.hpp>
#include <tests/utilities/column_wrapper.hpp>

#include <fstream>
#include <iostream>
#include <vector>

#define MAX_NUM_SENTENCES 101
#define MAX_NUM_CHARS 150000
#define MAX_ROWS_TENSOR 300

static std::string create_hash_vocab_file()
{
std::string dir_template("/tmp");
if (const char* env_p = std::getenv("WORKSPACE")) dir_template = env_p;
std::string hash_file = dir_template + "/hash_vocab.txt";
// create a fake hashed vocab text file for this test
// this only works with words in the strings in the benchmark code below
std::vector<std::pair<int, int>> coefficients(23, {65559, 0});
std::ofstream outfile(hash_file, std::ofstream::out);
outfile << "1\n0\n" << coefficients.size() << "\n";
for (auto c : coefficients) outfile << c.first << " " << c.second << "\n";
std::vector<uint64_t> hash_table(23, 0);
outfile << hash_table.size() << "\n";
hash_table[0] = 3015668L;
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
hash_table[1] = 6205475701751155871L;
hash_table[5] = 6358029;
hash_table[16] = 451412625363L;
hash_table[20] = 6206321707968235495L;
for (auto h : hash_table) outfile << h << "\n";
outfile << "100\n101\n102\n\n";
return hash_file;
}

static void BM_cuda_tokenizer_cudf(benchmark::State& state)
{
uint32_t nrows = MAX_NUM_SENTENCES - 1;
std::vector<const char*> h_strings(nrows, "This is a test ");
cudf::test::strings_column_wrapper strings(h_strings.begin(), h_strings.end());
// cudf::test::strings_column_wrapper strings{"This is a test."};
std::string hash_file = create_hash_vocab_file();
std::vector<uint32_t> offsets{14};
uint32_t max_sequence_length = 64;
uint32_t stride = 48;
uint32_t do_truncate = 0;
uint32_t do_lower = 1;
for (auto _ : state) {
auto result = nvtext::subword_tokenize(cudf::strings_column_view{strings},
hash_file,
max_sequence_length,
stride,
do_lower,
do_truncate,
MAX_NUM_SENTENCES,
MAX_NUM_CHARS,
MAX_ROWS_TENSOR);
}
}
BENCHMARK(BM_cuda_tokenizer_cudf);

BENCHMARK_MAIN();
43 changes: 43 additions & 0 deletions cpp/include/cudf/strings/detail/utilities.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

#include <rmm/thrust_rmm_allocator.h>
#include <thrust/scan.h>
#include <mutex>
#include <unordered_map>

namespace cudf {
namespace strings {
Expand Down Expand Up @@ -60,6 +62,47 @@ std::unique_ptr<column> make_offsets_child_column(
return offsets_column;
}

// This template is a thin wrapper around per-context singleton objects.
// It maintains a single object for each CUDA context.
template <typename TableType>
class per_context_cache {
public:
// Find an object cached for a current CUDA context.
// If there is no object available in the cache, it calls the initializer
// `init` to create a new one and cache it for later uses.
template <typename Initializer>
TableType* find_or_initialize(const Initializer& init)
{
CUcontext c;
cuCtxGetCurrent(&c);
auto finder = cache_.find(c);
if (finder == cache_.end()) {
TableType* result = init();
cache_[c] = result;
return result;
} else
return finder->second;
}

private:
std::unordered_map<CUcontext, TableType*> cache_;
};

// This template is a thread-safe version of per_context_cache.
template <typename TableType>
class thread_safe_per_context_cache : public per_context_cache<TableType> {
public:
template <typename Initializer>
TableType* find_or_initialize(const Initializer& init)
{
std::lock_guard<std::mutex> guard(mutex);
return per_context_cache<TableType>::find_or_initialize(init);
}

private:
std::mutex mutex;
};

} // namespace detail
} // namespace strings
} // namespace cudf
45 changes: 45 additions & 0 deletions cpp/include/nvtext/detail/load_hash_file.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cudf/column/column.hpp>
#include <nvtext/subword_tokenize.hpp>

#include <stdint.h>
#include <string.h>

namespace nvtext {
namespace detail {

/**
* @brief Load the hashed vocabulary file into device memory.
*
* The object here can be used to call the subword_tokenize without
* incurring the cost of loading the same file each time.
*
* @param filename_hashed_vocabulary A path to the preprocessed vocab.txt file.
* Note that this is the file AFTER python/perfect_hash.py has been used
* for preprocessing.
* @param stream CUDA stream used for device memory operations and kernel launches.
* @param mr Memory resource to allocate any returned objects.
* @return vocabulary hash-table elements
*/
hashed_vocabulary load_vocabulary_file(std::string const& filename_hashed_vocabulary,
cudaStream_t stream,
rmm::mr::device_memory_resource* mr);

} // namespace detail
} // namespace nvtext
174 changes: 174 additions & 0 deletions cpp/include/nvtext/subword_tokenize.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cudf/column/column.hpp>
#include <cudf/column/column_view.hpp>
#include <cudf/strings/strings_column_view.hpp>

#include <stdint.h>
#include <string.h>

namespace nvtext {

/**
* @brief The vocabulary data for use with the subword_tokenize function.
*/
struct hashed_vocabulary {
uint16_t first_token_id{};
uint16_t separator_token_id{};
uint16_t unknown_token_id{};
uint32_t outer_hash_a{};
uint32_t outer_hash_b{};
uint16_t num_bins{};
std::unique_ptr<cudf::column> table; // uint64
std::unique_ptr<cudf::column> bin_coefficients; // uint64
std::unique_ptr<cudf::column> bin_offsets; // uint16
};

/**
* @brief Load the hashed vocabulary file into device memory.
*
* The object here can be used to call the subword_tokenize without
* incurring the cost of loading the same file each time.
*
* @throw cudf::logic_error if the `filename_hashed_vocabulary` could not be opened.
*
* @param filename_hashed_vocabulary A path to the preprocessed vocab.txt file.
* Note that this is the file AFTER python/perfect_hash.py has been used
* for preprocessing.
* @param mr Memory resource to allocate any returned objects.
* @return vocabulary hash-table elements
*/
hashed_vocabulary load_vocabulary_file(
std::string const& filename_hashed_vocabulary,
rmm::mr::device_memory_resource* mr = rmm::mr::get_default_resource());

/**
* @brief Result object for the subword_tokenize functions.
*/
struct tokenizer_result {
/**
* @brief The number of rows for the output token-ids.
*/
uint32_t nrows_tensor{};
/**
* @brief The number of token-ids in each row.
*/
uint32_t sequence_length{};
/**
* @brief A vector of token-ids for each row.
*
* The data is a flat matrix (nrows_tensor x sequence_length) of token-ids.
* This column is of type UINT32 with no null entries.
*/
std::unique_ptr<cudf::column> tensor_token_ids;
/**
* @brief This mask identifies which tensor-token-ids are valid.
*
* This column is of type UINT32 with no null entries.
*/
std::unique_ptr<cudf::column> tensor_attention_mask;
/**
* @brief The metadata for each tensor row.
*
* There are three elements per tensor row [row-id, start_pos, stop_pos])
* This column is of type UINT32 with no null entries.
*/
std::unique_ptr<cudf::column> tensor_metadata;
};

/**
* @brief Creates a tokenizer that cleans the text, splits it into tokens and
* returns token-ids from an input vocabulary.
*
* The strings are first normalized by converting to lower-case, removing
* punctuation, replacing a select set of multi-byte characters and
* whitespace characters.
*
* The strings are then tokenized by using whitespace as a delimiter.
* Consecutive delimiters are ignored. Each token is then assigned
* a 4-byte token-id mapped from the provided vocabulary table.
*
* Essentially each string is converted into one or more vectors of token-ids
* in the output column. The total number of these vectors x `max_sequence_length`
* is the size of the output column.
*
* @throw cudf::logic_error if `stride > max_sequence_length`
* @throw cudf::logic_error if `max_sequence_length * max_rows_tensor` is
* larger than the max value for cudf::size_type
*
* @param strings The input strings to tokenize.
* @param filename_hashed_vocabulary A path to the preprocessed vocab.txt file.
* Note that this is the file AFTER python/perfect_hash.py has been used
* for preprocessing.
* @param max_sequence_length Limit of the number of token-ids per row in final tensor
* for each string.
* @param stride Each row in the output token-ids will replicate `max_sequence_length - stride`
* the token-ids from the previous row, unless it is the first string.
* @param do_lower_case If true, the tokenizer will convert uppercase characters in the
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
* input stream to lower-case and strip accents from those characters.
* If false, accented and uppercase characters are not transformed.
* @param do_truncate If true, the tokenizer will discard all the token-ids after
* `max_sequence_length` for each input string. If false, it will use a new row
* in the output token-ids to continue generating the output.
* @param max_num_strings Maximum number of input strings for instantiating the tokenizer.
* Used for allocating temporary working memory on the GPU.
* If the input contains a larger number of strings, behavior is undefined.
* @param max_num_chars Maximum number of characters for instantiating the tokenizer.
* Used for allocating temporary working memory on the GPU.
* If input contains larger number of characters, behavior is undefined.
* @param max_rows_tensor Maximum number of rows for the output token-ids expected
* to be generated by the tokenizer.
* Used for allocating temporary working memory on the GPU device.
* If the output generates a larger number of rows, behavior is undefined.
* @param mr Memory resource to allocate any returned objects.
* @return token-ids, attention-mask, and metadata
*/
tokenizer_result subword_tokenize(
cudf::strings_column_view const& strings,
std::string const& filename_hashed_vocabulary,
uint32_t max_sequence_length,
uint32_t stride,
bool do_lower_case,
bool do_truncate,
uint32_t max_num_strings,
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
uint32_t max_num_chars,
uint32_t max_rows_tensor,
rmm::mr::device_memory_resource* mr = rmm::mr::get_default_resource());

/**
* @copydoc subword_tokenize()
*
* This function differs from the one above by only the hashed vocabulary parameter.
* The file can be pre-loaded using the @ref load_vocabulary_file API and then
* passed in place of the file name in a call to this API.
*
* @param vocabulary_table The vocabulary table pre-loaded into this object.
*/
tokenizer_result subword_tokenize(
cudf::strings_column_view const& strings,
hashed_vocabulary const& vocabulary_table,
uint32_t max_sequence_length,
uint32_t stride,
bool do_lower_case,
bool do_truncate,
uint32_t max_num_strings,
uint32_t max_num_chars,
uint32_t max_rows_tensor,
rmm::mr::device_memory_resource* mr = rmm::mr::get_default_resource());

} // namespace nvtext
Loading