forked from jacexh/ultron
/
attacker.go
181 lines (156 loc) · 4.05 KB
/
attacker.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
package ultron
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/url"
"time"
)
type (
// Attacker 事务接口
Attacker interface {
Name() string
Fire(context.Context) error
}
// HTTPPrepareFunc 构造http.Request函数,需要调用方定义,由HTTPAttacker来发送
HTTPPrepareFunc func(context.Context) (*http.Request, error)
// HTTPCheckFunc http.Response校验函数,可由调用方自定义,如果返回error,则视为请求失败
HTTPCheckFunc func(context.Context, *http.Response, []byte) error
// HTTPAttacker 内置net/http库对Attacker的实现
HTTPAttacker struct {
client *http.Client
name string
prepareFunc HTTPPrepareFunc
checkFuncs []HTTPCheckFunc
}
// HTTPAttackerOption HTTPAttacker配置项
HTTPAttackerOption func(*HTTPAttacker)
)
const (
defaultUserAgent = "github.com/wosai/ultron"
)
var (
// defaultHTTPClient 默认http.Client
// http://tleyden.github.io/blog/2016/11/21/tuning-the-go-http-client-library-for-load-testing/
defaultHTTPClient = &http.Client{
Timeout: 45 * time.Second,
Transport: &http.Transport{
Proxy: nil,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
DisableKeepAlives: false,
MaxIdleConns: 1000,
MaxIdleConnsPerHost: 1000,
IdleConnTimeout: 30 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
}
_ Attacker = (*HTTPAttacker)(nil)
)
func NewHTTPAttacker(name string, opts ...HTTPAttackerOption) *HTTPAttacker {
attacker := &HTTPAttacker{
client: defaultHTTPClient,
name: name,
checkFuncs: make([]HTTPCheckFunc, 0),
}
attacker.Apply(opts...)
return attacker
}
func (ha *HTTPAttacker) Name() string {
return ha.name
}
func (ha *HTTPAttacker) Fire(ctx context.Context) error {
if ha.prepareFunc == nil {
panic("call Apply(WithPrepareFunc()) first")
}
ctx = AllocateStorageInContext(ctx)
defer ClearStorageInContext(ctx)
req, err := ha.prepareFunc(ctx)
if err != nil {
return err
}
req = req.WithContext(ctx)
// change user agent
if req.Header.Get("User-Agent") == "" {
req.Header.Set("User-Agent", defaultUserAgent)
}
res, err := ha.client.Do(req)
if err != nil {
return err
}
if len(ha.checkFuncs) == 0 {
io.Copy(io.Discard, res.Body) // no checker defined, discard body
return res.Body.Close()
}
body, err := io.ReadAll(res.Body)
if err != nil {
return err
}
res.Body.Close()
for _, check := range ha.checkFuncs {
if err = check(ctx, res, body); err != nil {
return err
}
}
return nil
}
func (ha *HTTPAttacker) Apply(opts ...HTTPAttackerOption) {
for _, opt := range opts {
opt(ha)
}
}
func WithClient(client *http.Client) HTTPAttackerOption {
return func(h *HTTPAttacker) {
h.client = client
}
}
func WithPrepareFunc(prepare HTTPPrepareFunc) HTTPAttackerOption {
return func(h *HTTPAttacker) {
if prepare == nil {
panic("invalid HTTPPrepareFunc")
}
h.prepareFunc = prepare
}
}
func WithCheckFuncs(checks ...HTTPCheckFunc) HTTPAttackerOption {
return func(h *HTTPAttacker) {
for _, check := range checks {
if check == nil {
panic("invalid HTTPCheckFunc")
}
}
h.checkFuncs = append(h.checkFuncs, checks...)
}
}
func WithDisableKeepAlives(disable bool) HTTPAttackerOption {
return func(h *HTTPAttacker) {
if tran, ok := h.client.Transport.(*http.Transport); ok {
tran.DisableKeepAlives = disable
}
}
}
func WithTimeout(t time.Duration) HTTPAttackerOption {
return func(h *HTTPAttacker) {
h.client.Timeout = t
}
}
func WithProxy(proxy func(*http.Request) (*url.URL, error)) HTTPAttackerOption {
return func(h *HTTPAttacker) {
if transport, ok := h.client.Transport.(*http.Transport); ok {
transport.Proxy = proxy
}
}
}
// CheckHTTPStatusCode 检查状态码是否>=400, 如果是则视为请求失败
func CheckHTTPStatusCode(_ context.Context, res *http.Response, body []byte) error {
if res.StatusCode >= http.StatusBadRequest {
return fmt.Errorf("bad status code: %d", res.StatusCode)
}
return nil
}