Skip to content

Commit

Permalink
add udp relay
Browse files Browse the repository at this point in the history
  • Loading branch information
xjdrew committed May 18, 2016
1 parent 76420b6 commit f17bf53
Show file tree
Hide file tree
Showing 12 changed files with 482 additions and 211 deletions.
6 changes: 3 additions & 3 deletions k1/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,9 @@ func ParseConfig(filename string) (*KoneConfig, error) {
cfg.TCP.NatPortStart = 10000
cfg.TCP.NatPortEnd = 60000

cfg.TCP.ListenPort = 82
cfg.TCP.NatPortStart = 10000
cfg.TCP.NatPortEnd = 60000
cfg.UDP.ListenPort = 82
cfg.UDP.NatPortStart = 10000
cfg.UDP.NatPortEnd = 60000

cfg.Dns.DnsPort = dnsDefaultPort
cfg.Dns.DnsTtl = dnsDefaultTtl
Expand Down
50 changes: 35 additions & 15 deletions k1/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,26 @@ func (d *Dns) resolve(r *dns.Msg) (*dns.Msg, error) {
}
}

func (d *Dns) fillRealIP(record *DomainRecord, r *dns.Msg) {
// resolve
msg, err := d.resolve(r)
if err != nil || len(msg.Answer) == 0 {
return
}
record.SetRealIP(msg)
}

