-
Notifications
You must be signed in to change notification settings - Fork 61
/
main.py
executable file
·146 lines (116 loc) · 4.71 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from __future__ import annotations as _annotations
import logging
from datetime import datetime
from pathlib import Path
from textwrap import wrap
from typing import Any
from dnslib import QTYPE, RR, DNSLabel, dns
from dnslib.proxy import ProxyResolver
from dnslib.server import DNSServer as LibDNSServer
from .load_records import Zone, load_records
__all__ = 'DNSServer', 'logger'
SERIAL_NO = int((datetime.utcnow() - datetime(1970, 1, 1)).total_seconds())
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
handler.setFormatter(logging.Formatter('%(asctime)s: %(message)s', datefmt='%H:%M:%S'))
logger = logging.getLogger(__name__)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
TYPE_LOOKUP = {
'A': (dns.A, QTYPE.A),
'AAAA': (dns.AAAA, QTYPE.AAAA),
'CAA': (dns.CAA, QTYPE.CAA),
'CNAME': (dns.CNAME, QTYPE.CNAME),
'DNSKEY': (dns.DNSKEY, QTYPE.DNSKEY),
'MX': (dns.MX, QTYPE.MX),
'NAPTR': (dns.NAPTR, QTYPE.NAPTR),
'NS': (dns.NS, QTYPE.NS),
'PTR': (dns.PTR, QTYPE.PTR),
'RRSIG': (dns.RRSIG, QTYPE.RRSIG),
'SOA': (dns.SOA, QTYPE.SOA),
'SRV': (dns.SRV, QTYPE.SRV),
'TXT': (dns.TXT, QTYPE.TXT),
'SPF': (dns.TXT, QTYPE.TXT),
}
DEFAULT_PORT = 53
DEFAULT_UPSTREAM = '1.1.1.1'
class Record:
def __init__(self, zone: Zone):
self._rname = DNSLabel(zone.host)
rd_cls, self._rtype = TYPE_LOOKUP[zone.type]
args: list[Any]
if isinstance(zone.answer, str):
if self._rtype == QTYPE.TXT:
args = [wrap(zone.answer, 255)]
else:
args = [zone.answer]
else:
if self._rtype == QTYPE.SOA and len(zone.answer) == 2:
# add sensible times to SOA
args = zone.answer + [(SERIAL_NO, 3600, 3600 * 3, 3600 * 24, 3600)]
else:
args = zone.answer
if self._rtype in (QTYPE.NS, QTYPE.SOA):
ttl = 3600 * 24
else:
ttl = 300
self.rr = RR(
rname=self._rname,
rtype=self._rtype,
rdata=rd_cls(*args),
ttl=ttl,
)
def match(self, q):
return q.qname == self._rname and (q.qtype == QTYPE.ANY or q.qtype == self._rtype)
def sub_match(self, q):
return self._rtype == QTYPE.SOA and q.qname.matchSuffix(self._rname)
def __str__(self):
return str(self.rr)
class Resolver(ProxyResolver):
def __init__(self, zones_file: str | Path, upstream: str):
records = load_records(zones_file)
self.records = [Record(zone) for zone in records.zones]
logger.info('loaded %d zone record from %s', len(self.records), zones_file)
super().__init__(address=upstream, port=53, timeout=5)
def resolve(self, request, handler):
type_name = QTYPE[request.q.qtype]
reply = request.reply()
for record in self.records:
if record.match(request.q):
reply.add_answer(record.rr)
if reply.rr:
logger.info('found zone for %s[%s], %d replies', request.q.qname, type_name, len(reply.rr))
return reply
# no direct zone so look for an SOA record for a higher level zone
for record in self.records:
if record.sub_match(request.q):
reply.add_answer(record.rr)
if reply.rr:
logger.info('found higher level SOA resource for %s[%s]', request.q.qname, type_name)
return reply
logger.info('no local zone found, proxying %s[%s]', request.q.qname, type_name)
return super().resolve(request, handler)
class DNSServer:
def __init__(
self, zones_file: str | Path, *, port: int | str | None = DEFAULT_PORT, upstream: str | None = DEFAULT_UPSTREAM
):
self.zones_file = zones_file
self.port: int = DEFAULT_PORT if port is None else int(port)
self.upstream: str = DEFAULT_UPSTREAM if upstream is None else upstream
self.udp_server: LibDNSServer | None = None
self.tcp_server: LibDNSServer | None = None
def start(self):
logger.info('starting DNS server on port %d, upstream DNS server "%s"', self.port, self.upstream)
resolver = Resolver(self.zones_file, self.upstream)
self.udp_server = LibDNSServer(resolver, port=self.port)
self.tcp_server = LibDNSServer(resolver, port=self.port, tcp=True)
self.udp_server.start_thread()
self.tcp_server.start_thread()
def stop(self):
self.udp_server.stop()
self.udp_server.server.server_close()
self.tcp_server.stop()
self.tcp_server.server.server_close()
@property
def is_running(self):
return (self.udp_server and self.udp_server.isAlive()) or (self.tcp_server and self.tcp_server.isAlive())