Skip to content

Commit

Permalink
Merge pull request #895 from pytorch/master
Browse files Browse the repository at this point in the history
  • Loading branch information
seemethere committed Jul 20, 2020
2 parents fb6fbb7 + 5bdc40e commit c851c3e
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 122 deletions.
1 change: 1 addition & 0 deletions test/asset/vocab_test.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
<new_unk>
a
b
c
Expand Down
104 changes: 35 additions & 69 deletions test/experimental/test_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,78 +18,78 @@ def tearDown(self):
torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()

def test_has_unk(self):
c = OrderedDict({})
c = OrderedDict()
v = Vocab(c)

# check if unk is mapped to the first index
self.assertEqual(v['not_in_it'], 0)
self.assertEqual(v['<unk>'], 0)

def test_new_unk(self):
c = OrderedDict({})
v = Vocab(c, specials=('<new_unk>',), unk_token="<new_unk>")
c = OrderedDict()
v = Vocab(c, unk_token="<new_unk>")

# check if new_unk is mapped to the first index
self.assertEqual(v['<new_unk>'], 0)
self.assertEqual(v['not_in_it'], 0)

def test_vocab_get_item(self):
token_to_freq = {'a': 2, 'b': 2}
token_to_freq = {'<unk>': 2, 'a': 2, 'b': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
c = OrderedDict(sorted_by_freq_tuples)
v = Vocab(c)
v = Vocab(c, min_freq=2)

self.assertEqual(v['<unk>'], 0)
self.assertEqual(v['<pad>'], 1)
self.assertEqual(v['a'], 2)
self.assertEqual(v['b'], 3)
self.assertEqual(v['a'], 1)
self.assertEqual(v['b'], 2)

def test_vocab_set_item(self):
c = OrderedDict({'a': 2})
def test_vocab_insert_token(self):
c = OrderedDict({'<unk>': 2, 'a': 2})

# add item to end
v = Vocab(c)
v.insert_token('b', 3)
v.insert_token('b', 2)

self.assertEqual(v['<unk>'], 0)
self.assertEqual(v['<pad>'], 1)
self.assertEqual(v['a'], 2)
self.assertEqual(v['b'], 3)
expected_itos = ['<unk>', 'a', 'b']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}

self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(v.get_stoi()), expected_stoi)

# add item to middle
v = Vocab(c, specials_first=False)
v = Vocab(c)
v.insert_token('b', 0)

self.assertEqual(v['b'], 0)
self.assertEqual(v['a'], 1)
self.assertEqual(v['<unk>'], 2)
self.assertEqual(v['<pad>'], 3)
expected_itos = ['b', '<unk>', 'a']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}

self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(v.get_stoi()), expected_stoi)

def test_vocab_append_token(self):
c = OrderedDict({'a': 2})
v = Vocab(c)
v.append_token('b')

self.assertEqual(len(v), 4)
self.assertEqual(v['b'], 3)
self.assertEqual(len(v), 3)
self.assertEqual(v['b'], 2)

def test_vocab_len(self):
token_to_freq = {'a': 2, 'b': 2, 'c': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
c = OrderedDict(sorted_by_freq_tuples)
v = Vocab(c)

self.assertEqual(len(v), 5)
self.assertEqual(len(v), 4)

def test_vocab_basic(self):
token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)

c = OrderedDict(sorted_by_freq_tuples)
v = Vocab(c, min_freq=3, specials=['<unk>', '<pad>', '<bos>'])
v = Vocab(c, min_freq=3)

expected_itos = ['<unk>', '<pad>', '<bos>',
'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world', '<unk>']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}

self.assertEqual(v.get_itos(), expected_itos)
Expand All @@ -100,50 +100,28 @@ def test_vocab_jit(self):
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)

c = OrderedDict(sorted_by_freq_tuples)
v = Vocab(c, min_freq=3, specials=['<unk>', '<pad>', '<bos>'])
v = Vocab(c, min_freq=3)
jit_v = torch.jit.script(v)

expected_itos = ['<unk>', '<pad>', '<bos>',
'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world', '<unk>']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}

self.assertEqual(jit_v.get_itos(), expected_itos)
self.assertEqual(dict(jit_v.get_stoi()), expected_stoi)

def test_vocab_specials_order(self):
token_to_freq = {'a': 2, 'b': 2, 'c': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
c = OrderedDict(sorted_by_freq_tuples)

# add specials into vocabulary at first
v = Vocab(c, specials=['<pad>', '<unk>'])
expected_itos = ['<pad>', '<unk>', 'a', 'b', 'c']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}

self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(v.get_stoi()), expected_stoi)

# add specials into vocabulary at last
v = Vocab(c, specials=['<pad>', '<unk>'], specials_first=False)
expected_itos = ['a', 'b', 'c', '<pad>', '<unk>']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}

self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(v.get_stoi()), expected_stoi)

