-
Notifications
You must be signed in to change notification settings - Fork 0
/
plugins.go
144 lines (121 loc) · 3.48 KB
/
plugins.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
// Package plugins
//
// @author: xwc1125
package plugins
import (
"errors"
"fmt"
"sort"
"sync"
"github.com/valyala/fasthttp"
)
var (
pluginRegistry = pluginRegistries{opts: make(map[string]Plugin)}
ErrMissingName = errors.New("missing name")
ErrMissingParseConfMethod = errors.New("missing ParseConf method")
ErrMissingRequestFilterMethod = errors.New("missing RequestFilter method")
ErrMissingResponseFilterMethod = errors.New("missing ResponseFilter method")
RequestPhase = requestPhase{} // 请求阶段
ResponsePhase = responsePhase{} // 响应阶段
)
type ErrPluginRegistered struct {
name string
}
func (err ErrPluginRegistered) Error() string {
return fmt.Sprintf("plugin %s registered", err.name)
}
type pluginRegistries struct {
sync.Mutex
opts map[string]Plugin
}
func RegisterPlugin(plugin Plugin) error {
log().Info("register plugin", "name", plugin.Name(), "version", plugin.Version(), "priority", plugin.Priority())
if plugin.Name() == "" {
return ErrMissingName
}
pluginRegistry.Lock()
defer pluginRegistry.Unlock()
if _, found := pluginRegistry.opts[plugin.Name()]; found {
return ErrPluginRegistered{plugin.Name()}
}
pluginRegistry.opts[plugin.Name()] = plugin
return nil
}
func findPlugin(name string) Plugin {
if opt, found := pluginRegistry.opts[name]; found {
return opt
}
return nil
}
func getPluginRuntimes(conf RuleConf) []pluginRuntime {
plugins := Plugins{}
for _, c := range conf {
plugin := findPlugin(c.Name)
if plugin == nil {
log().Warn("can't find plugin, skip", "name", c.Name)
continue
}
plugins = append(plugins, pluginRuntime{
conf: c,
plugin: plugin,
})
}
sort.Sort(plugins)
return plugins
}
type requestPhase struct {
}
func (ph *requestPhase) filter(conf RuleConf, req *fasthttp.Request, resp *fasthttp.Response) error {
pluginRuntimes := getPluginRuntimes(conf)
for _, pluginRuntime := range pluginRuntimes {
log().Debug("request run plugin", "plugin", pluginRuntime.conf.Name)
err := pluginRuntime.plugin.RequestFilter(pluginRuntime.conf.Value, req, resp)
if err != nil {
log().Error("plugin run request filter err", "plugin", pluginRuntime.conf.Name, "err", err)
return err
}
if resp.StatusCode() != fasthttp.StatusOK {
log().Error("plugin run request filter break", "plugin", pluginRuntime.conf.Name, "statusCode", resp.StatusCode())
break
}
}
return nil
}
// HTTPReqCall http请求的调用
func HTTPReqCall(key string, req *fasthttp.Request, resp *fasthttp.Response) error {
conf, err := GetRuleConf(key)
if err != nil {
return err
}
// 请求阶段
return RequestPhase.filter(conf, req, resp)
}
type responsePhase struct {
}
func (ph *responsePhase) filter(conf RuleConf, w *fasthttp.Response) error {
pluginRuntimes := getPluginRuntimes(conf)
for _, pluginRuntime := range pluginRuntimes {
err := pluginRuntime.plugin.ResponseFilter(pluginRuntime.conf.Value, w)
if err != nil {
log().Error("plugin run response filter err", "plugin", pluginRuntime.conf.Name, "statusCode", w.StatusCode(), "err", err)
return err
}
if w.StatusCode() != fasthttp.StatusOK {
log().Error("plugin run response filter break", "plugin", pluginRuntime.conf.Name, "statusCode", w.StatusCode())
break
}
}
return nil
}
// HTTPRespCall http 响应的调用
func HTTPRespCall(key string, resp *fasthttp.Response) error {
conf, err := GetRuleConf(key)
if err != nil {
return err
}
err = ResponsePhase.filter(conf, resp)
if err != nil {
return err
}
return nil
}