From 45788689be7226cc14f367a763d4e8b2d5b93ba6 Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Fri, 16 Dec 2016 08:50:47 +0900 Subject: [PATCH] use Server.Shutdown method to avoid overhead The HTTP Server has support for graceful shutdown in Go 1.8 https://beta.golang.org/doc/go1.8#http_shutdown go-gracedown is now just a wrapper of the net/http package to maintain interface compatibility. --- .travis.yml | 1 + gracedown.go | 211 +++++++++---------------------------- gracedown_fallback.go | 208 ++++++++++++++++++++++++++++++++++++ gracedown_fallback_test.go | 127 ++++++++++++++++++++++ gracedown_test.go | 115 -------------------- 5 files changed, 386 insertions(+), 276 deletions(-) create mode 100644 gracedown_fallback.go create mode 100644 gracedown_fallback_test.go diff --git a/.travis.yml b/.travis.yml index acf477d..48235d9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,4 +4,5 @@ go: - "1.4" - "1.5" - "1.6" + - "1.7" - tip diff --git a/gracedown.go b/gracedown.go index e5f5853..3e9cede 100644 --- a/gracedown.go +++ b/gracedown.go @@ -1,7 +1,9 @@ +// +build go1.8 + package gracedown import ( - "crypto/tls" + "context" "net" "net/http" "sync" @@ -15,13 +17,9 @@ type Server struct { KillTimeOut time.Duration - wg sync.WaitGroup - mu sync.Mutex - originalConnState func(conn net.Conn, newState http.ConnState) - connStateOnce sync.Once - closed int32 // accessed atomically. - idlePool map[net.Conn]struct{} - listeners map[net.Listener]struct{} + mu sync.Mutex + closed int32 // accessed atomically. + doneChan chan struct{} } // NewWithServer wraps an existing http.Server. @@ -29,178 +27,69 @@ func NewWithServer(s *http.Server) *Server { return &Server{ Server: s, KillTimeOut: 10 * time.Second, - idlePool: map[net.Conn]struct{}{}, - listeners: map[net.Listener]struct{}{}, - } -} - -// ListenAndServe provides a graceful equivalent of net/http.Server.ListenAndServe -func (srv *Server) ListenAndServe() error { - addr := srv.Server.Addr - if addr == "" { - addr = ":http" - } - ln, err := net.Listen("tcp", addr) - if err != nil { - return err } - return srv.Serve(ln) } -// ListenAndServeTLS provides a graceful equivalent of net/http.Server.ListenAndServeTLS -func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { - // direct lift from net/http/server.go - addr := srv.Addr - if addr == "" { - addr = ":https" - } - - config := cloneTLSConfig(srv.TLSConfig) - if !strSliceContains(config.NextProtos, "http/1.1") { - config.NextProtos = append(config.NextProtos, "http/1.1") - } - - configHasCert := len(config.Certificates) > 0 || config.GetCertificate != nil - if !configHasCert || certFile != "" || keyFile != "" { - var err error - config.Certificates = make([]tls.Certificate, 1) - config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return err - } - } +func (srv *Server) Serve(l net.Listener) error { + err := srv.Server.Serve(l) - ln, err := net.Listen("tcp", addr) - if err != nil { - return err + // Wait for closing all connections. + if err == http.ErrServerClosed { + ch := srv.getDoneChan() + <-ch + return nil } - - return srv.Serve(tls.NewListener(ln, config)) + return err } -// cloneTLSConfig returns a shallow clone of the exported -// fields of cfg, ignoring the unexported sync.Once, which -// contains a mutex and must not be copied. -// -// The cfg must not be in active use by tls.Server, or else -// there can still be a race with tls.Server updating SessionTicketKey -// and our copying it, and also a race with the server setting -// SessionTicketsDisabled=false on failure to set the random -// ticket key. -// -// If cfg is nil, a new zero tls.Config is returned. -// -// Direct lift from net/http/transport.go -func cloneTLSConfig(cfg *tls.Config) *tls.Config { - if cfg == nil { - return &tls.Config{} - } - return &tls.Config{ - Rand: cfg.Rand, - Time: cfg.Time, - Certificates: cfg.Certificates, - NameToCertificate: cfg.NameToCertificate, - GetCertificate: cfg.GetCertificate, - RootCAs: cfg.RootCAs, - NextProtos: cfg.NextProtos, - ServerName: cfg.ServerName, - ClientAuth: cfg.ClientAuth, - ClientCAs: cfg.ClientCAs, - InsecureSkipVerify: cfg.InsecureSkipVerify, - CipherSuites: cfg.CipherSuites, - PreferServerCipherSuites: cfg.PreferServerCipherSuites, - SessionTicketsDisabled: cfg.SessionTicketsDisabled, - SessionTicketKey: cfg.SessionTicketKey, - ClientSessionCache: cfg.ClientSessionCache, - MinVersion: cfg.MinVersion, - MaxVersion: cfg.MaxVersion, - CurvePreferences: cfg.CurvePreferences, - } +func (srv *Server) getDoneChan() <-chan struct{} { + srv.mu.Lock() + defer srv.mu.Unlock() + return srv.getDoneChanLocked() } -func strSliceContains(ss []string, s string) bool { - for _, v := range ss { - if v == s { - return true - } +func (srv *Server) getDoneChanLocked() chan struct{} { + if srv.doneChan == nil { + srv.doneChan = make(chan struct{}) } - return false + return srv.doneChan } -// Serve provides a graceful equivalent of net/http.Server.Serve -func (srv *Server) Serve(l net.Listener) error { - // remember net.Listener - srv.mu.Lock() - srv.listeners[l] = struct{}{} - srv.mu.Unlock() - defer func() { - srv.mu.Lock() - delete(srv.listeners, l) - srv.mu.Unlock() - }() - - // replace ConnState - srv.connStateOnce.Do(func() { - srv.originalConnState = srv.Server.ConnState - srv.Server.ConnState = srv.connState - }) - - err := srv.Server.Serve(l) - - go func() { - // wait for closing keep-alive connection by sending `Connection: Close` header. - time.Sleep(srv.KillTimeOut) - - // time out, close all idle connections - srv.mu.Lock() - for conn := range srv.idlePool { - conn.Close() - } - srv.mu.Unlock() - }() - - // wait all connections have done - srv.wg.Wait() - - if atomic.LoadInt32(&srv.closed) != 0 { - // ignore closed network error when srv.Close() is called - return nil +func (srv *Server) closeDoneChanLocked() { + ch := srv.getDoneChanLocked() + select { + case <-ch: + // Already closed. Don't close again. + default: + // Safe to close here. We're the only closer, guarded + // by s.mu. + close(ch) } - return err } // Close shuts down the default server used by ListenAndServe, ListenAndServeTLS and // Serve. It returns true if it's the first time Close is called. func (srv *Server) Close() bool { - if atomic.CompareAndSwapInt32(&srv.closed, 0, 1) { - srv.Server.SetKeepAlivesEnabled(false) - srv.mu.Lock() - listeners := srv.listeners - srv.listeners = map[net.Listener]struct{}{} - srv.mu.Unlock() - for l := range listeners { - l.Close() - } - return true + if !atomic.CompareAndSwapInt32(&srv.closed, 0, 1) { + return false } - return false -} -func (srv *Server) connState(conn net.Conn, newState http.ConnState) { - srv.mu.Lock() - switch newState { - case http.StateNew: - srv.wg.Add(1) - case http.StateActive: - delete(srv.idlePool, conn) - case http.StateIdle: - srv.idlePool[conn] = struct{}{} - case http.StateClosed, http.StateHijacked: - delete(srv.idlePool, conn) - srv.wg.Done() - } - srv.mu.Unlock() - if srv.originalConnState != nil { - srv.originalConnState(conn, newState) + // immediately closes all connection. + if srv.KillTimeOut == 0 { + srv.Server.Close() + return true } + + // graceful shutdown + go func() { + ctx, cancel := context.WithTimeout(context.Background(), srv.KillTimeOut) + defer cancel() + srv.Shutdown(ctx) + + srv.mu.Lock() + defer srv.mu.Unlock() + srv.closeDoneChanLocked() + }() + + return true } diff --git a/gracedown_fallback.go b/gracedown_fallback.go new file mode 100644 index 0000000..5c558a2 --- /dev/null +++ b/gracedown_fallback.go @@ -0,0 +1,208 @@ +// +build !go1.8 + +package gracedown + +import ( + "crypto/tls" + "net" + "net/http" + "sync" + "sync/atomic" + "time" +) + +// Server provides a graceful equivalent of net/http.Server. +type Server struct { + *http.Server + + KillTimeOut time.Duration + + wg sync.WaitGroup + mu sync.Mutex + originalConnState func(conn net.Conn, newState http.ConnState) + connStateOnce sync.Once + closed int32 // accessed atomically. + idlePool map[net.Conn]struct{} + listeners map[net.Listener]struct{} +} + +// NewWithServer wraps an existing http.Server. +func NewWithServer(s *http.Server) *Server { + return &Server{ + Server: s, + KillTimeOut: 10 * time.Second, + idlePool: map[net.Conn]struct{}{}, + listeners: map[net.Listener]struct{}{}, + } +} + +// ListenAndServe provides a graceful equivalent of net/http.Server.ListenAndServe +func (srv *Server) ListenAndServe() error { + addr := srv.Server.Addr + if addr == "" { + addr = ":http" + } + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + return srv.Serve(ln) +} + +// ListenAndServeTLS provides a graceful equivalent of net/http.Server.ListenAndServeTLS +func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { + // direct lift from net/http/server.go + addr := srv.Addr + if addr == "" { + addr = ":https" + } + + config := cloneTLSConfig(srv.TLSConfig) + if !strSliceContains(config.NextProtos, "http/1.1") { + config.NextProtos = append(config.NextProtos, "http/1.1") + } + + configHasCert := len(config.Certificates) > 0 || config.GetCertificate != nil + if !configHasCert || certFile != "" || keyFile != "" { + var err error + config.Certificates = make([]tls.Certificate, 1) + config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return err + } + } + + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + + return srv.Serve(tls.NewListener(ln, config)) +} + +// cloneTLSConfig returns a shallow clone of the exported +// fields of cfg, ignoring the unexported sync.Once, which +// contains a mutex and must not be copied. +// +// The cfg must not be in active use by tls.Server, or else +// there can still be a race with tls.Server updating SessionTicketKey +// and our copying it, and also a race with the server setting +// SessionTicketsDisabled=false on failure to set the random +// ticket key. +// +// If cfg is nil, a new zero tls.Config is returned. +// +// Direct lift from net/http/transport.go +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return &tls.Config{ + Rand: cfg.Rand, + Time: cfg.Time, + Certificates: cfg.Certificates, + NameToCertificate: cfg.NameToCertificate, + GetCertificate: cfg.GetCertificate, + RootCAs: cfg.RootCAs, + NextProtos: cfg.NextProtos, + ServerName: cfg.ServerName, + ClientAuth: cfg.ClientAuth, + ClientCAs: cfg.ClientCAs, + InsecureSkipVerify: cfg.InsecureSkipVerify, + CipherSuites: cfg.CipherSuites, + PreferServerCipherSuites: cfg.PreferServerCipherSuites, + SessionTicketsDisabled: cfg.SessionTicketsDisabled, + SessionTicketKey: cfg.SessionTicketKey, + ClientSessionCache: cfg.ClientSessionCache, + MinVersion: cfg.MinVersion, + MaxVersion: cfg.MaxVersion, + CurvePreferences: cfg.CurvePreferences, + } +} + +func strSliceContains(ss []string, s string) bool { + for _, v := range ss { + if v == s { + return true + } + } + return false +} + +// Serve provides a graceful equivalent of net/http.Server.Serve +func (srv *Server) Serve(l net.Listener) error { + // remember net.Listener + srv.mu.Lock() + srv.listeners[l] = struct{}{} + srv.mu.Unlock() + defer func() { + srv.mu.Lock() + delete(srv.listeners, l) + srv.mu.Unlock() + }() + + // replace ConnState + srv.connStateOnce.Do(func() { + srv.originalConnState = srv.Server.ConnState + srv.Server.ConnState = srv.connState + }) + + err := srv.Server.Serve(l) + + go func() { + // wait for closing keep-alive connection by sending `Connection: Close` header. + time.Sleep(srv.KillTimeOut) + + // time out, close all idle connections + srv.mu.Lock() + for conn := range srv.idlePool { + conn.Close() + } + srv.mu.Unlock() + }() + + // wait all connections have done + srv.wg.Wait() + + if atomic.LoadInt32(&srv.closed) != 0 { + // ignore closed network error when srv.Close() is called + return nil + } + return err +} + +// Close shuts down the default server used by ListenAndServe, ListenAndServeTLS and +// Serve. It returns true if it's the first time Close is called. +func (srv *Server) Close() bool { + if atomic.CompareAndSwapInt32(&srv.closed, 0, 1) { + srv.Server.SetKeepAlivesEnabled(false) + srv.mu.Lock() + listeners := srv.listeners + srv.listeners = map[net.Listener]struct{}{} + srv.mu.Unlock() + for l := range listeners { + l.Close() + } + return true + } + return false +} + +func (srv *Server) connState(conn net.Conn, newState http.ConnState) { + srv.mu.Lock() + switch newState { + case http.StateNew: + srv.wg.Add(1) + case http.StateActive: + delete(srv.idlePool, conn) + case http.StateIdle: + srv.idlePool[conn] = struct{}{} + case http.StateClosed, http.StateHijacked: + delete(srv.idlePool, conn) + srv.wg.Done() + } + srv.mu.Unlock() + if srv.originalConnState != nil { + srv.originalConnState(conn, newState) + } +} diff --git a/gracedown_fallback_test.go b/gracedown_fallback_test.go new file mode 100644 index 0000000..feafba4 --- /dev/null +++ b/gracedown_fallback_test.go @@ -0,0 +1,127 @@ +// +build !go1.8 + +package gracedown + +import ( + "net" + "net/http" + "testing" + "time" +) + +func TestShutdown_KeepAlive(t *testing.T) { + // prepare test server + handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + }) + ts := NewWithServer(&http.Server{ + Handler: handler, + }) + + // start server + l := newLocalListener() + go func() { + ts.Serve(l) + }() + url := "http://" + l.Addr().String() + + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + Dial: (&net.Dialer{ + Timeout: 5 * time.Second, + KeepAlive: 5 * time.Second, + }).Dial, + TLSHandshakeTimeout: 10 * time.Second, + DisableKeepAlives: false, // keep-alives are ENABLE!! + MaxIdleConnsPerHost: 1, + }, + } + + // 1st request will be success + resp, err := client.Get(url) + if err != nil { + t.Errorf("unexpected error: %v", err) + } else { + resp.Body.Close() + } + + // start shutting down process + ts.Close() + + // 2nd request will be success, because this request uses the Keep-Alive connection + resp, err = client.Get(url) + if err != nil { + t.Errorf("unexpected error: %v", err) + } else { + resp.Body.Close() + } + + // 3rd request will be failure, because the Keep-Alive connection is closed + resp, err = client.Get(url) + if err == nil { + t.Error("want error, but not") + } +} + +func TestShutdown_KillKeepAlive(t *testing.T) { + // prepare test server + handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + }) + ts := NewWithServer(&http.Server{ + Handler: handler, + }) + ts.KillTimeOut = time.Second // force close after a second + + // start server + done := make(chan error, 1) + l := newLocalListener() + go func() { + done <- ts.Serve(l) + }() + url := "http://" + l.Addr().String() + + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + Dial: (&net.Dialer{ + Timeout: 5 * time.Second, + KeepAlive: 5 * time.Second, + }).Dial, + TLSHandshakeTimeout: 10 * time.Second, + DisableKeepAlives: false, // keep-alives are ENABLE!! + MaxIdleConnsPerHost: 1, + }, + } + + // 1st request will be success + resp, err := client.Get(url) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + resp.Body.Close() + + // start shutting down process + start := time.Now() + ts.Close() + + select { + case err := <-done: + end := time.Now() + dt := end.Sub(start) + t.Logf("kill timeout: %v", dt) + if dt < ts.KillTimeOut { + t.Errorf("too fast kill timeout") + } + if err != nil { + t.Errorf("unexpected err: %v", err) + } + case <-time.After(ts.KillTimeOut + 5*time.Second): + t.Errorf("timeout") + } + + // 2nd request will be failure, because the server has already shut down + resp, err = client.Get(url) + if err == nil { + t.Error("want error, but not") + } +} diff --git a/gracedown_test.go b/gracedown_test.go index 98f814c..dbef3ab 100644 --- a/gracedown_test.go +++ b/gracedown_test.go @@ -120,118 +120,3 @@ func TestShutdown_NoKeepAlive(t *testing.T) { t.Errorf("timeout") } } - -func TestShutdown_KeepAlive(t *testing.T) { - // prepare test server - handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - }) - ts := NewWithServer(&http.Server{ - Handler: handler, - }) - - // start server - l := newLocalListener() - go func() { - ts.Serve(l) - }() - url := "http://" + l.Addr().String() - - client := &http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - Dial: (&net.Dialer{ - Timeout: 5 * time.Second, - KeepAlive: 5 * time.Second, - }).Dial, - TLSHandshakeTimeout: 10 * time.Second, - DisableKeepAlives: false, // keep-alives are ENABLE!! - MaxIdleConnsPerHost: 1, - }, - } - - // 1st request will be success - resp, err := client.Get(url) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - resp.Body.Close() - - // start shutting down process - ts.Close() - - // 2nd request will be success, because this request uses the Keep-Alive connection - resp, err = client.Get(url) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - resp.Body.Close() - - // 3rd request will be failure, because the Keep-Alive connection is closed - resp, err = client.Get(url) - if err == nil { - t.Error("want error, but not") - } -} - -func TestShutdown_KillKeepAlive(t *testing.T) { - // prepare test server - handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - }) - ts := NewWithServer(&http.Server{ - Handler: handler, - }) - ts.KillTimeOut = time.Second // force close after a second - - // start server - done := make(chan error, 1) - l := newLocalListener() - go func() { - done <- ts.Serve(l) - }() - url := "http://" + l.Addr().String() - - client := &http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - Dial: (&net.Dialer{ - Timeout: 5 * time.Second, - KeepAlive: 5 * time.Second, - }).Dial, - TLSHandshakeTimeout: 10 * time.Second, - DisableKeepAlives: false, // keep-alives are ENABLE!! - MaxIdleConnsPerHost: 1, - }, - } - - // 1st request will be success - resp, err := client.Get(url) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - resp.Body.Close() - - // start shutting down process - start := time.Now() - ts.Close() - - select { - case err := <-done: - end := time.Now() - dt := end.Sub(start) - t.Logf("kill timeout: %v", dt) - if dt < ts.KillTimeOut { - t.Errorf("too fast kill timeout") - } - if err != nil { - t.Errorf("unexpected err: %v", err) - } - case <-time.After(ts.KillTimeOut + 5*time.Second): - t.Errorf("timeout") - } - - // 2nd request will be failure, because the server has already shut down - resp, err = client.Get(url) - if err == nil { - t.Error("want error, but not") - } -}