Skip to content
Permalink
Browse files

graceful shutdown

  • Loading branch information...
pjdufour-truss committed Mar 5, 2019
1 parent 0a60bb0 commit 78301ce5414c6551ff46c40e014c6ba68b7529ca
Showing with 117 additions and 28 deletions.
  1. +82 −13 cmd/webserver/main.go
  2. +35 −15 pkg/server/server.go
@@ -2,6 +2,7 @@ package main

import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/hex"
@@ -12,11 +13,14 @@ import (
"net/http"
"net/http/httptest"
"os"
"os/signal"
"path"
"path/filepath"
"regexp"
"strconv"
"strings"
"sync"
"syscall"
"time"

"github.com/aws/aws-sdk-go/aws"
@@ -171,6 +175,7 @@ func initFlags(flag *pflag.FlagSet) {
flag.String("env", "development", "The environment to run in, which configures the database.")
flag.String("interface", "", "The interface spec to listen for connections on. Default is all.")
flag.String("service-name", "app", "The service name identifies the application for instrumentation.")
flag.Duration("graceful-shutdown-timeout", 25*time.Second, "The duration for which the server gracefully wait for existing connections to finish")

flag.String("http-my-server-name", "milmovelocal", "Hostname according to environment.")
flag.String("http-office-server-name", "officelocal", "Hostname according to environment.")
@@ -703,6 +708,20 @@ func checkStorage(v *viper.Viper) error {
return nil
}

func startListener(srv *server.NamedServer, logger *webserverLogger) {
logger.Info("Starting listener",
zap.String("name", srv.Name),
zap.Duration("idle-timeout", srv.IdleTimeout),
zap.Any("listen-address", srv.Addr),
zap.Int("max-header-bytes", srv.MaxHeaderBytes),
zap.Int("port", srv.Port()),
)
err := srv.ListenAndServe()
if err != nil && err != http.ErrServerClosed {
logger.Fatal("server error", zap.String("name", srv.Name), zap.Error(err))
}
}

func main() {

flag := pflag.CommandLine
@@ -1144,11 +1163,10 @@ func main() {
httpHandler = site
}

errChan := make(chan error)

listenInterface := v.GetString("interface")

noTLSServer, err := server.CreateServer(&server.CreateServerInput{
noTLSServer, err := server.CreateNamedServer(&server.CreateNamedServerInput{
Name: "no-tls",
Host: listenInterface,
Port: v.GetInt("no-tls-port"),
Logger: zapLogger,
@@ -1157,11 +1175,10 @@ func main() {
if err != nil {
logger.Fatal("error creating no-tls server", zap.Error(err))
}
go func() {
errChan <- noTLSServer.ListenAndServe()
}()
go startListener(noTLSServer, logger)

tlsServer, err := server.CreateServer(&server.CreateServerInput{
tlsServer, err := server.CreateNamedServer(&server.CreateNamedServerInput{
Name: "tls",
Host: listenInterface,
Port: v.GetInt("tls-port"),
Logger: zapLogger,
@@ -1172,11 +1189,10 @@ func main() {
if err != nil {
logger.Fatal("error creating tls server", zap.Error(err))
}
go func() {
errChan <- tlsServer.ListenAndServeTLS()
}()
go startListener(tlsServer, logger)

mutualTLSServer, err := server.CreateServer(&server.CreateServerInput{
mutualTLSServer, err := server.CreateNamedServer(&server.CreateNamedServerInput{
Name: "mutual-tls",
Host: listenInterface,
Port: v.GetInt("mutual-tls-port"),
Logger: zapLogger,
@@ -1188,11 +1204,64 @@ func main() {
if err != nil {
logger.Fatal("error creating mutual-tls server", zap.Error(err))
}
go startListener(mutualTLSServer, logger)

quit := make(chan os.Signal, 1)

signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)

sig := <-quit

logger.Info("recevied signal for graceful shutdown of server", zap.Any("signal", sig))

gracefulShutdownTimeout := v.GetDuration("graceful-shutdown-timeout")

ctx, cancel := context.WithTimeout(context.Background(), gracefulShutdownTimeout)
defer cancel()

logger.Info("Waiting for listeners to be shutdown", zap.Duration("timeout", gracefulShutdownTimeout))

wg := &sync.WaitGroup{}
var shutdownErrors sync.Map

wg.Add(1)
go func() {
shutdownErrors.Store(noTLSServer, noTLSServer.Shutdown(ctx))
wg.Done()
}()

wg.Add(1)
go func() {
shutdownErrors.Store(tlsServer, tlsServer.Shutdown(ctx))
wg.Done()
}()

wg.Add(1)
go func() {
errChan <- mutualTLSServer.ListenAndServeTLS()
shutdownErrors.Store(mutualTLSServer, mutualTLSServer.Shutdown(ctx))
wg.Done()
}()

logger.Fatal("listener error", zap.Error(<-errChan))
wg.Wait()
logger.Info("All listeners are shutdown")
logger.Sync()

shutdownError := false
shutdownErrors.Range(func(key, value interface{}) bool {
srv := key.(*server.NamedServer)
if err, ok := value.(error); ok {
logger.Error("shutdown error", zap.String("name", srv.Name), zap.String("addr", srv.Addr), zap.Int("port", srv.Port()), zap.Error(err))
shutdownError = true
} else {
logger.Info("shutdown server", zap.String("name", srv.Name), zap.String("addr", srv.Addr), zap.Int("port", srv.Port()))
}
return true
})
logger.Sync()

if shutdownError {
os.Exit(1)
}
}

// fileHandler serves up a single file
@@ -5,6 +5,8 @@ import (
"crypto/x509"
"fmt"
"net/http"
"strconv"
"strings"
"time"

"github.com/pkg/errors"
@@ -35,8 +37,9 @@ var curvePreferences = []tls.CurveID{
tls.X25519,
}

// CreateServerInput contains the input for the CreateServer function.
type CreateServerInput struct {
// CreateNamedServerInput contains the input for the CreateServer function.
type CreateNamedServerInput struct {
Name string
Host string
Port int
Logger *zap.Logger
@@ -46,13 +49,26 @@ type CreateServerInput struct {
ClientCAs *x509.CertPool // CaCertPool
}

// Server wraps *http.Server to override the definition of ListenAndServeTLS, but bypasses some restrictions.
type Server struct {
// NamedServer wraps *http.Server to override the definition of ListenAndServeTLS, but bypasses some restrictions.
type NamedServer struct {
*http.Server
Name string
}

// Port returns the port the server binds to. Returns -1 if any error.
func (s *NamedServer) Port() int {
if !strings.Contains(s.Addr, ":") {
return -1
}
port, err := strconv.Atoi(strings.SplitN(s.Addr, ":", 2)[1])
if err != nil {
return -1
}
return port
}

// ListenAndServeTLS is similar to (*http.Server).ListenAndServeTLS, but bypasses some restrictions.
func (s *Server) ListenAndServeTLS() error {
func (s *NamedServer) ListenAndServeTLS() error {
listener, err := tls.Listen("tcp", s.Addr, s.TLSConfig)
if err != nil {
return err
@@ -61,8 +77,8 @@ func (s *Server) ListenAndServeTLS() error {
return s.Serve(listener)
}

// CreateServer returns a no-tls, tls, or mutual-tls Server based on the input given and an error, if any.
func CreateServer(input *CreateServerInput) (*Server, error) {
// CreateNamedServer returns a no-tls, tls, or mutual-tls Server based on the input given and an error, if any.
func CreateNamedServer(input *CreateNamedServerInput) (*NamedServer, error) {

address := fmt.Sprintf("%s:%d", input.Host, input.Port)

@@ -106,13 +122,17 @@ func CreateServer(input *CreateServerInput) (*Server, error) {
tlsConfig.BuildNameToCertificate()
}

return &Server{Server: &http.Server{
Addr: address,
ErrorLog: standardLog,
Handler: input.HTTPHandler,
IdleTimeout: idleTimeout,
MaxHeaderBytes: maxHeaderSize,
TLSConfig: tlsConfig,
}}, nil
srv := &NamedServer{
Name: input.Name,
Server: &http.Server{
Addr: address,
ErrorLog: standardLog,
Handler: input.HTTPHandler,
IdleTimeout: idleTimeout,
MaxHeaderBytes: maxHeaderSize,
TLSConfig: tlsConfig,
},
}
return srv, nil

}

0 comments on commit 78301ce

Please sign in to comment.
You can’t perform that action at this time.