From aad55e73a0024ac20891b9c4ff6f4dfd5a45f6a4 Mon Sep 17 00:00:00 2001 From: Roman Sharkov Date: Wed, 21 Mar 2018 15:19:08 +0100 Subject: [PATCH] Implement the autoconnect feature Autoconnect is an optional feature, enabled by default, that makes client.Request, client.TimedRequest and client.RestoreSession automatically try to establish a connection if there's none until the timeout is reached. Also make client.Signal automatically try to connect without retrials directly returning a DisconnectedErr if the connection couldn't be established at the first trial. Add a new server-side hook: BeforeUpgrade to be able to intercept connection attempts and either prevent, delay or monitor them. This hook is rather useful for testing. --- client.go | 4 +- client/client.go | 124 +++++------------- client/connect.go | 87 ++++++++++++ client/errors.go | 8 -- client/options.go | 43 +++++- client/requestSessionRestoration.go | 2 +- client/sendNamelessRequest.go | 5 - client/sendRequest.go | 5 - client/tryAutoconnect.go | 56 ++++++++ client/verifyProtocolVersion.go | 14 +- errors.go | 44 ++++--- server.go | 15 +++ test/clientRestoreSessionDisconnected_test.go | 18 ++- test/clientSignalDisconnected_test.go | 7 +- test/cllientRequestDisconnected_test.go | 16 ++- 15 files changed, 297 insertions(+), 151 deletions(-) create mode 100644 client/connect.go delete mode 100644 client/errors.go create mode 100644 client/tryAutoconnect.go diff --git a/client.go b/client.go index f240ab9..62983c6 100644 --- a/client.go +++ b/client.go @@ -51,7 +51,7 @@ func (clt *Client) write(data []byte) error { defer clt.connLock.Unlock() if !clt.connected { return DisconnectedErr{ - msg: "Can't write to a disconnected client agent", + cause: fmt.Errorf("Can't write to a disconnected client agent"), } } return clt.conn.WriteMessage(websocket.BinaryMessage, data) @@ -115,7 +115,7 @@ func (clt *Client) CreateSession(attachment interface{}) error { if !clt.connected { clt.connLock.RUnlock() return DisconnectedErr{ - msg: "Can't create session on disconnected client agent", + cause: fmt.Errorf("Can't create session on disconnected client agent"), } } clt.connLock.RUnlock() diff --git a/client/client.go b/client/client.go index 38537a1..bf850dd 100644 --- a/client/client.go +++ b/client/client.go @@ -8,7 +8,6 @@ import ( "fmt" "log" - "net/url" "sync" "time" @@ -19,10 +18,12 @@ const supportedProtocolVersion = "1.2" // Client represents an instance of one of the servers clients type Client struct { - serverAddr string - isConnected int32 - defaultTimeout time.Duration - hooks Hooks + serverAddr string + isConnected int32 + defaultReqTimeout time.Duration + reconnInterval time.Duration + autoconnect bool + hooks Hooks sessionLock sync.RWMutex session *webwire.Session @@ -32,9 +33,10 @@ type Client struct { // because performing multiple requests and/or signals simultaneously is fine. // The Connect, RestoreSession, CloseSession and Close methods are locked exclusively // because they should temporarily block any other interaction with this client instance. - apiLock sync.RWMutex - connLock sync.Mutex - conn *websocket.Conn + apiLock sync.RWMutex + connectLock sync.Mutex + connLock sync.Mutex + conn *websocket.Conn requestManager reqman.RequestManager @@ -47,10 +49,17 @@ type Client struct { func NewClient(serverAddress string, opts Options) *Client { opts.SetDefaults() + autoconnect := true + if opts.Autoconnect == OptDisabled { + autoconnect = false + } + return &Client{ serverAddress, 0, opts.DefaultRequestTimeout, + opts.ReconnectionInterval, + autoconnect, opts.Hooks, sync.RWMutex{}, @@ -58,6 +67,7 @@ func NewClient(serverAddress string, opts Options) *Client { sync.RWMutex{}, sync.Mutex{}, + sync.Mutex{}, nil, reqman.NewRequestManager(), @@ -83,82 +93,8 @@ func (clt *Client) IsConnected() bool { // Connect connects the client to the configured server and // returns an error in case of a connection failure. // Automatically tries to restore the previous session -func (clt *Client) Connect() (err error) { - clt.apiLock.Lock() - defer clt.apiLock.Unlock() - - if atomic.LoadInt32(&clt.isConnected) > 0 { - return nil - } - - if err := clt.verifyProtocolVersion(); err != nil { - return err - } - - connURL := url.URL{Scheme: "ws", Host: clt.serverAddr, Path: "/"} - - clt.connLock.Lock() - clt.conn, _, err = websocket.DefaultDialer.Dial(connURL.String(), nil) - if err != nil { - return webwire.NewConnDialErr(err) - } - clt.connLock.Unlock() - - // Setup reader thread - go func() { - defer clt.close() - for { - _, message, err := clt.conn.ReadMessage() - if err != nil { - if websocket.IsUnexpectedCloseError( - err, - websocket.CloseGoingAway, - websocket.CloseAbnormalClosure, - ) { - // Error while reading message - clt.errorLog.Print("Failed reading message:", err) - break - } else { - // Shutdown client due to clean disconnection - break - } - } - // Try to handle the message - if err = clt.handleMessage(message); err != nil { - clt.warningLog.Print("Failed handling message:", err) - } - } - }() - - atomic.StoreInt32(&clt.isConnected, 1) - - // Read the current sessions key if there is any - clt.sessionLock.RLock() - if clt.session == nil { - clt.sessionLock.RUnlock() - return nil - } - sessionKey := clt.session.Key - clt.sessionLock.RUnlock() - - // Try to restore session if necessary - restoredSession, err := clt.requestSessionRestoration([]byte(sessionKey)) - if err != nil { - // Just log a warning and still return nil, even if session restoration failed, - // because we only care about the connection establishment in this method - clt.warningLog.Printf("Couldn't restore session on reconnection: %s", err) - - // Reset the session - clt.sessionLock.Lock() - clt.session = nil - clt.sessionLock.Unlock() - return nil - } - - clt.sessionLock.Lock() - clt.session = restoredSession - clt.sessionLock.Unlock() - return nil +func (clt *Client) Connect() error { + return clt.connect() } // Request sends a request containing the given payload to the server @@ -172,6 +108,10 @@ func (clt *Client) Request( clt.apiLock.RLock() defer clt.apiLock.RUnlock() + if err := clt.tryAutoconnect(clt.defaultReqTimeout); err != nil { + return webwire.Payload{}, err + } + reqType := webwire.MsgRequestBinary switch payload.Encoding { case webwire.EncodingUtf8: @@ -179,7 +119,7 @@ func (clt *Client) Request( case webwire.EncodingUtf16: reqType = webwire.MsgRequestUtf16 } - return clt.sendRequest(reqType, name, payload, clt.defaultTimeout) + return clt.sendRequest(reqType, name, payload, clt.defaultReqTimeout) } // TimedRequest sends a request containing the given payload to the server @@ -195,6 +135,10 @@ func (clt *Client) TimedRequest( clt.apiLock.RLock() defer clt.apiLock.RUnlock() + if err := clt.tryAutoconnect(timeout); err != nil { + return webwire.Payload{}, err + } + reqType := webwire.MsgRequestBinary switch payload.Encoding { case webwire.EncodingUtf8: @@ -210,8 +154,8 @@ func (clt *Client) Signal(name string, payload webwire.Payload) error { clt.apiLock.RLock() defer clt.apiLock.RUnlock() - if atomic.LoadInt32(&clt.isConnected) < 1 { - return DisconnectedErr{} + if err := clt.connect(); err != nil { + return err } msgBytes := webwire.NewSignalMessage(name, payload) @@ -269,6 +213,10 @@ func (clt *Client) RestoreSession(sessionKey []byte) error { } clt.sessionLock.RUnlock() + if err := clt.tryAutoconnect(clt.defaultReqTimeout); err != nil { + return err + } + restoredSession, err := clt.requestSessionRestoration(sessionKey) if err != nil { return err @@ -300,7 +248,7 @@ func (clt *Client) CloseSession() error { if _, err := clt.sendNamelessRequest( webwire.MsgCloseSession, webwire.Payload{}, - clt.defaultTimeout, + clt.defaultReqTimeout, ); err != nil { return err } diff --git a/client/connect.go b/client/connect.go new file mode 100644 index 0000000..ba97db8 --- /dev/null +++ b/client/connect.go @@ -0,0 +1,87 @@ +package client + +import ( + "fmt" + "net/url" + "sync/atomic" + + "github.com/gorilla/websocket" + webwire "github.com/qbeon/webwire-go" +) + +func (clt *Client) connect() (err error) { + clt.connectLock.Lock() + defer clt.connectLock.Unlock() + if atomic.LoadInt32(&clt.isConnected) > 0 { + return nil + } + + if err := clt.verifyProtocolVersion(); err != nil { + return err + } + + connURL := url.URL{Scheme: "ws", Host: clt.serverAddr, Path: "/"} + + clt.connLock.Lock() + clt.conn, _, err = websocket.DefaultDialer.Dial(connURL.String(), nil) + if err != nil { + return webwire.NewDisconnectedErr(fmt.Errorf("Dial failure: %s", err)) + } + clt.connLock.Unlock() + + // Setup reader thread + go func() { + defer clt.close() + for { + _, message, err := clt.conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError( + err, + websocket.CloseGoingAway, + websocket.CloseAbnormalClosure, + ) { + // Error while reading message + clt.errorLog.Print("Failed reading message:", err) + break + } else { + // Shutdown client due to clean disconnection + break + } + } + // Try to handle the message + if err = clt.handleMessage(message); err != nil { + clt.warningLog.Print("Failed handling message:", err) + } + } + }() + + atomic.StoreInt32(&clt.isConnected, 1) + + // Read the current sessions key if there is any + clt.sessionLock.RLock() + if clt.session == nil { + clt.sessionLock.RUnlock() + return nil + } + sessionKey := clt.session.Key + clt.sessionLock.RUnlock() + + // Try to restore session if necessary + restoredSession, err := clt.requestSessionRestoration([]byte(sessionKey)) + if err != nil { + // Just log a warning and still return nil, even if session restoration failed, + // because we only care about the connection establishment in this method + clt.warningLog.Printf("Couldn't restore session on reconnection: %s", err) + + // Reset the session + clt.sessionLock.Lock() + clt.session = nil + clt.sessionLock.Unlock() + return nil + } + + clt.sessionLock.Lock() + clt.session = restoredSession + clt.sessionLock.Unlock() + return nil +} diff --git a/client/errors.go b/client/errors.go deleted file mode 100644 index d4c3fc6..0000000 --- a/client/errors.go +++ /dev/null @@ -1,8 +0,0 @@ -package client - -// DisconnectedErr is an error type indicating that the client isn't connected to the server -type DisconnectedErr struct{} - -func (err DisconnectedErr) Error() string { - return "Client is disconnected" -} diff --git a/client/options.go b/client/options.go index 910aff1..8499f3b 100644 --- a/client/options.go +++ b/client/options.go @@ -6,22 +6,59 @@ import ( "time" ) +// OptionToggle represents the value of a togglable option +type OptionToggle int + +const ( + // OptUnset defines unset togglable options + OptUnset OptionToggle = iota + + // OptDisabled defines disabled togglable options + OptDisabled + + // OptEnabled defines enabled togglable options + OptEnabled +) + // Options represents the options used during the creation a new client instance type Options struct { - Hooks Hooks + // Hooks define the callback hook functions provided by the user to define behavior + // on certain events + Hooks Hooks + + // DefaultRequestTimeout defines the default request timeout duration used in client.Request DefaultRequestTimeout time.Duration - WarnLog io.Writer - ErrorLog io.Writer + + // ReconnectionInterval defines the interval at which autoconnect should poll for a connection. + // If undefined then the default value of 2 seconds is applied + ReconnectionInterval time.Duration + + // If autoconnect is enabled, client.Request, client.TimedRequest and client.RestoreSession + // won't immediately return a disconnected error if there's no active connection to the server, + // instead they will automatically try to reestablish the connection + // before the timeout is triggered and a timeout error is returned. + // Autoconnect is enabled by default + Autoconnect OptionToggle + WarnLog io.Writer + ErrorLog io.Writer } // SetDefaults sets default values for undefined required options func (opts *Options) SetDefaults() { opts.Hooks.SetDefaults() + if opts.Autoconnect == OptUnset { + opts.Autoconnect = OptEnabled + } + if opts.DefaultRequestTimeout < 1 { opts.DefaultRequestTimeout = 60 * time.Second } + if opts.ReconnectionInterval < 1 { + opts.ReconnectionInterval = 2 * time.Second + } + if opts.WarnLog == nil { opts.WarnLog = os.Stdout } diff --git a/client/requestSessionRestoration.go b/client/requestSessionRestoration.go index eaec2b8..efb18ca 100644 --- a/client/requestSessionRestoration.go +++ b/client/requestSessionRestoration.go @@ -17,7 +17,7 @@ func (clt *Client) requestSessionRestoration(sessionKey []byte) (*webwire.Sessio Encoding: webwire.EncodingBinary, Data: sessionKey, }, - clt.defaultTimeout, + clt.defaultReqTimeout, ) if err != nil { return nil, err diff --git a/client/sendNamelessRequest.go b/client/sendNamelessRequest.go index a23476c..6e9820b 100644 --- a/client/sendNamelessRequest.go +++ b/client/sendNamelessRequest.go @@ -1,7 +1,6 @@ package client import ( - "sync/atomic" "time" "github.com/gorilla/websocket" @@ -13,10 +12,6 @@ func (clt *Client) sendNamelessRequest( payload webwire.Payload, timeout time.Duration, ) (webwire.Payload, error) { - if atomic.LoadInt32(&clt.isConnected) < 1 { - return webwire.Payload{}, DisconnectedErr{} - } - request := clt.requestManager.Create(timeout) reqIdentifier := request.Identifier() diff --git a/client/sendRequest.go b/client/sendRequest.go index 7b348f5..0ed4eb1 100644 --- a/client/sendRequest.go +++ b/client/sendRequest.go @@ -1,7 +1,6 @@ package client import ( - "sync/atomic" "time" "github.com/gorilla/websocket" @@ -14,10 +13,6 @@ func (clt *Client) sendRequest( payload webwire.Payload, timeout time.Duration, ) (webwire.Payload, error) { - if atomic.LoadInt32(&clt.isConnected) < 1 { - return webwire.Payload{}, DisconnectedErr{} - } - request := clt.requestManager.Create(timeout) reqIdentifier := request.Identifier() diff --git a/client/tryAutoconnect.go b/client/tryAutoconnect.go new file mode 100644 index 0000000..6b7580d --- /dev/null +++ b/client/tryAutoconnect.go @@ -0,0 +1,56 @@ +package client + +import ( + "sync/atomic" + "time" + + webwire "github.com/qbeon/webwire-go" +) + +func (clt *Client) tryAutoconnect(timeout time.Duration) error { + if atomic.LoadInt32(&clt.isConnected) > 0 { + return nil + } + + if clt.autoconnect { + stopTrying := make(chan error, 1) + connected := make(chan error, 1) + go func() { + for { + select { + case <-stopTrying: + return + default: + } + + err := clt.connect() + switch err := err.(type) { + case nil: + close(connected) + return + case webwire.DisconnectedErr: + time.Sleep(clt.reconnInterval) + default: + // Unexpected error + connected <- err + return + } + } + }() + + // TODO: implement autoconnect + select { + case err := <-connected: + return err + case <-time.After(timeout): + // Stop reconnection trial loop and return timeout error + close(stopTrying) + return webwire.ReqTimeoutErr{} + } + } else { + if err := clt.connect(); err != nil { + return err + } + } + return nil +} diff --git a/client/verifyProtocolVersion.go b/client/verifyProtocolVersion.go index 344c63a..60a7549 100644 --- a/client/verifyProtocolVersion.go +++ b/client/verifyProtocolVersion.go @@ -22,22 +22,24 @@ func (clt *Client) verifyProtocolVersion() error { "WEBWIRE", "http://"+clt.serverAddr+"/", nil, ) if err != nil { - return fmt.Errorf("Couldn't create HTTP metadata request: %s", err) + panic(fmt.Errorf("Couldn't create HTTP metadata request: %s", err)) } response, err := httpClient.Do(request) if err != nil { - return fmt.Errorf("Endpoint metadata request failed: %s", err) + return webwire.NewDisconnectedErr(fmt.Errorf( + "Endpoint metadata request failed: %s", err, + )) } // Read response body defer response.Body.Close() encodedData, err := ioutil.ReadAll(response.Body) if err != nil { - return fmt.Errorf("Couldn't read metadata response body: %s", err) + return webwire.NewProtocolErr(fmt.Errorf("Couldn't read metadata response body: %s", err)) } if response.StatusCode == http.StatusServiceUnavailable { - return fmt.Errorf("Endpoint unavailable: %s", response.Status) + return webwire.NewDisconnectedErr(fmt.Errorf("Endpoint unavailable: %s", response.Status)) } // Unmarshal response @@ -45,11 +47,11 @@ func (clt *Client) verifyProtocolVersion() error { ProtocolVersion string `json:"protocol-version"` } if err := json.Unmarshal(encodedData, &metadata); err != nil { - return fmt.Errorf( + return webwire.NewProtocolErr(fmt.Errorf( "Couldn't parse HTTP metadata response ('%s'): %s", string(encodedData), err, - ) + )) } // Verify metadata diff --git a/errors.go b/errors.go index c7cc12f..a9b4ca2 100644 --- a/errors.go +++ b/errors.go @@ -29,23 +29,6 @@ func NewConnIncompErr(requiredVersion, supportedVersion string) ConnIncompErr { } } -// ConnDialErr represents a connection error type indicating that the dialing failed. -type ConnDialErr struct { - msg string -} - -func (err ConnDialErr) Error() string { - return err.msg -} - -// NewConnDialErr constructs and returns a new connection dial error -// based on the actual error message -func NewConnDialErr(err error) ConnDialErr { - return ConnDialErr{ - msg: err.Error(), - } -} - // ReqTransErr represents a connection error type indicating that the dialing failed. type ReqTransErr struct { msg string @@ -124,9 +107,32 @@ func (err MaxSessConnsReachedErr) Error() string { // DisconnectedErr represents an error type indicating that the targeted client is disconnected type DisconnectedErr struct { - msg string + cause error +} + +// NewDisconnectedErr constructs a new DisconnectedErr error based on the actual error +func NewDisconnectedErr(err error) DisconnectedErr { + return DisconnectedErr{ + cause: err, + } } func (err DisconnectedErr) Error() string { - return err.msg + return err.cause.Error() +} + +// ProtocolErr represents an error type indicating an error in the protocol implementation +type ProtocolErr struct { + cause error +} + +// NewProtocolErr constructs a new ProtocolErr error based on the actual error +func NewProtocolErr(err error) ProtocolErr { + return ProtocolErr{ + cause: err, + } +} + +func (err ProtocolErr) Error() string { + return err.cause.Error() } diff --git a/server.go b/server.go index e0cff23..b81f02d 100644 --- a/server.go +++ b/server.go @@ -22,6 +22,11 @@ type Hooks struct { // using the HTTP OPTION method. OnOptions func(resp http.ResponseWriter) + // BeforeUpgrade is an optional hook. + // It's invoked right before the upgrade of the HTTP connection to a WebSocket connection + // and can be used to intercept, prevent or monitor connection attempts + BeforeUpgrade func(resp http.ResponseWriter, req *http.Request) bool + // OnClientConnected is an optional hook. // It's invoked when a new client establishes a connection to the server OnClientConnected func(client *Client) @@ -69,6 +74,12 @@ type Hooks struct { // SetDefaults sets undefined required hooks func (hooks *Hooks) SetDefaults() { + if hooks.BeforeUpgrade == nil { + hooks.BeforeUpgrade = func(_ http.ResponseWriter, _ *http.Request) bool { + return true + } + } + if hooks.OnClientConnected == nil { hooks.OnClientConnected = func(_ *Client) {} } @@ -395,6 +406,10 @@ func (srv *Server) ServeHTTP( return } + if !srv.hooks.BeforeUpgrade(resp, req) { + return + } + // Establish connection conn, err := srv.upgrader.Upgrade(resp, req, nil) if err != nil { diff --git a/test/clientRestoreSessionDisconnected_test.go b/test/clientRestoreSessionDisconnected_test.go index 209b1a0..b7b8f65 100644 --- a/test/clientRestoreSessionDisconnected_test.go +++ b/test/clientRestoreSessionDisconnected_test.go @@ -10,23 +10,31 @@ import ( ) // TestClientRestoreSessionDisconnected tests manual session restoration on disconnected client +// and expects client.RestoreSession to automatically establish a connection func TestClientRestoreSessionDisconnected(t *testing.T) { // Initialize webwire server _, addr := setupServer( t, - wwr.ServerOptions{}, + wwr.ServerOptions{ + SessionsEnabled: true, + Hooks: wwr.Hooks{ + OnSessionCreated: func(_ *wwr.Client) error { return nil }, + OnSessionLookup: func(_ string) (*wwr.Session, error) { return nil, nil }, + OnSessionClosed: func(_ *wwr.Client) error { return nil }, + }, + }, ) - // Initialize client + // Initialize client and skip manual connection establishment client := wwrclt.NewClient( addr, wwrclt.Options{ - DefaultRequestTimeout: 2 * time.Second, + DefaultRequestTimeout: 100 * time.Millisecond, }, ) - err := client.RestoreSession([]byte("somekey")) - if _, isDisconnErr := err.(wwrclt.DisconnectedErr); !isDisconnErr { + err := client.RestoreSession([]byte("inexistentkey")) + if _, isSessNotFoundErr := err.(wwr.SessNotFoundErr); !isSessNotFoundErr { t.Fatalf( "Expected disconnected error, got: %s | %s", reflect.TypeOf(err), diff --git a/test/clientSignalDisconnected_test.go b/test/clientSignalDisconnected_test.go index 404bde4..aab72e3 100644 --- a/test/clientSignalDisconnected_test.go +++ b/test/clientSignalDisconnected_test.go @@ -17,7 +17,7 @@ func TestClientSignalDisconnected(t *testing.T) { wwr.ServerOptions{}, ) - // Initialize client + // Initialize client and skip manual connection establishment client := wwrclt.NewClient( addr, wwrclt.Options{ @@ -26,10 +26,9 @@ func TestClientSignalDisconnected(t *testing.T) { ) // Send request and await reply - err := client.Signal("", wwr.Payload{Data: []byte("testdata")}) - if _, isDisconnErr := err.(wwrclt.DisconnectedErr); !isDisconnErr { + if err := client.Signal("", wwr.Payload{Data: []byte("testdata")}); err != nil { t.Fatalf( - "Expected disconnected error, got: %s | %s", + "Expected signal to automatically connect, got error: %s | %s", reflect.TypeOf(err), err, ) diff --git a/test/cllientRequestDisconnected_test.go b/test/cllientRequestDisconnected_test.go index 89c7854..3f699eb 100644 --- a/test/cllientRequestDisconnected_test.go +++ b/test/cllientRequestDisconnected_test.go @@ -1,6 +1,7 @@ package test import ( + "context" "reflect" "testing" "time" @@ -14,10 +15,16 @@ func TestClientRequestDisconnected(t *testing.T) { // Initialize webwire server given only the request _, addr := setupServer( t, - wwr.ServerOptions{}, + wwr.ServerOptions{ + Hooks: wwr.Hooks{ + OnRequest: func(_ context.Context) (wwr.Payload, error) { + return wwr.Payload{}, nil + }, + }, + }, ) - // Initialize client + // Initialize client and skip manual connection establishment client := wwrclt.NewClient( addr, wwrclt.Options{ @@ -26,10 +33,9 @@ func TestClientRequestDisconnected(t *testing.T) { ) // Send request and await reply - _, err := client.Request("", wwr.Payload{Data: []byte("testdata")}) - if _, isDisconnErr := err.(wwrclt.DisconnectedErr); !isDisconnErr { + if _, err := client.Request("", wwr.Payload{Data: []byte("testdata")}); err != nil { t.Fatalf( - "Expected disconnected error, got: %s | %s", + "Expected request to automatically connect, got error: %s | %s", reflect.TypeOf(err), err, )