Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

doh route strategy optimized #2305

Merged
merged 2 commits into from Mar 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 10 additions & 7 deletions app/dispatcher/default.go
Expand Up @@ -258,7 +258,13 @@ func sniffer(ctx context.Context, cReader *cachedReader) (SniffResult, error) {

func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) {
var handler outbound.Handler
if d.router != nil {

skipRoutePick := false
if content := session.ContentFromContext(ctx); content != nil {
skipRoutePick = content.SkipRoutePick
}

if d.router != nil && !skipRoutePick {
if tag, err := d.router.PickRoute(ctx); err == nil {
if h := d.ohm.GetHandler(tag); h != nil {
newError("taking detour [", tag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx))
Expand All @@ -282,12 +288,9 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.
return
}

accessMessage := log.AccessMessageFromContext(ctx)
if accessMessage != nil {
if len(handler.Tag()) > 0 {
accessMessage.Detour = handler.Tag()
} else {
accessMessage.Detour = ""
if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil {
if tag := handler.Tag(); tag != "" {
accessMessage.Detour = tag
}
log.Record(accessMessage)
}
Expand Down
105 changes: 43 additions & 62 deletions app/dns/dohdns.go
Expand Up @@ -6,32 +6,32 @@ import (
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"sync"
"sync/atomic"
"time"

dns_feature "v2ray.com/core/features/dns"

"golang.org/x/net/dns/dnsmessage"
"v2ray.com/core/common"
"v2ray.com/core/common/dice"
"v2ray.com/core/common/net"
"v2ray.com/core/common/protocol/dns"
"v2ray.com/core/common/session"
"v2ray.com/core/common/signal/pubsub"
"v2ray.com/core/common/task"
"v2ray.com/core/features/routing"
"v2ray.com/core/transport/internet"
)

// DoHNameServer implimented DNS over HTTPS (RFC8484) Wire Format,
// which is compatiable with traditional dns over udp(RFC1035),
// thus most of the DOH implimentation is copied from udpns.go
type DoHNameServer struct {
sync.RWMutex
dispatcher routing.Dispatcher
dohDests []net.Destination
ips map[string]record
pub *pubsub.Service
cleanup *task.Periodic
Expand All @@ -45,41 +45,8 @@ type DoHNameServer struct {
// NewDoHNameServer creates DOH client object for remote resolving
func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher, clientIP net.IP) (*DoHNameServer, error) {

dohAddr := net.ParseAddress(url.Hostname())
dohPort := "443"
if url.Port() != "" {
dohPort = url.Port()
}

parseIPDest := func(ip net.IP, port string) net.Destination {
strIP := ip.String()
if len(ip) == net.IPv6len {
strIP = fmt.Sprintf("[%s]", strIP)
}
dest, err := net.ParseDestination(fmt.Sprintf("tcp:%s:%s", strIP, port))
common.Must(err)
return dest
}

var dests []net.Destination
if dohAddr.Family().IsDomain() {
// resolve DOH server in advance
ips, err := net.LookupIP(dohAddr.Domain())
if err != nil || len(ips) == 0 {
return nil, err
}
for _, ip := range ips {
dests = append(dests, parseIPDest(ip, dohPort))
}
} else {
ip := dohAddr.IP()
dests = append(dests, parseIPDest(ip, dohPort))
}

newError("DNS: created Remote DOH client for ", url.String(), ", preresolved Dests: ", dests).AtInfo().WriteToLog()
newError("DNS: created Remote DOH client for ", url.String()).AtInfo().WriteToLog()
s := baseDOHNameServer(url, "DOH", clientIP)
s.dispatcher = dispatcher
s.dohDests = dests

// Dispatched connection will be closed (interupted) after each request
// This makes DOH inefficient without a keeped-alive connection
Expand All @@ -88,15 +55,29 @@ func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher, clientIP net.
// Recommand to use NewDoHLocalNameServer (DOHL:) if v2ray instance is running on
// a normal network eg. the server side of v2ray
tr := &http.Transport{
MaxIdleConns: 10,
MaxIdleConns: 30,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
DialContext: s.DialContext,
TLSHandshakeTimeout: 30 * time.Second,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
dest, err := net.ParseDestination(network + ":" + addr)
if err != nil {
return nil, err
}

link, err := dispatcher.Dispatch(ctx, dest)
if err != nil {
return nil, err
}
return net.NewConnection(
net.ConnectionInputMulti(link.Writer),
net.ConnectionOutputMulti(link.Reader),
), nil
},
}

dispatchedClient := &http.Client{
Transport: tr,
Timeout: 16 * time.Second,
Timeout: 60 * time.Second,
}

s.httpClient = dispatchedClient
Expand All @@ -107,8 +88,23 @@ func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher, clientIP net.
func NewDoHLocalNameServer(url *url.URL, clientIP net.IP) *DoHNameServer {
url.Scheme = "https"
s := baseDOHNameServer(url, "DOHL", clientIP)
tr := &http.Transport{
IdleConnTimeout: 90 * time.Second,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
dest, err := net.ParseDestination(network + ":" + addr)
if err != nil {
return nil, err
}
conn, err := internet.DialSystem(ctx, dest, nil)
if err != nil {
return nil, err
}
return conn, nil
},
}
s.httpClient = &http.Client{
Timeout: time.Second * 180,
Timeout: time.Second * 180,
Transport: tr,
}
newError("DNS: created Local DOH client for ", url.String()).AtInfo().WriteToLog()
return s
Expand All @@ -120,7 +116,7 @@ func baseDOHNameServer(url *url.URL, prefix string, clientIP net.IP) *DoHNameSer
ips: make(map[string]record),
clientIP: clientIP,
pub: pubsub.NewService(),
name: fmt.Sprintf("%s//%s", prefix, url.Host),
name: prefix + "//" + url.Host,
dohURL: url.String(),
}
s.cleanup = &task.Periodic{
Expand All @@ -136,21 +132,6 @@ func (s *DoHNameServer) Name() string {
return s.name
}

// DialContext offer dispatched connection through core routing
func (s *DoHNameServer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {

dest := s.dohDests[dice.Roll(len(s.dohDests))]

link, err := s.dispatcher.Dispatch(ctx, dest)
if err != nil {
return nil, err
}
return net.NewConnection(
net.ConnectionInputMulti(link.Writer),
net.ConnectionOutputMulti(link.Reader),
), nil
}

// Cleanup clears expired items from cache
func (s *DoHNameServer) Cleanup() error {
now := time.Now()
Expand Down Expand Up @@ -255,7 +236,8 @@ func (s *DoHNameServer) sendQuery(ctx context.Context, domain string, option IPO
}

dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{
Protocol: "https",
Protocol: "https",
SkipRoutePick: true,
})

// forced to use mux for DOH
Expand Down Expand Up @@ -297,10 +279,9 @@ func (s *DoHNameServer) dohHTTPSContext(ctx context.Context, b []byte) ([]byte,
}

defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
err = fmt.Errorf("DOH HTTPS server returned with non-OK code %d", resp.StatusCode)
return nil, err
io.Copy(ioutil.Discard, resp.Body) // flush resp.Body so that the conn is reusable
return nil, fmt.Errorf("DOH server returned code %d", resp.StatusCode)
}

return ioutil.ReadAll(resp.Body)
Expand Down
2 changes: 2 additions & 0 deletions common/session/session.go
Expand Up @@ -68,6 +68,8 @@ type Content struct {
SniffingRequest SniffingRequest

Attributes map[string]interface{}

SkipRoutePick bool
}

func (c *Content) SetAttribute(name string, value interface{}) {
Expand Down