/
peer_storage.go
129 lines (105 loc) · 2.29 KB
/
peer_storage.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
package mtpwrap
import (
"context"
"errors"
"sort"
"sync"
"github.com/gotd/contrib/storage"
)
// MemStorage is the default peer storage for MTP. It uses a map to store all
// peers, hence, it's not a persistent store.
type MemStorage struct {
s map[string]storage.Peer
mu sync.RWMutex
iterating bool
// iterator
keys []string
keyIdx int
iterErr error
}
func NewMemStorage() *MemStorage {
return &MemStorage{
s: make(map[string]storage.Peer, 0),
}
}
func (ms *MemStorage) Add(_ context.Context, value storage.Peer) error {
ms.mu.Lock()
defer ms.mu.Unlock()
key := storage.KeyFromPeer(value).String()
ms.s[key] = value
return nil
}
func (ms *MemStorage) Find(ctx context.Context, key storage.PeerKey) (storage.Peer, error) {
return ms.Resolve(ctx, key.String())
}
func (ms *MemStorage) Assign(_ context.Context, key string, value storage.Peer) error {
ms.mu.Lock()
defer ms.mu.Unlock()
ms.s[key] = value
return nil
}
func (ms *MemStorage) Resolve(_ context.Context, key string) (storage.Peer, error) {
ms.mu.RLock()
defer ms.mu.RUnlock()
peer, ok := ms.s[key]
if !ok {
return storage.Peer{}, storage.ErrPeerNotFound
}
return peer, nil
}
func (ms *MemStorage) Iterate(ctx context.Context) (storage.PeerIterator, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
if ms.IsIterating() {
return nil, errors.New("already iterating")
}
// preparing the iterator
ms.mu.Lock()
ms.keys = make([]string, 0, len(ms.s))
for k := range ms.s {
ms.keys = append(ms.keys, k)
}
sort.Strings(ms.keys)
ms.keyIdx = -1 // set the passphrase start value
ms.iterating = true
ms.iterErr = nil
ms.mu.Unlock()
// locking for iteration
ms.mu.RLock()
return ms, nil
}
func (ms *MemStorage) Next(ctx context.Context) bool {
select {
case <-ctx.Done():
ms.iterErr = ctx.Err()
return false
default:
}
ms.keyIdx++
return ms.keyIdx < len(ms.keys)
}
func (ms *MemStorage) Err() error {
return ms.iterErr
}
func (ms *MemStorage) Value() storage.Peer {
if !ms.IsIterating() {
return storage.Peer{}
}
return ms.s[ms.keys[ms.keyIdx]]
}
func (ms *MemStorage) Close() error {
if !ms.IsIterating() {
return nil
}
ms.mu.RUnlock()
ms.mu.Lock()
ms.iterating = false
ms.mu.Unlock()
return nil
}
func (ms *MemStorage) IsIterating() bool {
return ms.iterating
}