forked from anacrolix/dht
-
Notifications
You must be signed in to change notification settings - Fork 0
/
getput.go
170 lines (162 loc) · 4.23 KB
/
getput.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
package getput
import (
"context"
"crypto/sha1"
"errors"
"math"
"sync"
"github.com/anacrolix/log"
"github.com/anacrolix/torrent/bencode"
"github.com/anacrolix/dht/v2"
"github.com/anacrolix/dht/v2/bep44"
k_nearest_nodes "github.com/anacrolix/dht/v2/k-nearest-nodes"
"github.com/anacrolix/dht/v2/krpc"
"github.com/anacrolix/dht/v2/traversal"
)
type GetResult struct {
Seq int64
V bencode.Bytes
Mutable bool
}
func startGetTraversal(
target bep44.Target, s *dht.Server, seq *int64, salt []byte,
) (
vChan chan GetResult, op *traversal.Operation, err error,
) {
vChan = make(chan GetResult)
op = traversal.Start(traversal.OperationInput{
Alpha: 15,
Target: target,
DoQuery: func(ctx context.Context, addr krpc.NodeAddr) traversal.QueryResult {
logger := log.ContextLogger(ctx)
res := s.Get(ctx, dht.NewAddr(addr.UDP()), target, seq, dht.QueryRateLimiting{})
err := res.ToError()
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, dht.TransactionTimeout) {
logger.Levelf(log.Debug, "error querying %v: %v", addr, err)
}
if r := res.Reply.R; r != nil {
rv := r.V
bv := rv
if sha1.Sum(bv) == target {
select {
case vChan <- GetResult{
V: rv,
Mutable: false,
}:
case <-ctx.Done():
}
} else if sha1.Sum(append(r.K[:], salt...)) == target && bep44.Verify(r.K[:], salt, *r.Seq, bv, r.Sig[:]) {
select {
case vChan <- GetResult{
Seq: *r.Seq,
V: rv,
Mutable: true,
}:
case <-ctx.Done():
}
} else if rv != nil {
logger.Levelf(log.Debug, "get response item hash didn't match target: %q", rv)
}
}
tqr := res.TraversalQueryResult(addr)
// Filter replies from nodes that don't have a string token. This doesn't look prettier
// with generics. "The token value should be a short binary string." ¯\_(ツ)_/¯ (BEP 5).
tqr.ClosestData, _ = tqr.ClosestData.(string)
if tqr.ClosestData == nil {
tqr.ResponseFrom = nil
}
return tqr
},
NodeFilter: s.TraversalNodeFilter,
})
nodes, err := s.TraversalStartingNodes()
op.AddNodes(nodes)
return
}
func Get(
ctx context.Context, target bep44.Target, s *dht.Server, seq *int64, salt []byte,
) (
ret GetResult, stats *traversal.Stats, err error,
) {
vChan, op, err := startGetTraversal(target, s, seq, salt)
if err != nil {
return
}
ret.Seq = math.MinInt64
gotValue := false
receiveResults:
select {
case <-op.Stalled():
if !gotValue {
err = errors.New("value not found")
}
case v := <-vChan:
log.ContextLogger(ctx).Levelf(log.Debug, "received %#v", v)
gotValue = true
if !v.Mutable {
ret = v
break
}
if v.Seq >= ret.Seq {
ret = v
}
goto receiveResults
case <-ctx.Done():
err = ctx.Err()
}
op.Stop()
stats = op.Stats()
return
}
type seqToPut = func(seq int64) bep44.Put
func Put(
ctx context.Context, target krpc.ID, s *dht.Server, salt []byte, seqToPut seqToPut,
) (
stats *traversal.Stats, err error,
) {
logger := log.ContextLogger(ctx)
vChan, op, err := startGetTraversal(target, s,
// When we do a get traversal for a put, we don't care what seq the peers have?
nil,
// This is duplicated with the put, but we need it to filter responses for autoSeq.
salt)
if err != nil {
return
}
var autoSeq int64
notDone:
select {
case v := <-vChan:
if v.Mutable && v.Seq > autoSeq {
autoSeq = v.Seq
}
// There are more optimizations that can be done here. We can set CAS automatically, and we
// can skip updating the sequence number if the existing content already matches (and
// presumably republish the existing seq).
goto notDone
case <-op.Stalled():
case <-ctx.Done():
err = ctx.Err()
}
op.Stop()
var wg sync.WaitGroup
put := seqToPut(autoSeq)
op.Closest().Range(func(elem k_nearest_nodes.Elem) {
wg.Add(1)
go func() {
defer wg.Done()
// This is enforced by startGetTraversal.
token := elem.Data.(string)
res := s.Put(ctx, dht.NewAddr(elem.Addr.UDP()), put, token, dht.QueryRateLimiting{})
err := res.ToError()
if err != nil {
logger.Levelf(log.Warning, "error putting to %v [token=%q]: %v", elem.Addr, token, err)
} else {
logger.Levelf(log.Debug, "put to %v [token=%q]", elem.Addr, token)
}
}()
})
wg.Wait()
stats = op.Stats()
return
}