Skip to content
This repository has been archived by the owner on Jun 9, 2023. It is now read-only.

Commit

Permalink
TransportAuthenticator gets to prepare context for every connection
Browse files Browse the repository at this point in the history
fixes grpc#111
  • Loading branch information
tv42 authored and Peter Sanford committed Jun 24, 2015
1 parent 389d18d commit a34965e
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 11 deletions.
6 changes: 6 additions & 0 deletions credentials/credentials.go
Expand Up @@ -91,6 +91,8 @@ type TransportAuthenticator interface {
ServerHandshake(rawConn net.Conn) (net.Conn, error)
// Info provides the ProtocolInfo of this TransportAuthenticator.
Info() ProtocolInfo
// NewServerConn is called in the server for every new connection.
NewServerConn(ctx context.Context, conn net.Conn) context.Context
Credentials
}

Expand Down Expand Up @@ -167,6 +169,10 @@ func NewTLS(c *tls.Config) TransportAuthenticator {
return tc
}

func (c *tlsCreds) NewServerConn(ctx context.Context, conn net.Conn) context.Context {
return ctx
}

// NewClientTLSFromCert constructs a TLS from the input certificate for client.
func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportAuthenticator {
return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp})
Expand Down
8 changes: 7 additions & 1 deletion server.go
Expand Up @@ -205,6 +205,12 @@ func (s *Server) Serve(lis net.Listener) error {
continue
}
}

ctx := context.TODO()
if creds, ok := s.opts.creds.(credentials.TransportAuthenticator); ok {
ctx = creds.NewServerConn(ctx, c)
}

s.mu.Lock()
if s.conns == nil {
s.mu.Unlock()
Expand All @@ -222,7 +228,7 @@ func (s *Server) Serve(lis net.Listener) error {
s.mu.Unlock()

go func() {
st.HandleStreams(func(stream *transport.Stream) {
st.HandleStreams(ctx, func(stream *transport.Stream) {
s.handleStream(st, stream)
})
s.mu.Lock()
Expand Down
14 changes: 8 additions & 6 deletions transport/http2_server.go
Expand Up @@ -86,6 +86,8 @@ type http2Server struct {
streamSendQuota uint32
}

var _ ServerTransport = (*http2Server)(nil)

// newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is
// returned if something goes wrong.
func newHTTP2Server(conn net.Conn, maxStreams uint32) (_ ServerTransport, err error) {
Expand Down Expand Up @@ -137,7 +139,7 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32) (_ ServerTransport, err er
// operateHeader takes action on the decoded headers. It returns the current
// stream if there are remaining headers on the wire (in the following
// Continuation frame).
func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame headerFrame, endStream bool, handle func(*Stream), wg *sync.WaitGroup) (pendingStream *Stream) {
func (t *http2Server) operateHeaders(ctx context.Context, hDec *hpackDecoder, s *Stream, frame headerFrame, endStream bool, handle func(*Stream), wg *sync.WaitGroup) (pendingStream *Stream) {
defer func() {
if pendingStream == nil {
hDec.state = decodeState{}
Expand Down Expand Up @@ -179,9 +181,9 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header
t.updateWindow(s, uint32(n))
}
if hDec.state.timeoutSet {
s.ctx, s.cancel = context.WithTimeout(context.TODO(), hDec.state.timeout)
s.ctx, s.cancel = context.WithTimeout(ctx, hDec.state.timeout)
} else {
s.ctx, s.cancel = context.WithCancel(context.TODO())
s.ctx, s.cancel = context.WithCancel(ctx)
}
// Cache the current stream to the context so that the server application
// can find out. Required when the server wants to send some metadata
Expand All @@ -208,7 +210,7 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header

// HandleStreams receives incoming streams using the given handler. This is
// typically run in a separate goroutine.
func (t *http2Server) HandleStreams(handle func(*Stream)) {
func (t *http2Server) HandleStreams(ctx context.Context, handle func(*Stream)) {
// Check the validity of client preface.
preface := make([]byte, len(clientPreface))
if _, err := io.ReadFull(t.conn, preface); err != nil {
Expand Down Expand Up @@ -268,9 +270,9 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
fc: fc,
}
endStream := frame.Header().Flags.Has(http2.FlagHeadersEndStream)
curStream = t.operateHeaders(hDec, curStream, frame, endStream, handle, &wg)
curStream = t.operateHeaders(ctx, hDec, curStream, frame, endStream, handle, &wg)
case *http2.ContinuationFrame:
curStream = t.operateHeaders(hDec, curStream, frame, false, handle, &wg)
curStream = t.operateHeaders(ctx, hDec, curStream, frame, false, handle, &wg)
case *http2.DataFrame:
t.handleData(frame)
case *http2.RSTStreamFrame:
Expand Down
2 changes: 1 addition & 1 deletion transport/transport.go
Expand Up @@ -386,7 +386,7 @@ type ServerTransport interface {
// WriteHeader sends the header metedata for the given stream.
WriteHeader(s *Stream, md metadata.MD) error
// HandleStreams receives incoming streams using the given handler.
HandleStreams(func(*Stream))
HandleStreams(context.Context, func(*Stream))
// Close tears down the transport. Once it is called, the transport
// should not be accessed any more. All the pending streams and their
// handlers will be terminated asynchronously.
Expand Down
6 changes: 3 additions & 3 deletions transport/transport_test.go
Expand Up @@ -166,11 +166,11 @@ func (s *server) start(port int, maxStreams uint32, ht hType) {
h := &testStreamHandler{t}
switch ht {
case suspended:
go t.HandleStreams(h.handleStreamSuspension)
go t.HandleStreams(context.TODO(), h.handleStreamSuspension)
case misbehaved:
go t.HandleStreams(h.handleStreamMisbehave)
go t.HandleStreams(context.TODO(), h.handleStreamMisbehave)
default:
go t.HandleStreams(h.handleStream)
go t.HandleStreams(context.TODO(), h.handleStream)
}
}
}
Expand Down

0 comments on commit a34965e

Please sign in to comment.