This repository has been archived by the owner on Jan 19, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
vocabulary.h
132 lines (102 loc) · 5.19 KB
/
vocabulary.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#pragma once
#include "pool.h"
#include <iosfwd>
#include <limits>
#include <list>
#include <memory>
#include <vector>
#include <cstdint>
namespace yzw2v {
namespace vocab {
static constexpr size_t MAX_TOKEN_LENGTH = 256;
static constexpr uint32_t INVALID_TOKEN_ID = std::numeric_limits<uint32_t>::max();
class Token {
public:
Token() noexcept = default;
Token(const Token& other) noexcept = default;
Token(Token&& other) noexcept = default;
explicit Token(const char* const begin) noexcept;
Token(const char* const begin, const char* const end) noexcept;
Token(const char* const begin, const uint8_t lenght) noexcept;
~Token() noexcept = default;
Token& operator=(const Token& other) noexcept = default;
Token& operator=(Token&& other) noexcept = default;
bool operator ==(const Token& other) const noexcept;
bool operator !=(const Token& other) const noexcept;
bool operator <(const Token& other) const noexcept;
bool operator >(const Token& other) const noexcept;
bool operator <=(const Token& other) const noexcept;
bool operator >=(const Token& other) const noexcept;
const char* cbegin() const noexcept;
const char* cend() const noexcept;
const char* begin() const noexcept;
const char* end() const noexcept;
uint8_t length() const noexcept;
private:
const char* begin_{nullptr};
uint8_t length_{0};
};
static const yzw2v::vocab::Token PARAGRAPH_TOKEN{"</s>"};
static const uint32_t PARAGRAPH_TOKEN_ID = 0;
struct TokenInfo {
Token token;
uint32_t count;
TokenInfo() noexcept = default;
TokenInfo(const TokenInfo& other) noexcept = default;
TokenInfo(TokenInfo&& other) noexcept = default;
TokenInfo(const Token& token_, const uint32_t count_) noexcept;
~TokenInfo() noexcept = default;
TokenInfo& operator=(const TokenInfo& other) noexcept = default;
TokenInfo& operator=(TokenInfo&& other) noexcept = default;
};
class Vocabulary {
public:
using const_iterator = std::vector<TokenInfo>::const_iterator;
using const_reverse_iterator = std::vector<TokenInfo>::const_reverse_iterator;
explicit Vocabulary(const uint32_t max_number_of_tokens);
uint32_t Add(const Token& token);
bool Has(const Token& token) const noexcept;
uint32_t ID(const Token& token) const noexcept;
bool Has(const uint32_t id) const noexcept;
const TokenInfo& Token(const uint32_t id) const noexcept;
uint32_t Count(const uint32_t id) const noexcept;
uint32_t size() const noexcept;
float LoadFactor() const noexcept;
uint64_t TextWordCount() const noexcept;
void Sort() noexcept;
const_iterator cbegin() noexcept;
const_iterator cend() noexcept;
const_reverse_iterator crbegin() noexcept;
const_reverse_iterator crend() noexcept;
private:
uint32_t max_number_of_tokens_;
uint32_t hash_table_size_;
mem::Pool pool_;
std::vector<uint32_t> hash_;
std::vector<TokenInfo> tokens_;
public:
static void WriteTSVWithFilter(const Vocabulary& vocab, const std::string& path,
const uint32_t min_token_freq);
static void WriteBinaryWithFilter(const Vocabulary& vocab, const std::string& path,
const uint32_t min_token_freq);
static void ReadBinaryWithFilter(const std::string& path, const uint32_t min_token_freq,
Vocabulary& vocab);
};
void CollectIntoVocabulary(const std::string& path, const uint32_t min_token_freq,
Vocabulary& vocab);
Vocabulary CollectVocabulary(const std::string& path, const uint32_t min_token_freq,
const uint32_t max_number_of_tokens);
void WriteTSV(const Vocabulary& vocab, const std::string& path);
void WriteTSVWithFilter(const Vocabulary& vocab, const std::string& path,
const uint32_t min_token_freq);
void WriteBinary(const Vocabulary& vocab, const std::string& path);
void WriteBinaryWithFilter(const Vocabulary& vocab, const std::string& path,
const uint32_t min_token_freq);
Vocabulary ReadBinary(const std::string& path);
void ReadBinary(const std::string& path, Vocabulary& vocab);
Vocabulary ReadBinaryWithFilter(const std::string& path, const uint32_t min_token_freq);
void ReadBinaryWithFilter(const std::string& path, const uint32_t min_token_freq,
Vocabulary& vocab);
} // namespace vocab
} // namespace yzw2v
std::ostream& operator<<(std::ostream& out, const yzw2v::vocab::Token& token);