func (d *Dns) doIPv4Query(r *dns.Msg) (*dns.Msg, error) {
one := d.one

domain := dnsutil.TrimDomainName(r.Question[0].Name, ".")
// if is a non-proxy-domain
if one.dnsCache.IsNonProxyDomain(domain) {
if one.dnsTable.IsNonProxyDomain(domain) {
return d.resolve(r)
}

// if have already hijacked
record := one.dnsCache.Get(domain)
record := one.dnsTable.Get(domain)
if record != nil {
return record.Answer(r), nil
}
Expand All @@ -108,7 +117,8 @@ func (d *Dns) doIPv4Query(r *dns.Msg) (*dns.Msg, error) {

// if domain use proxy
if matched && proxy != "" {
if record := one.dnsCache.Set(domain, proxy); record != nil {
if record := one.dnsTable.Set(domain, proxy); record != nil {
go d.fillRealIP(record, r)
return record.Answer(r), nil
}
}
Expand All @@ -119,25 +129,35 @@ func (d *Dns) doIPv4Query(r *dns.Msg) (*dns.Msg, error) {
return msg, err
}

// match by ip
if !matched {
if answer, ok := msg.Answer[0].(*dns.A); ok {
// test ip
_, proxy = one.rule.Proxy(answer.A)

// if ip use proxy
if proxy != "" {
if record := one.dnsCache.Set(domain, proxy); record != nil {
return record.Answer(r), nil
// try match by cname and ip
for _, item := range msg.Answer {
switch answer := item.(type) {
case *dns.A:
// test ip
_, proxy = one.rule.Proxy(answer.A)
break
case *dns.CNAME:
// test cname
matched, proxy = one.rule.Proxy(answer.Target)
if matched && proxy != "" {
break
}
default:
logger.Noticef("[dns] unexpected response %s -> %v", domain, item)
}
}
// if ip use proxy
if proxy != "" {
if record := one.dnsTable.Set(domain, proxy); record != nil {
record.SetRealIP(msg)
return record.Answer(r), nil
}
} else {
logger.Noticef("[dns] unexpected response %s -> %v", domain, msg.Answer[0])
}
}

// set domain as a non-proxy-domain
one.dnsCache.SetNonProxyDomain(domain, msg.Answer[0].Header().Ttl)
one.dnsTable.SetNonProxyDomain(domain, msg.Answer[0].Header().Ttl)

// final
return msg, err
Expand Down
46 changes: 32 additions & 14 deletions k1/dns_cache.go → k1/dns_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,37 @@ func NewIPv4Space(subnet *net.IPNet) *IPv4Space {
return space
}

// hijacked domain
type DomainRecord struct {
ip net.IP // nat ip
domain string // domain name
proxy string // proxy

realIP net.IP // real domin ip
ip net.IP // nat ip
realIP net.IP // real ip

answer *dns.A // cache dns answer

touch time.Time
hit int
}

func (record *DomainRecord) SetRealIP(msg *dns.Msg) {
if record.realIP != nil {
return
}

var ip net.IP
for _, item := range msg.Answer {
switch answer := item.(type) {
case *dns.A:
ip = answer.A
break
}
}
record.realIP = ip
logger.Debugf("[dns] %s real ip: %s", record.domain, ip)
}

func (record *DomainRecord) Answer(request *dns.Msg) *dns.Msg {
rsp := new(dns.Msg)
rsp.SetReply(request)
Expand All @@ -83,7 +101,7 @@ func (record *DomainRecord) Touch() {
record.touch = time.Now()
}

type DnsCache struct {
type DnsTable struct {
// dns ip space
ipSpace *IPv4Space

Expand All @@ -103,15 +121,15 @@ type DnsCache struct {
npdLock sync.Mutex
}

func (c *DnsCache) get(domain string) *DomainRecord {
func (c *DnsTable) get(domain string) *DomainRecord {
record := c.records[domain]
if record != nil {
record.Touch()
}
return record
}

func (c *DnsCache) GetByIP(ip net.IP) *DomainRecord {
func (c *DnsTable) GetByIP(ip net.IP) *DomainRecord {
c.recordsLock.Lock()
defer c.recordsLock.Unlock()
if domain, ok := c.ip2Domain[ip.String()]; ok {
Expand All @@ -120,7 +138,7 @@ func (c *DnsCache) GetByIP(ip net.IP) *DomainRecord {
return nil
}

func (c *DnsCache) Get(domain string) *DomainRecord {
func (c *DnsTable) Get(domain string) *DomainRecord {
c.recordsLock.Lock()
defer c.recordsLock.Unlock()
return c.get(domain)
Expand All @@ -134,7 +152,7 @@ func forgeIPv4Answer(domain string, ip net.IP) *dns.A {
return rr
}

func (c *DnsCache) Set(domain string, proxy string) *DomainRecord {
func (c *DnsTable) Set(domain string, proxy string) *DomainRecord {
c.recordsLock.Lock()
defer c.recordsLock.Unlock()
record := c.records[domain]
Expand Down Expand Up @@ -163,21 +181,21 @@ func (c *DnsCache) Set(domain string, proxy string) *DomainRecord {
return record
}

func (c *DnsCache) IsNonProxyDomain(domain string) bool {
func (c *DnsTable) IsNonProxyDomain(domain string) bool {
c.npdLock.Lock()
defer c.npdLock.Unlock()
_, ok := c.nonProxyDomains[domain]
return ok
}

func (c *DnsCache) SetNonProxyDomain(domain string, ttl uint32) {
func (c *DnsTable) SetNonProxyDomain(domain string, ttl uint32) {
c.npdLock.Lock()
defer c.npdLock.Unlock()
c.nonProxyDomains[domain] = time.Now().Add(time.Duration(ttl) * time.Second)
logger.Debugf("[dns] set non proxy domain: %s, ttl: %d", domain, ttl)
}

func (c *DnsCache) clearExpiredNonProxyDomain(now time.Time) {
func (c *DnsTable) clearExpiredNonProxyDomain(now time.Time) {
c.npdLock.Lock()
defer c.npdLock.Unlock()
for domain, expired := range c.nonProxyDomains {
Expand All @@ -188,7 +206,7 @@ func (c *DnsCache) clearExpiredNonProxyDomain(now time.Time) {
}
}

func (c *DnsCache) clearExpiredDomain(now time.Time) {
func (c *DnsTable) clearExpiredDomain(now time.Time) {
c.recordsLock.Lock()
defer c.recordsLock.Unlock()

Expand All @@ -204,7 +222,7 @@ func (c *DnsCache) clearExpiredDomain(now time.Time) {
}
}

func (c *DnsCache) Serve() error {
func (c *DnsTable) Serve() error {
tick := time.Tick(60 * time.Second)
for now := range tick {
c.clearExpiredDomain(now)
Expand All @@ -213,8 +231,8 @@ func (c *DnsCache) Serve() error {
return nil
}

func NewDnsCache(subnet *net.IPNet) *DnsCache {
c := new(DnsCache)
func NewDnsTable(subnet *net.IPNet) *DnsTable {
c := new(DnsTable)
c.ipSpace = NewIPv4Space(subnet)
c.records = make(map[string]*DomainRecord)
c.ip2Domain = make(map[string]string)
Expand Down
18 changes: 10 additions & 8 deletions k1/filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@
package k1

import (
"io"

"github.com/xjdrew/kone/tcpip"
)

type udpFilter struct {
type PacketFilter interface {
Filter(wr io.Writer, p tcpip.IPv4Packet)
}

func (uf *udpFilter) Filter(p *tcpip.IPv4Packet) bool {
return false
type PacketFilterFunc func(wr io.Writer, p tcpip.IPv4Packet)

func (f PacketFilterFunc) Filter(wr io.Writer, p tcpip.IPv4Packet) {
f(wr, p)
}

func icmpFilterFunc(p *tcpip.IPv4Packet) bool {
ipPacket := *p
func icmpFilterFunc(wr io.Writer, ipPacket tcpip.IPv4Packet) {
icmpPacket := tcpip.ICMPPacket(ipPacket.Payload())
if icmpPacket.Type() == tcpip.ICMPRequest && icmpPacket.Code() == 0 {
logger.Debugf("icmp echo request: %s -> %s", ipPacket.SourceIP(), ipPacket.DestinationIP())
Expand All @@ -30,10 +34,8 @@ func icmpFilterFunc(p *tcpip.IPv4Packet) bool {

icmpPacket.ResetChecksum()
ipPacket.ResetChecksum()
p = &ipPacket
return true
wr.Write(ipPacket)
} else {
logger.Debugf("icmp: %s -> %s", ipPacket.SourceIP(), ipPacket.DestinationIP())
return false
}
}
30 changes: 16 additions & 14 deletions k1/one.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@ type One struct {
// tun virtual network
subnet *net.IPNet

rule *Rule
dnsCache *DnsCache
dns *Dns
proxies *Proxies
tcpForwarder *TCPForwarder
tun *TunDriver
rule *Rule
dnsTable *DnsTable
proxies *Proxies

dns *Dns
tcpRelay *TCPRelay
udpRelay *UDPRelay
tun *TunDriver
}

func (one *One) Serve() error {
Expand All @@ -37,9 +39,10 @@ func (one *One) Serve() error {
}
}

go runAndWait(one.dnsCache.Serve)
go runAndWait(one.dnsTable.Serve)
go runAndWait(one.dns.Serve)
go runAndWait(one.tcpForwarder.Serve)
go runAndWait(one.tcpRelay.Serve)
go runAndWait(one.udpRelay.Serve)
go runAndWait(one.tun.Serve)
return <-done
}
Expand All @@ -61,7 +64,7 @@ func FromConfig(cfg *KoneConfig) (*One, error) {
one.rule = NewRule(cfg.Rule, cfg.Pattern)

// new dns cache
one.dnsCache = NewDnsCache(subnet)
one.dnsTable = NewDnsTable(subnet)

var err error

Expand All @@ -74,14 +77,13 @@ func FromConfig(cfg *KoneConfig) (*One, error) {
return nil, err
}

if one.tcpForwarder, err = NewTCPForwarder(one, cfg.TCP); err != nil {
return nil, err
}
one.tcpRelay = NewTCPRelay(one, cfg.TCP)
one.udpRelay = NewUDPRelay(one, cfg.UDP)

filters := map[tcpip.IPProtocol]PacketFilter{
tcpip.ICMP: PacketFilterFunc(icmpFilterFunc),
tcpip.TCP: one.tcpForwarder,
//tcpip.UDP: &udpFilter{},
tcpip.TCP: one.tcpRelay,
tcpip.UDP: one.udpRelay,
}

if one.tun, err = NewTunDriver(name, ip, subnet, filters); err != nil {
Expand Down
Loading

0 comments on commit f17bf53

Please sign in to comment.