forked from onloner/goagent
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdnsproxy.py
227 lines (202 loc) · 9.6 KB
/
dnsproxy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
#!/usr/bin/env python
# coding:utf-8
# TODO: 1. sort reply rdata by ip latency
# 3. reduce socket fd usage
__version__ = '1.0'
import sys
import os
import glob
sys.path += glob.glob('*.egg')
import gevent
import gevent.server
import gevent.timeout
import gevent.monkey
gevent.monkey.patch_all(subprocess=True)
import time
import logging
import heapq
import socket
import select
import re
import dnslib
try:
import pygeoip
except ImportError:
pygeoip = None
def get_dnsserver_list():
if os.name == 'nt':
import ctypes, ctypes.wintypes, struct, socket
DNS_CONFIG_DNS_SERVER_LIST = 6
buf = ctypes.create_string_buffer(2048)
ctypes.windll.dnsapi.DnsQueryConfig(DNS_CONFIG_DNS_SERVER_LIST, 0, None, None, ctypes.byref(buf), ctypes.byref(ctypes.wintypes.DWORD(len(buf))))
ips = struct.unpack('I', buf[0:4])[0]
out = []
for i in xrange(ips):
start = (i+1) * 4
out.append(socket.inet_ntoa(buf[start:start+4]))
return out
elif os.path.isfile('/etc/resolv.conf'):
with open('/etc/resolv.conf', 'rb') as fp:
return re.findall(r'(?m)^nameserver\s+(\S+)', fp.read())
else:
logging.warning("get_dnsserver_list failed: unsupport platform '%s-%s'", sys.platform, os.name)
return []
class ExpireCache(object):
""" A dictionary-like object, supporting expire semantics."""
def __init__(self, max_size=1024):
self.__maxsize = max_size
self.__values = {}
self.__expire_times = {}
self.__expire_heap = []
def size(self):
return len(self.__values)
def clear(self):
self.__values.clear()
self.__expire_times.clear()
del self.__expire_heap[:]
def exists(self, key):
return key in self.__values
def set(self, key, value, expire):
try:
et = self.__expire_times[key]
pos = self.__expire_heap.index((et, key))
del self.__expire_heap[pos]
if pos < len(self.__expire_heap):
heapq._siftup(self.__expire_heap, pos)
except KeyError:
pass
et = int(time.time() + expire)
self.__expire_times[key] = et
heapq.heappush(self.__expire_heap, (et, key))
self.__values[key] = value
self.cleanup()
def get(self, key):
et = self.__expire_times[key]
if et < time.time():
self.cleanup()
raise KeyError(key)
return self.__values[key]
def delete(self, key):
et = self.__expire_times.pop(key)
pos = self.__expire_heap.index((et, key))
del self.__expire_heap[pos]
if pos < len(self.__expire_heap):
heapq._siftup(self.__expire_heap, pos)
del self.__values[key]
def cleanup(self):
t = int(time.time())
eh = self.__expire_heap
ets = self.__expire_times
v = self.__values
size = self.__maxsize
heappop = heapq.heappop
#Delete expired, ticky
while eh and eh[0][0] <= t or len(v) > size:
_, key = heappop(eh)
del v[key], ets[key]
class DNSServer(gevent.server.DatagramServer):
"""DNS Proxy based on gevent/dnslib"""
is_local_addr = re.compile(r'(?i)(?:[0-9a-f:]+0:5efe:)?(?:127(?:\.\d+){3}|10(?:\.\d+){3}|192\.168(?:\.\d+){2}|172\.(?:1[6-9]|2\d|3[01])(?:\.\d+){2})').match
def __init__(self, *args, **kwargs):
dns_blacklist = kwargs.pop('dns_blacklist')
dns_servers = kwargs.pop('dns_servers')
dns_timeout = kwargs.pop('dns_timeout', 2)
super(self.__class__, self).__init__(*args, **kwargs)
self.dns_servers = dns_servers
self.dns_v4_servers = [x for x in self.dns_servers if ':' not in x]
self.dns_v6_servers = [x for x in self.dns_servers if ':' in x]
self.dns_intranet_servers = set([x for x in self.dns_servers if self.is_local_addr(x)])
self.dns_blacklist = set(dns_blacklist)
self.dns_timeout = int(dns_timeout)
self.dns_cache = ExpireCache(max_size=65536)
self.dns_trust_servers = set(['8.8.8.8', '8.8.4.4', '2001:4860:4860::8888', '2001:4860:4860::8844'])
if pygeoip:
for dirname in ('.', '/usr/share/GeoIP/', '/usr/local/share/GeoIP/'):
filename = os.path.join(dirname, 'GeoIP.dat')
if os.path.isfile(filename):
geoip = pygeoip.GeoIP(filename)
for dnsserver in self.dns_servers:
if ':' not in dnsserver and geoip.country_name_by_addr(dnsserver) not in ('China',):
self.dns_trust_servers.add(dnsserver)
break
def handle(self, data, address):
logging.debug('receive from %r data=%r', address, data)
request = dnslib.DNSRecord.parse(data)
qname = str(request.q.qname)
qtype = request.q.qtype
try:
reply_data = self.dns_cache.get((qname, qtype))
except KeyError:
reply_data = ''
sock_v4 = sock_v6 = None
socks = []
is_local_hostname = '.' not in qname
if 'USERDNSDOMAIN' in os.environ:
is_local_hostname = qname.lower().endswith('.' + os.environ['USERDNSDOMAIN'].lower())
if is_local_hostname and not self.dns_intranet_servers:
logging.warning('qname=%r is a plain hostname, need intranet dns server!!!', qname)
reply = dnslib.DNSRecord(header=dnslib.DNSHeader(id=request.header.id, rcode=3))
self.sendto(reply.pack(), address)
return
dns_v4_servers = self.dns_v4_servers if not is_local_hostname else [x for x in self.dns_intranet_servers if ':' not in x]
dns_v6_servers = self.dns_v6_servers if not is_local_hostname else [x for x in self.dns_intranet_servers if ':' in x]
if dns_v4_servers:
sock_v4 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
socks.append(sock_v4)
if dns_v6_servers:
sock_v6 = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
socks.append(sock_v6)
for _ in xrange(2):
if reply_data:
break
try:
need_reply_servers = set()
for dnsserver in dns_v4_servers:
sock_v4.sendto(data, (dnsserver, 53))
need_reply_servers.add(dnsserver)
for dnsserver in dns_v6_servers:
sock_v6.sendto(data, (dnsserver, 53))
need_reply_servers.add(dnsserver)
timeout_at = time.time() + self.dns_timeout
while time.time() < timeout_at:
if reply_data:
break
ins, _, _ = select.select(socks, [], [], 0.1)
for sock in ins:
reply_data, (reply_server, _) = sock.recvfrom(512)
reply = dnslib.DNSRecord.parse(reply_data)
iplist = [str(x.rdata) for x in reply.rr]
if any(x in self.dns_blacklist for x in iplist):
logging.warning('query qname=%r reply bad iplist=%r, continue', qname, iplist)
reply_data = ''
continue
if reply.header.rcode and not iplist and need_reply_servers and reply_server not in self.dns_trust_servers:
try:
need_reply_servers.remove(reply_server)
except KeyError:
pass
if need_reply_servers:
logging.warning('query qname=%r qtype=%r reply nonzero rcode=%r, wait other need_reply_servers=%s, continue', qname, qtype, reply.header.rcode, need_reply_servers)
reply_data = ''
continue
else:
logging.info('query qname=%r qtype=%r reply nonzero rcode=%r', qname, qtype, reply.header.rcode)
ttl = max(x.ttl for x in reply.rr) if reply.rr else 600
logging.debug('query qname=%r qtype=%r reply_server=%r reply iplist=%s, ttl=%r', qname, qtype, reply_server, iplist, ttl)
if iplist or qname.endswith('.in-addr.arpa'):
self.dns_cache.set((qname, qtype), reply_data, ttl*2)
break
except socket.error as e:
logging.warning('handle dns data=%r socket: %r', data, e)
for sock in socks:
sock.close()
if reply_data:
return self.sendto(data[:2] + reply_data[2:], address)
def test():
logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(asctime)s %(message)s', datefmt='[%b %d %H:%M:%S]')
dns_servers = ['114.114.114.114', '114.114.115.115', '8.8.8.8', '8.8.4.4']
dns_blacklist = '1.1.1.1|255.255.255.255|74.125.127.102|74.125.155.102|74.125.39.102|74.125.39.113|209.85.229.138|4.36.66.178|8.7.198.45|37.61.54.158|46.82.174.68|59.24.3.173|64.33.88.161|64.33.99.47|64.66.163.251|65.104.202.252|65.160.219.113|66.45.252.237|72.14.205.104|72.14.205.99|78.16.49.15|93.46.8.89|128.121.126.139|159.106.121.75|169.132.13.103|192.67.198.6|202.106.1.2|202.181.7.85|203.161.230.171|203.98.7.65|207.12.88.98|208.56.31.43|209.145.54.50|209.220.30.174|209.36.73.33|209.85.229.138|211.94.66.147|213.169.251.35|216.221.188.182|216.234.179.13|243.185.187.3|243.185.187.39'.split('|')
logging.info('serving at port 53...')
DNSServer(('', 53), dns_servers=dns_servers, dns_blacklist=dns_blacklist).serve_forever()
if __name__ == '__main__':
test()