Skip to content

Commit

Permalink
splitting registration and refactoring vocab.py module (#1352)
Browse files Browse the repository at this point in the history
  • Loading branch information
parmeet committed Jul 2, 2021
1 parent aa75fe0 commit 7ab50af
Show file tree
Hide file tree
Showing 12 changed files with 493 additions and 455 deletions.
1 change: 1 addition & 0 deletions .circleci/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ dependencies:
- sphinx
- sphinx-rtd-theme
- tqdm
- expecttest
- https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.0.0/de_core_news_sm-3.0.0.tar.gz#egg=de_core_news_sm==3.0.0
- https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz#egg=en_core_web_sm==3.0.0
1 change: 1 addition & 0 deletions .circleci/unittest/windows/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ dependencies:
- tqdm
- certifi
- future
- expecttest
- https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.0.0/de_core_news_sm-3.0.0.tar.gz#egg=de_core_news_sm==3.0.0
- https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz#egg=en_core_web_sm==3.0.0
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <torch/script.h>
#include <vectors.h> // @manual
#include <vocab.h> // @manual
#include <vocab_factory.h>

namespace torchtext {

namespace py = pybind11;
Expand Down Expand Up @@ -155,126 +157,8 @@ PYBIND11_MODULE(_torchtext, m) {
&_load_token_and_vectors_from_file);
m.def("_load_vocab_from_file", &_load_vocab_from_file);
m.def("_build_vocab_from_text_file", &build_vocab_from_text_file);
m.def("_build_vocab_from_text_file_using_python_tokenizer", &_build_vocab_from_text_file_using_python_tokenizer);
}

TORCH_LIBRARY_FRAGMENT(torchtext, m) {
m.class_<Regex>("Regex")
.def(torch::init<std::string>())
.def("Sub", &Regex::Sub)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Regex> &self) -> std::string {
return _serialize_regex(self);
},
// __setstate__
[](std::string state) -> c10::intrusive_ptr<Regex> {
return _deserialize_regex(std::move(state));
});

m.class_<RegexTokenizer>("RegexTokenizer")
.def(torch::init<std::vector<std::string>, std::vector<std::string>,
bool>())
.def("forward", &RegexTokenizer::forward)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<RegexTokenizer> &self)
-> RegexTokenizerStates {
return _serialize_regex_tokenizer(self);
},
// __setstate__
[](RegexTokenizerStates states)
-> c10::intrusive_ptr<RegexTokenizer> {
return _deserialize_regex_tokenizer(std::move(states));
});

m.class_<SentencePiece>("SentencePiece")
.def(torch::init<std::string>())
.def("Encode", &SentencePiece::Encode)
.def("EncodeAsIds", &SentencePiece::EncodeAsIds)
.def("DecodeIds", &SentencePiece::DecodeIds)
.def("EncodeAsPieces", &SentencePiece::EncodeAsPieces)
.def("DecodePieces", &SentencePiece::DecodePieces)
.def("GetPieceSize", &SentencePiece::GetPieceSize)
.def("unk_id", &SentencePiece::unk_id)
.def("PieceToId", &SentencePiece::PieceToId)
.def("IdToPiece", &SentencePiece::IdToPiece)
.def_pickle(
// The underlying content of SentencePiece contains byte string,
// and returing it as std::string cause UTF8 decoding error.
// Since TorchScript does not support byte string, we use byte Tensor
// to pass around the data.
// __getstate__
[](const c10::intrusive_ptr<SentencePiece> &self) -> torch::Tensor {
auto *data =
static_cast<void *>(const_cast<char *>(self->content_.data()));
auto numel = static_cast<int64_t>(self->content_.size());
return torch::from_blob(data, {numel}, {torch::kUInt8}).clone();
},
// __setstate__
[](torch::Tensor state) -> c10::intrusive_ptr<SentencePiece> {
auto *data = static_cast<char *>(state.data_ptr());
auto numel = state.size(0);
return c10::make_intrusive<SentencePiece>(std::string(data, numel));
});

m.class_<Vectors>("Vectors")
.def(torch::init<std::vector<std::string>, std::vector<std::int64_t>,
torch::Tensor, torch::Tensor>())
.def("__getitem__", &Vectors::__getitem__)
.def("lookup_vectors", &Vectors::lookup_vectors)
.def("__setitem__", &Vectors::__setitem__)
.def("__len__", &Vectors::__len__)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Vectors> &self) -> VectorsStates {
return _serialize_vectors(self);
},
// __setstate__
[](VectorsStates states) -> c10::intrusive_ptr<Vectors> {
return _deserialize_vectors(states);
});

