-
Notifications
You must be signed in to change notification settings - Fork 0
/
shard_tracker.go
165 lines (136 loc) · 4.23 KB
/
shard_tracker.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
package amp
import (
"crypto/rand"
"encoding/binary"
"fmt"
"sync"
"github.com/brolightningnetwork/broln/lntypes"
"github.com/brolightningnetwork/broln/lnwire"
"github.com/brolightningnetwork/broln/record"
"github.com/brolightningnetwork/broln/routing/shards"
)
// Shard is an implementation of the shards.PaymentShards interface specific
// to AMP payments.
type Shard struct {
child *Child
mpp *record.MPP
amp *record.AMP
}
// A compile time check to ensure Shard implements the shards.PaymentShard
// interface.
var _ shards.PaymentShard = (*Shard)(nil)
// Hash returns the hash used for the HTLC representing this AMP shard.
func (s *Shard) Hash() lntypes.Hash {
return s.child.Hash
}
// MPP returns any extra MPP records that should be set for the final hop on
// the route used by this shard.
func (s *Shard) MPP() *record.MPP {
return s.mpp
}
// AMP returns any extra AMP records that should be set for the final hop on
// the route used by this shard.
func (s *Shard) AMP() *record.AMP {
return s.amp
}
// ShardTracker is an implementation of the shards.ShardTracker interface
// that is able to generate payment shards according to the AMP splitting
// algorithm. It can be used to generate new hashes to use for HTLCs, and also
// cancel shares used for failed payment shards.
type ShardTracker struct {
setID [32]byte
paymentAddr [32]byte
totalAmt lnwire.MilliSatoshi
sharer Sharer
shards map[uint64]*Child
sync.Mutex
}
// A compile time check to ensure ShardTracker implements the
// shards.ShardTracker interface.
var _ shards.ShardTracker = (*ShardTracker)(nil)
// NewShardTracker creates a new shard tracker to use for AMP payments. The
// root shard, setID, payment address and total amount must be correctly set in
// order for the TLV options to include with each shard to be created
// correctly.
func NewShardTracker(root, setID, payAddr [32]byte,
totalAmt lnwire.MilliSatoshi) *ShardTracker {
// Create a new seed sharer from this root.
rootShare := Share(root)
rootSharer := SeedSharerFromRoot(&rootShare)
return &ShardTracker{
setID: setID,
paymentAddr: payAddr,
totalAmt: totalAmt,
sharer: rootSharer,
shards: make(map[uint64]*Child),
}
}
// NewShard registers a new attempt with the ShardTracker and returns a
// new shard representing this attempt. This attempt's shard should be canceled
// if it ends up not being used by the overall payment, i.e. if the attempt
// fails.
func (s *ShardTracker) NewShard(pid uint64, last bool) (shards.PaymentShard,
error) {
s.Lock()
defer s.Unlock()
// Use a random child index.
var childIndex [4]byte
if _, err := rand.Read(childIndex[:]); err != nil {
return nil, err
}
idx := binary.BigEndian.Uint32(childIndex[:])
// Depending on whether we are requesting the last shard or not, either
// split the current share into two, or get a Child directly from the
// current sharer.
var child *Child
if last {
child = s.sharer.Child(idx)
// If this was the last shard, set the current share to the
// zero share to indicate we cannot split it further.
s.sharer = s.sharer.Zero()
} else {
left, sharer, err := s.sharer.Split()
if err != nil {
return nil, err
}
s.sharer = sharer
child = left.Child(idx)
}
// Track the new child and return the shard.
s.shards[pid] = child
mpp := record.NewMPP(s.totalAmt, s.paymentAddr)
amp := record.NewAMP(
child.ChildDesc.Share, s.setID, child.ChildDesc.Index,
)
return &Shard{
child: child,
mpp: mpp,
amp: amp,
}, nil
}
// CancelShard cancel's the shard corresponding to the given attempt ID.
func (s *ShardTracker) CancelShard(pid uint64) error {
s.Lock()
defer s.Unlock()
c, ok := s.shards[pid]
if !ok {
return fmt.Errorf("pid not found")
}
delete(s.shards, pid)
// Now that we are canceling this shard, we XOR the share back into our
// current share.
s.sharer = s.sharer.Merge(c)
return nil
}
// GetHash retrieves the hash used by the shard of the given attempt ID. This
// will return an error if the attempt ID is unknown.
func (s *ShardTracker) GetHash(pid uint64) (lntypes.Hash, error) {
s.Lock()
defer s.Unlock()
c, ok := s.shards[pid]
if !ok {
return lntypes.Hash{}, fmt.Errorf("AMP shard for attempt %v "+
"not found", pid)
}
return c.Hash, nil
}