forked from gophergala/dnsp
/
server.go
205 lines (191 loc) · 5.39 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
package ddns
import (
"log"
"log/syslog"
"net"
"fmt"
"time"
"strings"
"io/ioutil"
"github.com/miekg/dns"
"github.com/stutiredboy/radix.v2/pool"
"github.com/stutiredboy/radix.v2/redis"
)
type qinfo struct {
name string
addr net.Addr
}
// Server implements a DNS server.
type Server struct {
c *dns.Client
s *dns.Server
pools map[int]*pool.Pool
/* current queries counter */
currQueries int64
/* last queries counter for qps */
lastQueries int64
/* current failed counter */
currFailed int64
/* last failed counter */
lastFailed int64
/* failedRate */
failedRate float64
sysLog *syslog.Writer
logChan map[int]map[int]chan qinfo
lenBackends int
ExpiresIn int
}
// NewServer creates a new Server with the given options.
func NewServer(c Configurations) (*Server, error) {
if err := c.validate(); err != nil {
return nil, err
}
connectTimeout := time.Millisecond * time.Duration(c.ConnectTimeout)
readTimeout := time.Millisecond * time.Duration(c.ReadTimeout)
if c.Debug {
log.Printf("create redis pool with connectTimeout: %s, readTimeout: %s", connectTimeout, readTimeout)
}
pools := make(map[int]*pool.Pool)
logChan := make(map[int]map[int]chan qinfo)
for index, backend := range c.Backends {
p, err := pool.NewCustom("tcp", backend, c.PoolNum, connectTimeout, readTimeout, redis.DialTimeout)
if err != nil {
return nil, err
}
pools[index] = p
_logChan := make(map[int]chan qinfo)
for i := 0; i < c.ChanNum ; i++ {
_logChan[i] = make(chan qinfo, 10)
}
logChan[index] = _logChan
}
sysLog, err := syslog.Dial("unixgram", "/dev/log", syslog.LOG_DEBUG|syslog.LOG_LOCAL5, "ddns")
if err != nil {
return nil, err
}
s := Server{
c: &dns.Client{},
s: &dns.Server{
Net: "udp",
Addr: c.Listen,
},
pools: pools,
currQueries: 0,
lastQueries: 0,
currFailed: 0,
lastFailed: 0,
failedRate: 0.0,
sysLog: sysLog,
logChan: logChan,
lenBackends: len(c.Backends),
ExpiresIn: c.ExpiresIn,
}
s.s.Handler = dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
// If no upstream proxy is present, drop the query:
if len(c.NameServers) == 0 {
log.Printf("no nameservers, drop query")
dns.HandleFailed(w, r)
return
}
if c.Debug {
ecs := GetEdns0Subnet(r)
log.Printf("query %+v from %s msg %+v with ecs %s", r.Question, w.RemoteAddr(), r.MsgHdr, ecs.String())
}
/* r == nil:
panic: runtime error: invalid memory address or nil pointer dereference
*/
if r == nil {
log.Printf("dns Msg is nil, ignore it.")
return
}
if r.Question == nil || len(r.Question) == 0 {
log.Printf("no query Question, drop query")
dns.HandleFailed(w, r)
return
}
// send query info to channel
name := strings.ToLower(strings.TrimSuffix(r.Question[0].Name, "."))
// Backend and Channel must use different hash method
backendIndex := backendHash(name) % s.lenBackends
chanIndex := channelHash(name) % c.ChanNum
select {
case s.logChan[backendIndex][chanIndex] <- qinfo{name, w.RemoteAddr()}:
default:
s.currFailed++
log.Printf("receive query %s %s, but backend %d channel%d full", r.Question[0].Name, w.RemoteAddr(), backendIndex, chanIndex)
}
// increase queries counter
s.currQueries++
// Proxy Query:
for _, addr := range c.NameServers {
in, _, err := s.c.Exchange(r, addr)
if err != nil {
continue
}
w.WriteMsg(in)
return
}
dns.HandleFailed(w, r)
})
return &s, nil
}
// ListenAndServe runs the server
func (s *Server) ListenAndServe() error {
return s.s.ListenAndServe()
}
// Shutdown stops the server, closing its connection.
func (s *Server) Shutdown() error {
return s.s.Shutdown()
}
// Dump the stats of ddns
func (s *Server) Dump(period int, saveto string) {
qps := (s.currQueries - s.lastQueries) / int64(period)
if qps > 0 {
s.failedRate = float64(s.currFailed - s.lastFailed) / float64(s.currQueries - s.lastQueries)
}
log.Printf("total queries: %d, qps: %d, log failed: %d, failed rate: %f", s.currQueries, qps, s.currFailed, s.failedRate)
if saveto != "" {
err := ioutil.WriteFile(saveto, []byte(fmt.Sprintf("total queries: %d\nlog failed: %d\nfailed rate: %f", s.currQueries, s.currFailed, s.failedRate)), 644)
if err != nil {
log.Printf("dump statistics to %s err: %s", saveto, err)
}
}
s.lastQueries = s.currQueries
s.lastFailed = s.currFailed
}
func (s *Server) log2b(name string, addr net.Addr, backendIndex int) error {
// trimsuffix and lowercase
// name = strings.ToLower(strings.TrimSuffix(name, "."))
clientip, _, err := net.SplitHostPort(addr.String())
if err != nil {
return err
}
s.sysLog.Debug(fmt.Sprintf("query %s from %s", name, clientip))
err = s.pools[backendIndex].Cmd("SETEX", name, s.ExpiresIn, clientip).Err
return err
}
// Log2b log quureies to backend by different channel/backend
func (s *Server) Log2b(backendIndex int, chanIndex int) {
log.Printf("listening to backend %d channel %d" , backendIndex, chanIndex)
for {
query := <- s.logChan[backendIndex][chanIndex]
err := s.log2b(query.name, query.addr, backendIndex)
if err != nil {
log.Printf("backend %d channel %d log2b %s %s raise err: %s", backendIndex, chanIndex, query.name, query.addr, err)
}
}
}
// GetEdns0Subnet get ecs from query msg
func GetEdns0Subnet(query *dns.Msg) net.IP {
opt := query.IsEdns0()
if opt == nil {
return nil
}
for _, s := range opt.Option {
switch e := s.(type) {
case *dns.EDNS0_SUBNET:
return e.Address
}
}
return nil
}