m.class_<Vocab>("Vocab")
.def(torch::init<StringList, c10::optional<int64_t>>())
.def("__contains__",
[](const c10::intrusive_ptr<Vocab> &self, const std::string &item)
-> bool { return self->__contains__(c10::string_view{item}); })
.def("__getitem__",
[](const c10::intrusive_ptr<Vocab> &self, const std::string &item)
-> int64_t { return self->__getitem__(c10::string_view{item}); })
.def("insert_token", &Vocab::insert_token)
.def("__len__", &Vocab::__len__)
.def("set_default_index", &Vocab::set_default_index)
.def("get_default_index", &Vocab::get_default_index)
.def("append_token", &Vocab::append_token)
.def("lookup_token", &Vocab::lookup_token)
.def("lookup_tokens", &Vocab::lookup_tokens)
.def("lookup_indices",
[](const c10::intrusive_ptr<Vocab> &self,
const std::vector<std::string> &items) {
std::vector<int64_t> indices(items.size());
int64_t counter = 0;
for (const auto &item : items) {
indices[counter++] = self->__getitem__(c10::string_view{item});
}
return indices;
})
.def("get_stoi", &Vocab::get_stoi)
.def("get_itos", &Vocab::get_itos)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Vocab> &self) -> VocabStates {
return _serialize_vocab(self);
},
// __setstate__
[](VocabStates states) -> c10::intrusive_ptr<Vocab> {
return _deserialize_vocab(states);
});

m.def("torchtext::generate_sp_model", &generate_sp_model);
m.def("torchtext::load_sp_model", &load_sp_model);
m.def("torchtext::load_sp_model_string", &load_sp_model_string);
m.def("_build_vocab_from_text_file_using_python_tokenizer",
&_build_vocab_from_text_file_using_python_tokenizer);
}

} // namespace torchtext
129 changes: 129 additions & 0 deletions torchtext/csrc/register_torchbindings.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#include <iostream>
#include <regex.h>
#include <regex_tokenizer.h> // @manual
#include <sentencepiece.h> // @manual
#include <torch/script.h>
#include <vectors.h> // @manual
#include <vocab.h> // @manual
namespace torchtext {

TORCH_LIBRARY_FRAGMENT(torchtext, m) {
m.class_<Regex>("Regex")
.def(torch::init<std::string>())
.def("Sub", &Regex::Sub)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Regex> &self) -> std::string {
return _serialize_regex(self);
},
// __setstate__
[](std::string state) -> c10::intrusive_ptr<Regex> {
return _deserialize_regex(std::move(state));
});

m.class_<RegexTokenizer>("RegexTokenizer")
.def(torch::init<std::vector<std::string>, std::vector<std::string>,
bool>())
.def("forward", &RegexTokenizer::forward)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<RegexTokenizer> &self)
-> RegexTokenizerStates {
return _serialize_regex_tokenizer(self);
},
// __setstate__
[](RegexTokenizerStates states)
-> c10::intrusive_ptr<RegexTokenizer> {
return _deserialize_regex_tokenizer(std::move(states));
});

m.class_<SentencePiece>("SentencePiece")
.def(torch::init<std::string>())
.def("Encode", &SentencePiece::Encode)
.def("EncodeAsIds", &SentencePiece::EncodeAsIds)
.def("DecodeIds", &SentencePiece::DecodeIds)
.def("EncodeAsPieces", &SentencePiece::EncodeAsPieces)
.def("DecodePieces", &SentencePiece::DecodePieces)
.def("GetPieceSize", &SentencePiece::GetPieceSize)
.def("unk_id", &SentencePiece::unk_id)
.def("PieceToId", &SentencePiece::PieceToId)
.def("IdToPiece", &SentencePiece::IdToPiece)
.def_pickle(
// The underlying content of SentencePiece contains byte string,
// and returing it as std::string cause UTF8 decoding error.
// Since TorchScript does not support byte string, we use byte Tensor
// to pass around the data.
// __getstate__
[](const c10::intrusive_ptr<SentencePiece> &self) -> torch::Tensor {
auto *data =
static_cast<void *>(const_cast<char *>(self->content_.data()));
auto numel = static_cast<int64_t>(self->content_.size());
return torch::from_blob(data, {numel}, {torch::kUInt8}).clone();
},
// __setstate__
[](torch::Tensor state) -> c10::intrusive_ptr<SentencePiece> {
auto *data = static_cast<char *>(state.data_ptr());
auto numel = state.size(0);
return c10::make_intrusive<SentencePiece>(std::string(data, numel));
});

m.class_<Vectors>("Vectors")
.def(torch::init<std::vector<std::string>, std::vector<std::int64_t>,
torch::Tensor, torch::Tensor>())
.def("__getitem__", &Vectors::__getitem__)
.def("lookup_vectors", &Vectors::lookup_vectors)
.def("__setitem__", &Vectors::__setitem__)
.def("__len__", &Vectors::__len__)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Vectors> &self) -> VectorsStates {
return _serialize_vectors(self);
},
// __setstate__
[](VectorsStates states) -> c10::intrusive_ptr<Vectors> {
return _deserialize_vectors(states);
});

