Skip to content

Commit

Permalink
add tls support
Browse files Browse the repository at this point in the history
TODO add testcase
  • Loading branch information
徐志强 committed Jan 12, 2020
1 parent c6bd483 commit 01ff2d5
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 9 deletions.
16 changes: 14 additions & 2 deletions clientconn.go
Expand Up @@ -2,6 +2,7 @@ package qrpc

import (
"context"
"crypto/tls"
"errors"
mathrand "math/rand"
"net"
Expand Down Expand Up @@ -98,9 +99,9 @@ func (r *response) Close() {
func NewConnection(addr string, conf ConnectionConfig, f SubFunc) (conn *Connection, err error) {
var rwc net.Conn
if conf.OverlayNetwork != nil {
rwc, err = conf.OverlayNetwork(addr, DialConfig{DialTimeout: conf.DialTimeout, WBufSize: conf.WBufSize, RBufSize: conf.RBufSize})
rwc, err = conf.OverlayNetwork(addr, DialConfig{DialTimeout: conf.DialTimeout, WBufSize: conf.WBufSize, RBufSize: conf.RBufSize, TLSConf: conf.TLSConf})
} else {
rwc, err = dialTCP(addr, DialConfig{DialTimeout: conf.DialTimeout, WBufSize: conf.WBufSize, RBufSize: conf.RBufSize})
rwc, err = dialTCP(addr, DialConfig{DialTimeout: conf.DialTimeout, WBufSize: conf.WBufSize, RBufSize: conf.RBufSize, TLSConf: conf.TLSConf})
}

if err != nil {
Expand Down Expand Up @@ -137,6 +138,17 @@ func dialTCP(addr string, dialConfig DialConfig) (rwc net.Conn, err error) {
}
}

if dialConfig.TLSConf != nil {
tlsConn := tls.Client(rwc, dialConfig.TLSConf)
err = tlsConn.Handshake()
if err != nil {
rwc.Close()
return
}

rwc = tlsConn
}

return
}

Expand Down
7 changes: 6 additions & 1 deletion conf.go
Expand Up @@ -4,6 +4,8 @@ import (
"net"
"time"

"crypto/tls"

"github.com/go-kit/kit/metrics"
)

Expand All @@ -28,10 +30,11 @@ type ServerBinding struct {
MaxCloseRate int // per second
ListenFunc func(network, address string) (net.Listener, error)
Codec CompressorCodec
OverlayNetwork func(net.Listener) Listener
OverlayNetwork func(net.Listener, *tls.Config) Listener
OnKickCB func(w FrameWriter)
LatencyMetric metrics.Histogram
CounterMetric metrics.Counter
TLSConf *tls.Config
ln Listener
}

Expand All @@ -49,11 +52,13 @@ type ConnectionConfig struct {
Handler Handler
OverlayNetwork func(address string, dialConfig DialConfig) (net.Conn, error)
Codec CompressorCodec
TLSConf *tls.Config
}

// DialConfig for dial
type DialConfig struct {
DialTimeout time.Duration
WBufSize int // best effort only, check log for error
RBufSize int // best effort only, check log for error
TLSConf *tls.Config
}
10 changes: 8 additions & 2 deletions server.go
Expand Up @@ -2,6 +2,7 @@ package qrpc

import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
Expand Down Expand Up @@ -188,9 +189,14 @@ func (srv *Server) ListenAll() (err error) {
}

if binding.OverlayNetwork != nil {
srv.bindings[i].ln = binding.OverlayNetwork(ln)
srv.bindings[i].ln = binding.OverlayNetwork(ln, srv.bindings[i].TLSConf)
} else {
srv.bindings[i].ln = ln.(*net.TCPListener)
if srv.bindings[i].TLSConf != nil {
srv.bindings[i].ln = tls.NewListener(ln.(*net.TCPListener), srv.bindings[i].TLSConf)
} else {
srv.bindings[i].ln = ln.(*net.TCPListener)
}

}
}

Expand Down
8 changes: 7 additions & 1 deletion ws/client/conn.go
Expand Up @@ -44,8 +44,14 @@ func DialConn(address string, dialConfig qrpc.DialConfig) (nc net.Conn, err erro
}
return
},
TLSClientConfig: dialConfig.TLSConf,
}
wc, resp, err = dialer.Dial("ws://"+address+"/qrpc", http.Header{})
if dialConfig.TLSConf != nil {
wc, resp, err = dialer.Dial("wss://"+address+"/qrpc", http.Header{})
} else {
wc, resp, err = dialer.Dial("ws://"+address+"/qrpc", http.Header{})
}

if err != nil {
qrpc.Logger().Error("dialer.Dial", zap.Any("resp", resp), zap.Error(err))
return
Expand Down
13 changes: 10 additions & 3 deletions ws/server/overlay.go
Expand Up @@ -2,6 +2,7 @@ package server

import (
"context"
"crypto/tls"
"net"
"net/http"

Expand All @@ -15,14 +16,15 @@ const (
)

// OverlayNetwork impl the overlay network for ws
func OverlayNetwork(l net.Listener) qrpc.Listener {
return newOverlay(l)
func OverlayNetwork(l net.Listener, tlsConfig *tls.Config) qrpc.Listener {
return newOverlay(l, tlsConfig)
}

type qrpcOverWS struct {
l net.Listener
httpServer *http.Server
acceptCh chan *websocket.Conn
tlsConfig *tls.Config
ctx context.Context
cancelFunc context.CancelFunc
}
Expand All @@ -34,7 +36,11 @@ var upgrader = websocket.Upgrader{
Subprotocols: []string{"null"},
}

func newOverlay(l net.Listener) (o *qrpcOverWS) {
func newOverlay(l net.Listener, tlsConfig *tls.Config) (o *qrpcOverWS) {

if tlsConfig != nil {
l = tls.NewListener(l, tlsConfig)
}

mux := &http.ServeMux{}
mux.HandleFunc("/qrpc", func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -58,6 +64,7 @@ func newOverlay(l net.Listener) (o *qrpcOverWS) {
l: l,
httpServer: httpServer,
acceptCh: make(chan *websocket.Conn, backlog),
tlsConfig: tlsConfig,
ctx: ctx,
cancelFunc: cancelFunc,
}
Expand Down

0 comments on commit 01ff2d5

Please sign in to comment.