def test_vocab_lookup_token(self):
token_to_freq = {'a': 2, 'b': 2, 'c': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
c = OrderedDict(sorted_by_freq_tuples)
v = Vocab(c, specials_first=False)
v = Vocab(c)

self.assertEqual(v.lookup_token(0), 'a')

def test_vocab_lookup_tokens(self):
token_to_freq = {'a': 2, 'b': 2, 'c': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
c = OrderedDict(sorted_by_freq_tuples)
v = Vocab(c, specials_first=False)
v = Vocab(c)

indices = [1, 0, 2]
expected_tokens = ['b', 'a', 'c']
Expand All @@ -154,7 +132,7 @@ def test_vocab_lookup_indices(self):
token_to_freq = {'a': 2, 'b': 2, 'c': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
c = OrderedDict(sorted_by_freq_tuples)
v = Vocab(c, specials_first=False)
v = Vocab(c)

tokens = ['b', 'a', 'c']
expected_indices = [1, 0, 2]
Expand All @@ -168,18 +146,7 @@ def test_errors(self):

with self.assertRaises(ValueError):
# Test proper error raised when setting unk token to None
Vocab(c, specials=['<unk>', '<bos>'], unk_token=None)

with self.assertRaises(ValueError):
# Test proper error raised when specials token doesn't contain unk_token
Vocab(c, specials=['<pad>', '<bos>'])

with self.assertRaises(ValueError):
# Test proper error raised when ordered_dict contains a special token
updated_token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2, '<pad>': 1}
updated_sorted_by_freq_tuples = sorted(updated_token_to_freq.items(), key=lambda x: x[1], reverse=True)
updated_c = OrderedDict(updated_sorted_by_freq_tuples)
Vocab(updated_c, specials=['<unk>', '<pad>', '<bos>'])
Vocab(c, unk_token=None)

with self.assertRaises(RuntimeError):
# Test proper error raised when setting a token out of bounds
Expand All @@ -198,8 +165,7 @@ def test_vocab_load_and_save(self):
c = OrderedDict(sorted_by_freq_tuples)
v = Vocab(c, min_freq=3)

expected_itos = ['<unk>', '<pad>',
'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world', '<unk>']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}

self.assertEqual(v.get_itos(), expected_itos)
Expand All @@ -216,9 +182,9 @@ def test_vocab_from_file(self):
asset_name = 'vocab_test.txt'
asset_path = get_asset_path(asset_name)
f = open(asset_path, 'r')
v = vocab_from_file_object(f, specials=('<unk>', '<pad>', '<eos>'), specials_first=False)
v = vocab_from_file_object(f, unk_token='<new_unk>')

expected_itos = ['a', 'b', 'c', '<unk>', '<pad>', '<eos>']
expected_itos = ['<new_unk>', 'a', 'b', 'c']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}

self.assertEqual(v.get_itos(), expected_itos)
Expand Down
62 changes: 49 additions & 13 deletions torchtext/csrc/vectors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@ using c10::Dict;
namespace torchtext {
namespace {

typedef std::tuple<std::string, std::vector<int64_t>, std::vector<std::string>,
std::vector<torch::Tensor>>
VectorsStates;

struct Vectors : torch::CustomClassHolder {
public:
const std::string version_str_ = "0.0.1";

Dict<std::string, torch::Tensor> stovec_;
std::vector<std::string> tokens_;
torch::Tensor vectors_;
Expand Down Expand Up @@ -70,6 +76,45 @@ struct Vectors : torch::CustomClassHolder {
int64_t __len__() { return stovec_.size(); }
};

VectorsStates _set_vectors_states(const c10::intrusive_ptr<Vectors> &self) {
std::vector<int64_t> integers;
std::vector<std::string> strings = self->tokens_;
std::vector<torch::Tensor> tensors{self->vectors_, self->unk_tensor_};

VectorsStates states =
std::make_tuple(self->version_str_, std::move(integers),
std::move(strings), std::move(tensors));

return states;
}

c10::intrusive_ptr<Vectors> _get_vectors_from_states(VectorsStates states) {
auto state_size = std::tuple_size<decltype(states)>::value;
if (state_size != 4) {
throw std::runtime_error(
"Expected deserialized Vectors to have 4 states but found only " +
std::to_string(state_size) + " states.");
}

auto &version_str = std::get<0>(states);
auto &integers = std::get<1>(states);
auto &strings = std::get<2>(states);
auto &tensors = std::get<3>(states);

// check integers are empty
if (integers.size() != 0) {
throw std::runtime_error("Expected `integers` states to be empty.");
}

if (version_str.compare("0.0.1") >= 0) {
return c10::make_intrusive<Vectors>(
std::move(strings), std::move(tensors[0]), std::move(tensors[1]));
}

throw std::runtime_error(
"Found unexpected version for serialized Vector: " + version_str + ".");
}

// Registers our custom class with torch.
static auto vectors =
torch::class_<Vectors>("torchtext", "Vectors")
Expand All @@ -81,21 +126,12 @@ static auto vectors =
.def("__len__", &Vectors::__len__)
.def_pickle(
// __setstate__
[](const c10::intrusive_ptr<Vectors> &self)
-> std::tuple<std::vector<std::string>, torch::Tensor,
torch::Tensor> {
std::tuple<std::vector<std::string>, torch::Tensor, torch::Tensor>
states(self->tokens_, self->vectors_, self->unk_tensor_);
return states;
[](const c10::intrusive_ptr<Vectors> &self) -> VectorsStates {
return _set_vectors_states(self);
},
// __getstate__
[](std::tuple<std::vector<std::string>, torch::Tensor,
torch::Tensor>
states) -> c10::intrusive_ptr<Vectors> {
return c10::make_intrusive<Vectors>(
std::move(std::get<0>(states)),
std::move(std::get<1>(states)),
std::move(std::get<2>(states)));
[](VectorsStates states) -> c10::intrusive_ptr<Vectors> {
return _get_vectors_from_states(states);
});

} // namespace
Expand Down
Loading

0 comments on commit c851c3e

Please sign in to comment.