Skip to content
Permalink
Browse files

Merge pull request #1799 from transcom/graceful_shutdown

Graceful shutdown
  • Loading branch information...
pjdufour-truss committed Mar 12, 2019
2 parents 8382f73 + 9d54267 commit c523ecf2c397259744aa83fe2759300ed6349965
Showing with 142 additions and 34 deletions.
  1. +101 −13 cmd/webserver/main.go
  2. +35 −15 pkg/server/server.go
  3. +6 −6 pkg/server/server_test.go
@@ -2,6 +2,7 @@ package main

import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/hex"
@@ -12,11 +13,14 @@ import (
"net/http"
"net/http/httptest"
"os"
"os/signal"
"path"
"path/filepath"
"regexp"
"strconv"
"strings"
"sync"
"syscall"
"time"

"github.com/aws/aws-sdk-go/aws"
@@ -173,6 +177,7 @@ func initFlags(flag *pflag.FlagSet) {
flag.String("env", "development", "The environment to run in, which configures the database.")
flag.String("interface", "", "The interface spec to listen for connections on. Default is all.")
flag.String("service-name", "app", "The service name identifies the application for instrumentation.")
flag.Duration("graceful-shutdown-timeout", 25*time.Second, "The duration for which the server gracefully wait for existing connections to finish. AWS ECS only gives you 30 seconds before sending SIGKILL.")

flag.String("http-my-server-name", "milmovelocal", "Hostname according to environment.")
flag.String("http-office-server-name", "officelocal", "Hostname according to environment.")
@@ -706,6 +711,26 @@ func checkStorage(v *viper.Viper) error {
return nil
}

func startListener(srv *server.NamedServer, logger *webserverLogger, useTLS bool) {
logger.Info("Starting listener",
zap.String("name", srv.Name),
zap.Duration("idle-timeout", srv.IdleTimeout),
zap.Any("listen-address", srv.Addr),
zap.Int("max-header-bytes", srv.MaxHeaderBytes),
zap.Int("port", srv.Port()),
zap.Bool("tls", useTLS),
)
var err error
if useTLS {
err = srv.ListenAndServeTLS()
} else {
err = srv.ListenAndServe()
}
if err != nil && err != http.ErrServerClosed {
logger.Fatal("server error", zap.String("name", srv.Name), zap.Error(err))
}
}

func main() {

flag := pflag.CommandLine
@@ -1173,11 +1198,10 @@ func main() {
httpHandler = site
}

errChan := make(chan error)

listenInterface := v.GetString("interface")

noTLSServer, err := server.CreateServer(&server.CreateServerInput{
noTLSServer, err := server.CreateNamedServer(&server.CreateNamedServerInput{
Name: "no-tls",
Host: listenInterface,
Port: v.GetInt("no-tls-port"),
Logger: zapLogger,
@@ -1186,11 +1210,10 @@ func main() {
if err != nil {
logger.Fatal("error creating no-tls server", zap.Error(err))
}
go func() {
errChan <- noTLSServer.ListenAndServe()
}()
go startListener(noTLSServer, logger, false)

tlsServer, err := server.CreateServer(&server.CreateServerInput{
tlsServer, err := server.CreateNamedServer(&server.CreateNamedServerInput{
Name: "tls",
Host: listenInterface,
Port: v.GetInt("tls-port"),
Logger: zapLogger,
@@ -1201,11 +1224,10 @@ func main() {
if err != nil {
logger.Fatal("error creating tls server", zap.Error(err))
}
go func() {
errChan <- tlsServer.ListenAndServeTLS()
}()
go startListener(tlsServer, logger, true)

mutualTLSServer, err := server.CreateServer(&server.CreateServerInput{
mutualTLSServer, err := server.CreateNamedServer(&server.CreateNamedServerInput{
Name: "mutual-tls",
Host: listenInterface,
Port: v.GetInt("mutual-tls-port"),
Logger: zapLogger,
@@ -1217,11 +1239,77 @@ func main() {
if err != nil {
logger.Fatal("error creating mutual-tls server", zap.Error(err))
}
go startListener(mutualTLSServer, logger, true)

// make sure we flush any pending startup messages
logger.Sync()

// Create a buffered channel that accepts 1 signal at a time.
quit := make(chan os.Signal, 1)

// Only send the SIGINT and SIGTERM signals to the quit channel
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)

// Wait until the quit channel receieves a signal
sig := <-quit

logger.Info("received signal for graceful shutdown of server", zap.Any("signal", sig))

// flush message that we received signal
logger.Sync()

gracefulShutdownTimeout := v.GetDuration("graceful-shutdown-timeout")

ctx, cancel := context.WithTimeout(context.Background(), gracefulShutdownTimeout)
defer cancel()

logger.Info("Waiting for listeners to be shutdown", zap.Duration("timeout", gracefulShutdownTimeout))

// flush message that we are waiting on listeners
logger.Sync()

wg := &sync.WaitGroup{}
var shutdownErrors sync.Map

wg.Add(1)
go func() {
errChan <- mutualTLSServer.ListenAndServeTLS()
shutdownErrors.Store(noTLSServer, noTLSServer.Shutdown(ctx))
wg.Done()
}()

logger.Fatal("listener error", zap.Error(<-errChan))
wg.Add(1)
go func() {
shutdownErrors.Store(tlsServer, tlsServer.Shutdown(ctx))
wg.Done()
}()

wg.Add(1)
go func() {
shutdownErrors.Store(mutualTLSServer, mutualTLSServer.Shutdown(ctx))
wg.Done()
}()

wg.Wait()
logger.Info("All listeners are shutdown")
logger.Sync()

shutdownError := false
shutdownErrors.Range(func(key, value interface{}) bool {
if srv, ok := key.(*server.NamedServer); ok {
if err, ok := value.(error); ok {
logger.Error("shutdown error", zap.String("name", srv.Name), zap.String("addr", srv.Addr), zap.Int("port", srv.Port()), zap.Error(err))
shutdownError = true
} else {
logger.Info("shutdown server", zap.String("name", srv.Name), zap.String("addr", srv.Addr), zap.Int("port", srv.Port()))
}
}
return true
})
logger.Sync()

if shutdownError {
os.Exit(1)
}
}

// fileHandler serves up a single file
@@ -5,6 +5,8 @@ import (
"crypto/x509"
"fmt"
"net/http"
"strconv"
"strings"
"time"

"github.com/pkg/errors"
@@ -35,8 +37,9 @@ var curvePreferences = []tls.CurveID{
tls.X25519,
}

// CreateServerInput contains the input for the CreateServer function.
type CreateServerInput struct {
// CreateNamedServerInput contains the input for the CreateServer function.
type CreateNamedServerInput struct {
Name string
Host string
Port int
Logger *zap.Logger
@@ -46,13 +49,26 @@ type CreateServerInput struct {
ClientCAs *x509.CertPool // CaCertPool
}

// Server wraps *http.Server to override the definition of ListenAndServeTLS, but bypasses some restrictions.
type Server struct {
// NamedServer wraps *http.Server to override the definition of ListenAndServeTLS, but bypasses some restrictions.
type NamedServer struct {
*http.Server
Name string
}

// Port returns the port the server binds to. Returns -1 if any error.
func (s *NamedServer) Port() int {
if !strings.Contains(s.Addr, ":") {
return -1
}
port, err := strconv.Atoi(strings.SplitN(s.Addr, ":", 2)[1])
if err != nil {
return -1
}
return port
}

// ListenAndServeTLS is similar to (*http.Server).ListenAndServeTLS, but bypasses some restrictions.
func (s *Server) ListenAndServeTLS() error {
func (s *NamedServer) ListenAndServeTLS() error {
listener, err := tls.Listen("tcp", s.Addr, s.TLSConfig)
if err != nil {
return err
@@ -61,8 +77,8 @@ func (s *Server) ListenAndServeTLS() error {
return s.Serve(listener)
}

// CreateServer returns a no-tls, tls, or mutual-tls Server based on the input given and an error, if any.
func CreateServer(input *CreateServerInput) (*Server, error) {
// CreateNamedServer returns a no-tls, tls, or mutual-tls Server based on the input given and an error, if any.
func CreateNamedServer(input *CreateNamedServerInput) (*NamedServer, error) {

address := fmt.Sprintf("%s:%d", input.Host, input.Port)

@@ -106,13 +122,17 @@ func CreateServer(input *CreateServerInput) (*Server, error) {
tlsConfig.BuildNameToCertificate()
}

return &Server{Server: &http.Server{
Addr: address,
ErrorLog: standardLog,
Handler: input.HTTPHandler,
IdleTimeout: idleTimeout,
MaxHeaderBytes: maxHeaderSize,
TLSConfig: tlsConfig,
}}, nil
srv := &NamedServer{
Name: input.Name,
Server: &http.Server{
Addr: address,
ErrorLog: standardLog,
Handler: input.HTTPHandler,
IdleTimeout: idleTimeout,
MaxHeaderBytes: maxHeaderSize,
TLSConfig: tlsConfig,
},
}
return srv, nil

}
@@ -54,7 +54,7 @@ func (suite *serverSuite) TestParseSingleTLSCert() {

suite.Nil(err)

httpsServer, err := CreateServer(&CreateServerInput{
httpsServer, err := CreateNamedServer(&CreateNamedServerInput{
Host: "127.0.0.1",
Port: 8443,
ClientAuth: tls.NoClientCert,
@@ -90,7 +90,7 @@ func (suite *serverSuite) TestParseMultipleTLSCerts() {

suite.Nil(err)

httpsServer, err := CreateServer(&CreateServerInput{
httpsServer, err := CreateNamedServer(&CreateNamedServerInput{
Host: "127.0.0.1",
Port: 8443,
ClientAuth: tls.NoClientCert,
@@ -119,7 +119,7 @@ func (suite *serverSuite) TestTLSConfigWithClientAuth() {
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caFile)

_, err = CreateServer(&CreateServerInput{
_, err = CreateNamedServer(&CreateNamedServerInput{
Host: "127.0.0.1",
Port: 8443,
ClientAuth: tls.RequireAndVerifyClientCert,
@@ -139,7 +139,7 @@ func (suite *serverSuite) TestTLSConfigWithMissingCA() {

suite.Nil(err)

_, err = CreateServer(&CreateServerInput{
_, err = CreateNamedServer(&CreateNamedServerInput{
Host: "127.0.0.1",
Port: 8443,
ClientAuth: tls.RequireAndVerifyClientCert,
@@ -163,7 +163,7 @@ func (suite *serverSuite) TestTLSConfigWithMisconfiguredCA() {
certOk := caCertPool.AppendCertsFromPEM(caFile)
suite.False(certOk)

_, err = CreateServer(&CreateServerInput{
_, err = CreateNamedServer(&CreateNamedServerInput{
Host: "127.0.0.1",
Port: 8443,
ClientAuth: tls.RequireAndVerifyClientCert,
@@ -176,7 +176,7 @@ func (suite *serverSuite) TestTLSConfigWithMisconfiguredCA() {
}

func (suite *serverSuite) TestHTTPServerConfig() {
httpsServer, err := CreateServer(&CreateServerInput{
httpsServer, err := CreateNamedServer(&CreateNamedServerInput{
Host: "127.0.0.1",
Port: 8080,
HTTPHandler: suite.httpHandler,

0 comments on commit c523ecf

Please sign in to comment.
You can’t perform that action at this time.