Skip to content

Commit

Permalink
Merge branch 'experiment/functional' into release/0.5.0
Browse files Browse the repository at this point in the history
  • Loading branch information
ross-pure committed Mar 24, 2021
2 parents fbad62f + 51df5ec commit cc63fd6
Show file tree
Hide file tree
Showing 14 changed files with 184 additions and 204 deletions.
151 changes: 46 additions & 105 deletions channel/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net"
"sync/atomic"
"time"

"github.com/renproject/aw/codec"
Expand Down Expand Up @@ -215,49 +216,47 @@ func (ch Channel) Remote() id.Signatory {
}

func (ch *Channel) readLoop(ctx context.Context) error {
buf := make([]byte, ch.opts.MaxMessageSize)
read := func(r reader, drain <-chan struct{}) {
draining := uint64(0)

var r reader
var rOk bool
// If the drain channel is written to, this signals that this reader is
// now expired and we should begin draining it.
go func() {
<-drain

var m wire.Msg
var mOk bool
atomic.StoreUint64(&draining, 1)

syncData := make([]byte, ch.opts.MaxMessageSize)

for {
if !rOk {
select {
case <-ctx.Done():
return ctx.Err()
case r, rOk = <-ch.readers:
ch.opts.Logger.Debug("replaced reader", zap.String("remote", ch.remote.String()), zap.String("addr", r.Conn.RemoteAddr().String()))
// Set the read deadline here, instead of per-message, so that the
// remote peer cannot easily "slow loris" the local peer by
// periodically sending messages into the draining connection.
if err := r.Conn.SetReadDeadline(time.Now().Add(ch.opts.DrainTimeout)); err != nil {
ch.opts.Logger.Error("drain: set deadline", zap.Error(err))
return
}
}
}()

buf := make([]byte, ch.opts.MaxMessageSize)
bufSyncData := make([]byte, ch.opts.MaxMessageSize)

if !mOk {
for {
n, err := r.Decoder(r.Reader, buf[:])
if err != nil {
ch.opts.Logger.Error("decode", zap.Error(err))
// If reading from the reader fails, then clear the reader. This
// will cause the next iteration to wait until a new underlying
// network connection is attached to the Channel.
draining := atomic.LoadUint64(&draining)
ch.opts.Logger.Error("decode", zap.Uint64("draining", draining), zap.Error(err))
close(r.q)
r = reader{}
rOk = false
continue
return
}

// Check that the underlying connection is not exceeding its rate
// limit.
if !ch.rateLimiter.AllowN(time.Now(), n) {
ch.opts.Logger.Error("rate limit exceeded", zap.String("remote", ch.remote.String()), zap.String("addr", r.Conn.RemoteAddr().String()))
close(r.q)
r = reader{}
rOk = false
continue
return
}

m := wire.Msg{}

// Unmarshal the message from binary. If this is successfully, then
// we mark the message as available (and will attempt to write it to
// the inbound message channel).
Expand All @@ -266,8 +265,6 @@ func (ch *Channel) readLoop(ctx context.Context) error {
continue
}

mOk = true

// An aggressive filtering strategy would involve pre-filtering
// synchronisation messages before reading the synchronisation data.
// However, in practice, this does not provide much of an advantage
Expand All @@ -276,44 +273,41 @@ func (ch *Channel) readLoop(ctx context.Context) error {
// rate-limiting, and (b) filtering that happens in the client
// results in bad channels being killed quickly anyway.
if m.Type == wire.MsgTypeSync {
n, err := r.Decoder(r.Reader, syncData)
n, err := r.Decoder(r.Reader, bufSyncData)
if err != nil {
ch.opts.Logger.Error("decode sync data", zap.Error(err))
// If reading from the reader fails, then clear the reader. This
// will cause the next iteration to wait until a new underlying
// network connection is attached to the Channel.
close(r.q)
r = reader{}
rOk = false
mOk = false
continue
return
}
m.SyncData = make([]byte, n)
copy(m.SyncData, syncData[:n])
copy(m.SyncData, bufSyncData[:n])
}

select {
case <-ctx.Done():
if r.q != nil {
close(r.q)
}
return
case ch.inbound <- m:
}
}
}

// At this point, a message is guaranteed to be available, so we attempt
// to write it to the inbound message channel.
drain := make(chan struct{}, 1)
for {
select {
case <-ctx.Done():
if r.q != nil {
close(r.q)
}
return ctx.Err()
case ch.inbound <- m:
// If we succeed, then we clear the message. This will allow us to
// progress and try to read the next message.
m = wire.Msg{}
mOk = false
case v, vOk := <-ch.readers:
// If a new underlying network connection is attached to the
// Channel before we can write the message to the inbound message
// channel, we do not clear the message. This will force us to
// re-attempt writing the message in the next iteration.
ch.drainReader(ctx, r, m, mOk)
r, rOk = v, vOk
m, mOk = wire.Msg{}, false
case r := <-ch.readers:
ch.opts.Logger.Debug("replaced reader", zap.String("remote", ch.remote.String()), zap.String("addr", r.Conn.RemoteAddr().String()))

drain <- struct{}{} // Write to the previous drain channel.
drain = make(chan struct{}, 1) // Create a new drain channel.
go read(r, drain)
}
}
}
Expand Down Expand Up @@ -404,56 +398,3 @@ func (ch *Channel) writeLoop(ctx context.Context) {
}
}
}

func (ch *Channel) drainReader(ctx context.Context, r reader, m wire.Msg, mOk bool) {
f := func() {
defer func() {
if r.q != nil {
close(r.q)
}
}()
if mOk {
select {
case <-ctx.Done():
return
case ch.inbound <- m:
}
}

// Set the deadline here, instead of per-message, so that the remote
// peer can easily "slow loris" the local peer by periodically sending
// messages into the draining connection.
if err := r.Conn.SetDeadline(time.Now().Add(ch.opts.DrainTimeout)); err != nil {
ch.opts.Logger.Error("drain: set deadline", zap.Error(err))
return
}

buf := make([]byte, ch.opts.MaxMessageSize)
msg := wire.Msg{}
for {
n, err := r.Decoder(r.Reader, buf[:])
if err != nil {
// We do not log this as an error, because it is entirely
// expected when draining.
ch.opts.Logger.Info("drain: decode", zap.Error(err))
return
}
if _, _, err := msg.Unmarshal(buf[:n], len(buf)); err != nil {
ch.opts.Logger.Error("drain: unmarshal", zap.Error(err))
return
}
select {
case <-ctx.Done():
return
case ch.inbound <- msg:
}
}
}
if ch.opts.DrainInBackground {
ch.opts.Logger.Debug("drain: background", zap.String("addr", r.Conn.RemoteAddr().String()))
go f()
return
}
ch.opts.Logger.Debug("drain: foreground", zap.String("addr", r.Conn.RemoteAddr().String()))
f()
}
56 changes: 9 additions & 47 deletions channel/channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@ import (

var _ = Describe("Channels", func() {

run := func(ctx context.Context, remote id.Signatory, drainInBg bool) (*channel.Channel, <-chan wire.Msg, chan<- wire.Msg) {
run := func(ctx context.Context, remote id.Signatory) (*channel.Channel, <-chan wire.Msg, chan<- wire.Msg) {
inbound, outbound := make(chan wire.Msg), make(chan wire.Msg)
ch := channel.New(
channel.DefaultOptions().
WithDrainInBackground(drainInBg).
WithDrainTimeout(3000*time.Millisecond),
channel.DefaultOptions().WithDrainTimeout(1500*time.Millisecond),
remote,
inbound,
outbound)
Expand Down Expand Up @@ -65,7 +63,6 @@ var _ = Describe("Channels", func() {
max := uint64(0)
received := make(map[uint64]int, n)
for iter := uint64(0); iter < n; iter++ {
time.Sleep(time.Millisecond)
select {
case msg := <-inbound:
data := binary.BigEndian.Uint64(msg.Data)
Expand Down Expand Up @@ -101,8 +98,8 @@ var _ = Describe("Channels", func() {

localPrivKey := id.NewPrivKey()
remotePrivKey := id.NewPrivKey()
localCh, localInbound, localOutbound := run(ctx, remotePrivKey.Signatory(), true)
remoteCh, remoteInbound, remoteOutbound := run(ctx, localPrivKey.Signatory(), true)
localCh, localInbound, localOutbound := run(ctx, remotePrivKey.Signatory())
remoteCh, remoteInbound, remoteOutbound := run(ctx, localPrivKey.Signatory())

// Remote channel will listen for incoming connections.
listen(ctx, remoteCh, remotePrivKey.Signatory(), localPrivKey.Signatory(), 3333)
Expand Down Expand Up @@ -138,8 +135,8 @@ var _ = Describe("Channels", func() {

localPrivKey := id.NewPrivKey()
remotePrivKey := id.NewPrivKey()
localCh, localInbound, localOutbound := run(ctx, remotePrivKey.Signatory(), true)
remoteCh, remoteInbound, remoteOutbound := run(ctx, localPrivKey.Signatory(), true)
localCh, localInbound, localOutbound := run(ctx, remotePrivKey.Signatory())
remoteCh, remoteInbound, remoteOutbound := run(ctx, localPrivKey.Signatory())

// Number of messages that we will test.
n := uint64(1000)
Expand Down Expand Up @@ -176,13 +173,13 @@ var _ = Describe("Channels", func() {

localPrivKey := id.NewPrivKey()
remotePrivKey := id.NewPrivKey()
localCh, localInbound, localOutbound := run(ctx, remotePrivKey.Signatory(), true)
remoteCh, remoteInbound, remoteOutbound := run(ctx, localPrivKey.Signatory(), true)
localCh, localInbound, localOutbound := run(ctx, remotePrivKey.Signatory())
remoteCh, remoteInbound, remoteOutbound := run(ctx, localPrivKey.Signatory())

// Number of messages that we will test. This number is higher than
// in other tests, because we need sending/receiving to take long
// enough that replacements will happen.
n := uint64(3000)
n := uint64(10000)
// Send and receive messages in both direction; from local to
// remote, and from remote to local.
q1 := sink(localOutbound, n)
Expand All @@ -203,40 +200,5 @@ var _ = Describe("Channels", func() {
<-q4
})
})

Context("when draining connections in the foreground", func() {
It("should send and receive all messages in order", func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

localPrivKey := id.NewPrivKey()
remotePrivKey := id.NewPrivKey()
localCh, localInbound, localOutbound := run(ctx, remotePrivKey.Signatory(), false)
remoteCh, remoteInbound, remoteOutbound := run(ctx, localPrivKey.Signatory(), false)

// Number of messages that we will test. This number is higher than
// in other tests, because we need sending/receiving to take long
// enough that replacements will happen.
n := uint64(3000)
// Send and receive messages in both direction; from local to
// remote, and from remote to local.
q1 := sink(localOutbound, n)
q2 := stream(remoteInbound, n, true)
q3 := sink(remoteOutbound, n)
q4 := stream(localInbound, n, true)

// Remote channel will listen for incoming connections.
listen(ctx, remoteCh, remotePrivKey.Signatory(), localPrivKey.Signatory(), 3363)
// Local channel will dial the listener (and re-dial once per
// second).
dial(ctx, localCh, localPrivKey.Signatory(), remotePrivKey.Signatory(), 3363, time.Second)

// Wait for sinking and streaming to finish.
<-q1
<-q2
<-q3
<-q4
})
})
})
})
11 changes: 8 additions & 3 deletions channel/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,21 @@ func (client *Client) Bind(remote id.Signatory) {
ctx, cancel := context.WithCancel(context.Background())
ch := New(client.opts, remote, inbound, outbound)
go func() {
defer close(inbound)
if err := ch.Run(ctx); err != nil {
client.opts.Logger.Error("run", zap.Error(err))
}
}()
go func() {
for msg := range inbound {
for {
select {
case <-ctx.Done():
case client.inbound <- Msg{Msg: msg, From: remote}:
return
case msg := <-inbound:
select {
case <-ctx.Done():
return
case client.inbound <- Msg{Msg: msg, From: remote}:
}
}
}
}()
Expand Down
22 changes: 12 additions & 10 deletions channel/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,21 @@ var _ = Describe("Client", func() {
defer time.Sleep(time.Millisecond) // Wait for the receiver to be shutdown.
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
receiver := make(chan wire.Msg)
client.Receive(ctx, func(signatory id.Signatory, msg wire.Msg) error {
receiver <- msg
return nil
})
//for iter := uint64(0); iter < n; iter++ {
// time.Sleep(time.Millisecond)
// select {
// case <-ctx.Done():
// Expect(ctx.Err()).ToNot(HaveOccurred())
// case msg := <-receiver:
// data := binary.BigEndian.Uint64(msg.Data)
// Expect(data).To(Equal(iter))
// }
//}
for iter := uint64(0); iter < n; iter++ {
time.Sleep(time.Millisecond)
select {
case <-ctx.Done():
Expect(ctx.Err()).ToNot(HaveOccurred())
case msg := <-receiver:
data := binary.BigEndian.Uint64(msg.Data)
Expect(data).To(Equal(iter))
}
}
}()
return quit
}
Expand Down
Loading

0 comments on commit cc63fd6

Please sign in to comment.