m.class_<Vocab>("Vocab")
.def(torch::init<StringList, c10::optional<int64_t>>())
.def("__contains__",
[](const c10::intrusive_ptr<Vocab> &self, const std::string &item)
-> bool { return self->__contains__(c10::string_view{item}); })
.def("__getitem__",
[](const c10::intrusive_ptr<Vocab> &self, const std::string &item)
-> int64_t { return self->__getitem__(c10::string_view{item}); })
.def("insert_token", &Vocab::insert_token)
.def("__len__", &Vocab::__len__)
.def("set_default_index", &Vocab::set_default_index)
.def("get_default_index", &Vocab::get_default_index)
.def("append_token", &Vocab::append_token)
.def("lookup_token", &Vocab::lookup_token)
.def("lookup_tokens", &Vocab::lookup_tokens)
.def("lookup_indices",
[](const c10::intrusive_ptr<Vocab> &self,
const std::vector<std::string> &items) {
std::vector<int64_t> indices(items.size());
int64_t counter = 0;
for (const auto &item : items) {
indices[counter++] = self->__getitem__(c10::string_view{item});
}
return indices;
})
.def("get_stoi", &Vocab::get_stoi)
.def("get_itos", &Vocab::get_itos)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Vocab> &self) -> VocabStates {
return _serialize_vocab(self);
},
// __setstate__
[](VocabStates states) -> c10::intrusive_ptr<Vocab> {
return _deserialize_vocab(states);
});

m.def("torchtext::generate_sp_model", &generate_sp_model);
m.def("torchtext::load_sp_model", &load_sp_model);
m.def("torchtext::load_sp_model_string", &load_sp_model_string);
}

} // namespace torchtext
59 changes: 0 additions & 59 deletions torchtext/csrc/vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,17 +191,6 @@ void parse_raw_text_file_chunk(const std::string &file_path, size_t offset,
}
}

// sorting using a custom object
struct CompareTokens {
bool operator()(const std::pair<std::string, int64_t> &a,
const std::pair<std::string, int64_t> &b) {
if (a.second == b.second) {
return a.first < b.first;
}
return a.second > b.second;
}
};

StringList
_concat_tokens(std::vector<std::shared_ptr<IndexDict>> chunk_counters,
const int64_t min_freq, const int64_t num_lines,
Expand Down Expand Up @@ -345,54 +334,6 @@ Vocab _build_vocab_from_text_file(const std::string &file_path,
return Vocab(std::move(tokens));
}

Vocab _build_vocab_from_text_file_using_python_tokenizer(
const std::string &file_path, const int64_t min_freq,
py::object tokenizer) {
// find number of lines
int64_t num_lines = _infer_lines(file_path);
// Read text from file and add tokens
std::ifstream fin(file_path, std::ios::in);
TORCH_CHECK(fin.is_open(), "Cannot open input file " + file_path);

IndexDict counter;
std::string line;
for (int64_t i = 0; i < num_lines; i++) {
std::getline(fin, line);
std::vector<std::string> token_list =
tokenizer(line).cast<std::vector<std::string>>();

for (size_t i = 0; i < token_list.size(); i++) {
std::string token = token_list[i];

if (counter.find(token) == counter.end()) {
counter[token] = 1;
} else {
counter[token] += 1;
}
}
}

// create tokens-frequency pairs
std::vector<std::pair<std::string, int64_t>> token_freq_pairs;
for (const auto &item : counter) {
if (item.second >= min_freq) {
token_freq_pairs.push_back(item);
}
}

// sort tokens by frequency
CompareTokens compare_tokens;
std::sort(token_freq_pairs.begin(), token_freq_pairs.end(), compare_tokens);

// Create final list of tokens
StringList tokens;
for (const auto &token_freq_pair : token_freq_pairs) {
tokens.push_back(token_freq_pair.first);
}

return Vocab(std::move(tokens));
}

VocabStates _serialize_vocab(const c10::intrusive_ptr<Vocab> &self) {
std::vector<int64_t> integers;
StringList strings = self->itos_;
Expand Down
20 changes: 14 additions & 6 deletions torchtext/csrc/vocab.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
#pragma once
#include <algorithm>
#include <c10/util/string_view.h>
#include <pybind11/pybind11.h>
#include <torch/script.h>

namespace py = pybind11;

namespace torchtext {

typedef std::vector<std::string> StringList;
Expand All @@ -14,6 +12,19 @@ typedef std::tuple<std::string, std::vector<int64_t>, std::vector<std::string>,
std::vector<torch::Tensor>>
VocabStates;

// sorting using a custom object
struct CompareTokens {
bool operator()(const std::pair<std::string, int64_t> &a,
const std::pair<std::string, int64_t> &b) {
if (a.second == b.second) {
return a.first < b.first;
}
return a.second > b.second;
}
};

int64_t _infer_lines(const std::string &file_path);

struct Vocab : torch::CustomClassHolder {
static const int32_t MAX_VOCAB_SIZE = 30000000;
int64_t unk_index_;
Expand Down Expand Up @@ -79,7 +90,4 @@ Vocab _build_vocab_from_text_file(const std::string &file_path,
const int64_t min_freq,
const int64_t num_cpus,
torch::jit::script::Module tokenizer);
Vocab _build_vocab_from_text_file_using_python_tokenizer(
const std::string &file_path, const int64_t min_freq, py::object tokenizer);

} // namespace torchtext

0 comments on commit 7ab50af

Please sign in to comment.