-
Notifications
You must be signed in to change notification settings - Fork 211
/
deadline_adjuster.go
163 lines (147 loc) · 3.79 KB
/
deadline_adjuster.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
package server
import (
"context"
"errors"
"fmt"
"io"
"time"
"github.com/jonboulle/clockwork"
"github.com/libp2p/go-yamux/v4"
)
const (
deadlineAdjusterChunkSize = 4096
)
type deadlineAdjusterError struct {
what string
innerErr error
elapsed time.Duration
totalRead int
totalWritten int
timeout time.Duration
hardTimeout time.Duration
}
func (err *deadlineAdjusterError) Unwrap() error {
return err.innerErr
}
func (err *deadlineAdjusterError) Error() string {
return fmt.Sprintf("%s: %v elapsed, %d bytes read, %d bytes written, timeout %v, hard timeout %v: %v",
err.what,
err.elapsed,
err.totalRead,
err.totalWritten,
err.timeout,
err.hardTimeout,
err.innerErr)
}
type deadlineAdjuster struct {
peerStream
timeout time.Duration
hardTimeout time.Duration
totalRead int
totalWritten int
start time.Time
clock clockwork.Clock
chunkSize int
nextAdjustRead int
nextAdjustWrite int
hardDeadline time.Time
}
var _ io.ReadWriteCloser = &deadlineAdjuster{}
func newDeadlineAdjuster(stream peerStream, timeout, hardTimeout time.Duration) *deadlineAdjuster {
return &deadlineAdjuster{
peerStream: stream,
timeout: timeout,
hardTimeout: hardTimeout,
start: time.Now(),
clock: clockwork.NewRealClock(),
chunkSize: deadlineAdjusterChunkSize,
nextAdjustRead: -1,
nextAdjustWrite: -1,
}
}
func (dadj *deadlineAdjuster) augmentError(what string, err error) error {
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, yamux.ErrTimeout) {
return err
}
return &deadlineAdjusterError{
what: what,
innerErr: err,
elapsed: dadj.clock.Now().Sub(dadj.start),
totalRead: dadj.totalRead,
totalWritten: dadj.totalWritten,
timeout: dadj.timeout,
hardTimeout: dadj.hardTimeout,
}
}
// Close closes the stream. This method is safe to call multiple times.
func (dadj *deadlineAdjuster) Close() error {
// FIXME: unsure if this is really needed (inherited from the older Server code)
_ = dadj.peerStream.SetDeadline(time.Time{})
return dadj.peerStream.Close()
}
func (dadj *deadlineAdjuster) adjust() error {
now := dadj.clock.Now()
if dadj.hardDeadline.IsZero() {
dadj.hardDeadline = now.Add(dadj.hardTimeout)
} else if now.After(dadj.hardDeadline) {
// emulate yamux timeout error
return yamux.ErrTimeout
}
// Do not adjust the deadline too often
adj := false
if dadj.totalRead > dadj.nextAdjustRead {
dadj.nextAdjustRead = dadj.totalRead + dadj.chunkSize
adj = true
}
if dadj.totalWritten > dadj.nextAdjustWrite {
dadj.nextAdjustWrite = dadj.totalWritten + dadj.chunkSize
adj = true
}
if adj {
// We ignore the error returned by SetDeadline b/c the call
// doesn't work for mock hosts
deadline := now.Add(dadj.timeout)
if deadline.After(dadj.hardDeadline) {
_ = dadj.SetDeadline(dadj.hardDeadline)
} else {
_ = dadj.SetDeadline(deadline)
}
}
return nil
}
func (dadj *deadlineAdjuster) Read(p []byte) (int, error) {
var n int
for n < len(p) {
if err := dadj.adjust(); err != nil {
return n, dadj.augmentError("read", err)
}
to := min(len(p), n+dadj.chunkSize)
nCur, err := dadj.peerStream.Read(p[n:to])
n += nCur
dadj.totalRead += nCur
if err != nil {
return n, dadj.augmentError("read", err)
}
if n < to {
// Short read, don't try to read more data
break
}
}
return n, nil
}
func (dadj *deadlineAdjuster) Write(p []byte) (n int, err error) {
var nCur int
for n < len(p) {
if err := dadj.adjust(); err != nil {
return n, dadj.augmentError("write", err)
}
to := min(len(p), n+dadj.chunkSize)
nCur, err = dadj.peerStream.Write(p[n:to])
n += nCur
dadj.totalWritten += nCur
if err != nil {
return n, dadj.augmentError("write", err)
}
}
return n, nil
}