-
Notifications
You must be signed in to change notification settings - Fork 3
/
vocab.go
116 lines (99 loc) · 2.45 KB
/
vocab.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
package bert
import (
"bufio"
"os"
"github.com/sunhailin-Leo/triton-service-go/utils"
)
// Provider is an interface for exposing a vocab.
type Provider interface {
Vocab() Dict
}
// ID is used to identify vocab items.
type ID int32
// Int64 int32 ID to int64.
func (id ID) Int64() int64 {
return int64(id)
}
// Dict is a container for tokens
// NOTE: python uses an OrderedDict, unsure of implications.
type Dict struct {
tokens map[string]ID
}
// VocabFromFile will read a newline delimited file into a Dict.
func VocabFromFile(path string) (Dict, error) {
f, err := os.Open(path)
if err != nil {
// TODO wrap w/ stdlib
return Dict{}, err
}
defer func() { _ = f.Close() }()
scanner := bufio.NewScanner(f)
voc := Dict{tokens: map[string]ID{}}
for scanner.Scan() {
voc.Add(scanner.Text())
}
return voc, nil
}
// VocabFromSlice will read vocab from config into a Dict.
func VocabFromSlice(vocabArr []string) (Dict, error) {
if len(vocabArr) == 0 {
return Dict{}, utils.ErrEmptyVocab
}
voc := Dict{tokens: map[string]ID{}}
for i := range vocabArr {
voc.Add(vocabArr[i])
}
return voc, nil
}
// New will return a vocab dict from the given tokens, IDs will match index.
func New(tokens []string) Dict {
v := make(map[string]ID, len(tokens))
for i := range tokens {
v[tokens[i]] = ID(i)
}
return Dict{tokens: v}
}
// Add will add an item to the vocabulary, is not thread-safe.
func (v Dict) Add(token string) {
v.tokens[token] = ID(v.Size())
}
// GetID will return the ID of the token in the vocab. Will be negative if it doesn't exist.
func (v Dict) GetID(token string) ID {
id, ok := v.tokens[token]
if !ok {
return ID(-1)
}
return id
}
// Size returns the size of the vocabulary.
func (v Dict) Size() int {
return len(v.tokens)
}
// LongestSubstring returns the longest token that is a substring of the token.
func (v Dict) LongestSubstring(token string) string {
// Greedy, optimize to trie if needed
for i := len(token); i > 0; i-- {
sub := token[:i]
if v.IsInVocab(sub) {
return sub
}
}
return ""
}
// ConvertItems convert items to ids.
func (v Dict) ConvertItems(items []string) []ID {
ids := make([]ID, len(items))
for i := range items {
ids[i] = v.tokens[items[i]]
}
return ids
}
// ConvertTokens convert token to id.
func (v Dict) ConvertTokens(tokens []string) []ID {
return v.ConvertItems(tokens)
}
// IsInVocab token is in vocabs.
func (v Dict) IsInVocab(token string) bool {
_, exists := v.tokens[token]
return exists
}