forked from 0xERR0R/blocky
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcommon.go
221 lines (175 loc) · 5.61 KB
/
common.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
package util
import (
"context"
"encoding/binary"
"fmt"
"io"
"net"
"path/filepath"
"regexp"
"sort"
"strings"
"sync/atomic"
"github.com/0xERR0R/blocky/log"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
)
//nolint:gochecknoglobals
var (
// To avoid making this package depend on config, we use a global
// that is set at config load.
// Ideally we'd move the obfuscate code somewhere else (maybe into `log`),
// but that would require also moving all its dependencies.
// This is good enough for now.
LogPrivacy atomic.Bool
alphanumeric = regexp.MustCompile("[a-zA-Z0-9]")
)
// Obfuscate replaces all alphanumeric characters with * to obfuscate user sensitive data if LogPrivacy is enabled
func Obfuscate(in string) string {
if LogPrivacy.Load() {
return alphanumeric.ReplaceAllString(in, "*")
}
return in
}
// AnswerToString creates a user-friendly representation of an answer
func AnswerToString(answer []dns.RR) string {
answers := make([]string, len(answer))
for i, record := range answer {
switch v := record.(type) {
case *dns.A:
answers[i] = fmt.Sprintf("A (%s)", v.A)
case *dns.AAAA:
answers[i] = fmt.Sprintf("AAAA (%s)", v.AAAA)
case *dns.CNAME:
answers[i] = fmt.Sprintf("CNAME (%s)", v.Target)
case *dns.PTR:
answers[i] = fmt.Sprintf("PTR (%s)", v.Ptr)
default:
answers[i] = fmt.Sprint(record.String())
}
}
return Obfuscate(strings.Join(answers, ", "))
}
// QuestionToString creates a user-friendly representation of a question
func QuestionToString(questions []dns.Question) string {
result := make([]string, len(questions))
for i, question := range questions {
result[i] = fmt.Sprintf("%s (%s)", dns.TypeToString[question.Qtype], question.Name)
}
return Obfuscate(strings.Join(result, ", "))
}
// CreateAnswerFromQuestion creates new answer from a question
func CreateAnswerFromQuestion(question dns.Question, ip net.IP, remainingTTL uint32) (dns.RR, error) {
h := CreateHeader(question, remainingTTL)
switch question.Qtype {
case dns.TypeA:
a := new(dns.A)
a.A = ip
a.Hdr = h
return a, nil
case dns.TypeAAAA:
a := new(dns.AAAA)
a.AAAA = ip
a.Hdr = h
return a, nil
}
log.Log().Errorf("Using fallback for unsupported query type %s", dns.TypeToString[question.Qtype])
return dns.NewRR(fmt.Sprintf("%s %d %s %s %s",
question.Name, remainingTTL, "IN", dns.TypeToString[question.Qtype], ip))
}
// CreateHeader creates DNS header for passed question
func CreateHeader(question dns.Question, remainingTTL uint32) dns.RR_Header {
return dns.RR_Header{Name: question.Name, Rrtype: question.Qtype, Class: dns.ClassINET, Ttl: remainingTTL}
}
// ExtractDomain returns domain string from the question
func ExtractDomain(question dns.Question) string {
return ExtractDomainOnly(question.Name)
}
// ExtractDomainOnly extracts domain from the DNS query
func ExtractDomainOnly(in string) string {
return strings.TrimSuffix(strings.ToLower(in), ".")
}
// NewMsgWithQuestion creates new DNS message with question
func NewMsgWithQuestion(question string, qType dns.Type) *dns.Msg {
msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(question), uint16(qType))
return msg
}
// NewMsgWithAnswer creates new DNS message with answer
func NewMsgWithAnswer(domain string, ttl uint, dnsType dns.Type, address string) (*dns.Msg, error) {
rr, err := dns.NewRR(fmt.Sprintf("%s\t%d\tIN\t%s\t%s", domain, ttl, dnsType, address))
if err != nil {
return nil, err
}
msg := new(dns.Msg)
msg.Answer = []dns.RR{rr}
return msg, nil
}
type kv struct {
key string
value int
}
// IterateValueSorted iterates over maps value in a sorted order and applies the passed function
func IterateValueSorted(in map[string]int, fn func(string, int)) {
ss := make([]kv, 0)
for k, v := range in {
ss = append(ss, kv{k, v})
}
sort.Slice(ss, func(i, j int) bool {
return ss[i].value > ss[j].value || (ss[i].value == ss[j].value && ss[i].key > ss[j].key)
})
for _, kv := range ss {
fn(kv.key, kv.value)
}
}
// LogOnError logs the message only if error is not nil
func LogOnError(ctx context.Context, message string, err error) {
if err != nil {
log.FromCtx(ctx).Error(message, err)
}
}
// LogOnErrorWithEntry logs the message only if error is not nil
func LogOnErrorWithEntry(logEntry *logrus.Entry, message string, err error) {
if err != nil {
logEntry.Error(message, err)
}
}
// FatalOnError logs the message only if error is not nil and exits the program execution
func FatalOnError(message string, err error) {
if err != nil {
logger := log.Log()
// Make sure the error is printend even if the log has been silenced
if logger.Out == io.Discard {
log.ConfigureLogger(logger, log.DefaultConfig())
}
logger.Fatal(message, err)
}
}
// GenerateCacheKey return cacheKey by query type/domain
func GenerateCacheKey(qType dns.Type, qName string) string {
const qTypeLength = 2
b := make([]byte, qTypeLength+len(qName))
binary.BigEndian.PutUint16(b, uint16(qType))
copy(b[2:], strings.ToLower(qName))
return string(b)
}
// ExtractCacheKey return query type/domain from cacheKey
func ExtractCacheKey(key string) (qType dns.Type, qName string) {
b := []byte(key)
qType = dns.Type(binary.BigEndian.Uint16(b))
qName = string(b[2:])
return
}
// CidrContainsIP checks if CIDR contains a single IP
func CidrContainsIP(cidr string, ip net.IP) bool {
_, ipnet, err := net.ParseCIDR(cidr)
if err != nil {
return false
}
return ipnet.Contains(ip)
}
// ClientNameMatchesGroupName checks if a group with optional wildcards contains a client name
func ClientNameMatchesGroupName(group, clientName string) bool {
match, _ := filepath.Match(strings.ToLower(group), strings.ToLower(clientName))
return match
}