diff --git a/client.go b/client.go index 84d8419..d360d2f 100644 --- a/client.go +++ b/client.go @@ -65,7 +65,7 @@ func (clt *Client) Signal(name string, payload Payload) error { // Returns an error if there's already another session active func (clt *Client) CreateSession(attachment interface{}) error { if !clt.srv.sessionsEnabled { - return fmt.Errorf("Sessions disabled") + return SessionsDisabled{} } clt.lock.Lock() @@ -131,7 +131,7 @@ func (clt *Client) notifySessionClosed() error { // Does nothing if there's no active session func (clt *Client) CloseSession() error { if !clt.srv.sessionsEnabled { - return fmt.Errorf("Sessions disabled") + return SessionsDisabled{} } if clt.Session == nil { return nil diff --git a/client/client.go b/client/client.go index 62a35eb..5eeec88 100644 --- a/client/client.go +++ b/client/client.go @@ -302,7 +302,7 @@ func (clt *Client) CloseSession() error { webwire.Payload{}, clt.defaultTimeout, ); err != nil { - return fmt.Errorf("Session destruction request failed: %s", err) + return err } } diff --git a/client/handle.go b/client/handle.go index c3eb830..5b74ed7 100644 --- a/client/handle.go +++ b/client/handle.go @@ -58,6 +58,10 @@ func (clt *Client) handleMaxSessConnsReached(reqIdent [8]byte) { clt.requestManager.Fail(reqIdent, webwire.MaxSessConnsReached{}) } +func (clt *Client) handleSessionsDisabled(reqIdent [8]byte) { + clt.requestManager.Fail(reqIdent, webwire.SessionsDisabled{}) +} + func (clt *Client) handleReply(reqID [8]byte, payload webwire.Payload) { clt.requestManager.Fulfill(reqID, payload) } @@ -97,6 +101,8 @@ func (clt *Client) handleMessage(message []byte) error { clt.handleSessionNotFound(extractMessageIdentifier(message)) case webwire.MsgMaxSessConnsReached: clt.handleMaxSessConnsReached(extractMessageIdentifier(message)) + case webwire.MsgSessionsDisabled: + clt.handleSessionsDisabled(extractMessageIdentifier(message)) case webwire.MsgErrorReply: clt.handleFailure(extractMessageIdentifier(message), message[9:]) case webwire.MsgReplyInternalError: diff --git a/errors.go b/errors.go index 6c8eb6a..0f8fd16 100644 --- a/errors.go +++ b/errors.go @@ -99,6 +99,13 @@ func (err ReqErr) Error() string { return err.Message } +// SessionsDisabled represents an error type indicating that the server has sessions disabled +type SessionsDisabled struct{} + +func (err SessionsDisabled) Error() string { + return "Sessions are disabled for this server" +} + // SessNotFound represents a session restoration error type indicating that the server didn't // find the session to be restored type SessNotFound struct{} diff --git a/message.go b/message.go index 8cf4deb..d35ba4b 100644 --- a/message.go +++ b/message.go @@ -71,6 +71,10 @@ const ( // when the maximum number of concurrent connections for a certain session was reached MsgMaxSessConnsReached = byte(4) + // MsgSessionsDisabled is sent by the server in response to a session restoration request + // if sessions are disabled for the target server + MsgSessionsDisabled = byte(5) + // MsgSessionCreated is sent by the server // to notify the client about the session creation MsgSessionCreated = byte(21) @@ -786,6 +790,8 @@ func (msg *Message) createFailCallback(client *Client, srv *Server) { msgType = MsgMaxSessConnsReached case SessNotFound: msgType = MsgSessionNotFound + case SessionsDisabled: + msgType = MsgSessionsDisabled default: msgType = MsgReplyInternalError } diff --git a/server.go b/server.go index 73e22ec..08199d8 100644 --- a/server.go +++ b/server.go @@ -104,6 +104,7 @@ func (hooks *Hooks) SetDefaults() { // ServerOptions represents the options used during the creation of a new WebWire server instance type ServerOptions struct { Hooks Hooks + SessionsEnabled bool MaxSessionConnections uint WarnLog io.Writer ErrorLog io.Writer @@ -147,11 +148,16 @@ type Server struct { func NewServer(opts ServerOptions) *Server { opts.SetDefaults() - sessionsEnabled := false - if opts.Hooks.OnSessionCreated != nil && - opts.Hooks.OnSessionLookup != nil && - opts.Hooks.OnSessionClosed != nil { - sessionsEnabled = true + if opts.SessionsEnabled { + if opts.Hooks.OnSessionCreated == nil { + panic("Expected OnSessionCreated hook to be defined because sessions are enabled") + } + if opts.Hooks.OnSessionLookup == nil { + panic("Expected OnSessionLookup hook to be defined because sessions are enabled") + } + if opts.Hooks.OnSessionClosed == nil { + panic("Expected OnSessionClosed hook to be defined because sessions are enabled") + } } srv := Server{ @@ -164,7 +170,7 @@ func NewServer(opts ServerOptions) *Server { opsLock: sync.Mutex{}, clients: make([]*Client, 0), clientsLock: &sync.Mutex{}, - sessionsEnabled: sessionsEnabled, + sessionsEnabled: opts.SessionsEnabled, SessionRegistry: newSessionRegistry(opts.MaxSessionConnections), // Internals @@ -194,11 +200,7 @@ func NewServer(opts ServerOptions) *Server { // and returns an error if the ongoing connection cannot be proceeded func (srv *Server) handleSessionRestore(msg *Message) error { if !srv.sessionsEnabled { - // TODO: Implement dedicated error message type for "sessions disabled" errors - msg.fail(ReqErr{ - Code: "SESSIONS_DISABLED", - Message: "Sessions are disabled on this server instance", - }) + msg.fail(SessionsDisabled{}) return nil } @@ -244,11 +246,7 @@ func (srv *Server) handleSessionRestore(msg *Message) error { // and returns an error if the ongoing connection cannot be proceeded func (srv *Server) handleSessionClosure(msg *Message) error { if !srv.sessionsEnabled { - // TODO: Implement dedicated error message type for "max connection reached" errors - msg.fail(ReqErr{ - Code: "SESSIONS_DISABLED", - Message: "Sessions are disabled on this server instance", - }) + msg.fail(SessionsDisabled{}) return nil } diff --git a/test/activeSessionRegistry_test.go b/test/activeSessionRegistry_test.go index 90f302f..3cabe2e 100644 --- a/test/activeSessionRegistry_test.go +++ b/test/activeSessionRegistry_test.go @@ -15,6 +15,7 @@ func TestActiveSessionRegistry(t *testing.T) { srv, addr := setupServer( t, webwire.ServerOptions{ + SessionsEnabled: true, Hooks: webwire.Hooks{ OnRequest: func(ctx context.Context) (webwire.Payload, error) { // Extract request message and requesting client from the context diff --git a/test/authentication_test.go b/test/authentication_test.go index 4d980e5..00a7716 100644 --- a/test/authentication_test.go +++ b/test/authentication_test.go @@ -64,6 +64,7 @@ func TestAuthentication(t *testing.T) { _, addr := setupServer( t, wwr.ServerOptions{ + SessionsEnabled: true, Hooks: wwr.Hooks{ OnSignal: func(ctx context.Context) { defer clientSignalReceived.Done() diff --git a/test/clientAutomaticSessionRestoration_test.go b/test/clientAutomaticSessionRestoration_test.go index 1d1d92d..1da4e16 100644 --- a/test/clientAutomaticSessionRestoration_test.go +++ b/test/clientAutomaticSessionRestoration_test.go @@ -21,6 +21,7 @@ func TestClientAutomaticSessionRestoration(t *testing.T) { _, addr := setupServer( t, webwire.ServerOptions{ + SessionsEnabled: true, Hooks: webwire.Hooks{ OnRequest: func(ctx context.Context) (webwire.Payload, error) { // Extract request message and requesting client from the context diff --git a/test/clientInitiatedSessionDestruction_test.go b/test/clientInitiatedSessionDestruction_test.go index a3dc048..b3e3c14 100644 --- a/test/clientInitiatedSessionDestruction_test.go +++ b/test/clientInitiatedSessionDestruction_test.go @@ -29,6 +29,7 @@ func TestClientInitiatedSessionDestruction(t *testing.T) { _, addr := setupServer( t, webwire.ServerOptions{ + SessionsEnabled: true, Hooks: webwire.Hooks{ OnRequest: func(ctx context.Context) (webwire.Payload, error) { // Extract request message and requesting client from the context diff --git a/test/clientOfflineSessionClosure_test.go b/test/clientOfflineSessionClosure_test.go index 2a6f285..7d3e99c 100644 --- a/test/clientOfflineSessionClosure_test.go +++ b/test/clientOfflineSessionClosure_test.go @@ -21,6 +21,7 @@ func TestClientOfflineSessionClosure(t *testing.T) { _, addr := setupServer( t, webwire.ServerOptions{ + SessionsEnabled: true, Hooks: webwire.Hooks{ OnRequest: func(ctx context.Context) (webwire.Payload, error) { // Extract request message and requesting client from the context diff --git a/test/clientOnSessionClosed_test.go b/test/clientOnSessionClosed_test.go index 3c1a74e..6e00e89 100644 --- a/test/clientOnSessionClosed_test.go +++ b/test/clientOnSessionClosed_test.go @@ -18,6 +18,7 @@ func TestClientOnSessionClosed(t *testing.T) { _, addr := setupServer( t, webwire.ServerOptions{ + SessionsEnabled: true, Hooks: webwire.Hooks{ OnRequest: func(ctx context.Context) (webwire.Payload, error) { // Extract request message and requesting client from the context diff --git a/test/clientOnSessionCreated_test.go b/test/clientOnSessionCreated_test.go index 05f5b09..df02604 100644 --- a/test/clientOnSessionCreated_test.go +++ b/test/clientOnSessionCreated_test.go @@ -19,6 +19,7 @@ func TestClientOnSessionCreated(t *testing.T) { _, addr := setupServer( t, webwire.ServerOptions{ + SessionsEnabled: true, Hooks: webwire.Hooks{ OnRequest: func(ctx context.Context) (webwire.Payload, error) { // Extract request message and requesting client from the context diff --git a/test/clientSessionInfo_test.go b/test/clientSessionInfo_test.go index 59ed1e3..6b2e861 100644 --- a/test/clientSessionInfo_test.go +++ b/test/clientSessionInfo_test.go @@ -27,6 +27,7 @@ func TestClientSessionInfo(t *testing.T) { _, addr := setupServer( t, webwire.ServerOptions{ + SessionsEnabled: true, Hooks: webwire.Hooks{ OnRequest: func(ctx context.Context) (webwire.Payload, error) { // Extract request message and requesting client from the context diff --git a/test/clientSessionRestoration_test.go b/test/clientSessionRestoration_test.go index 15c750c..3932eac 100644 --- a/test/clientSessionRestoration_test.go +++ b/test/clientSessionRestoration_test.go @@ -21,6 +21,7 @@ func TestClientSessionRestoration(t *testing.T) { _, addr := setupServer( t, webwire.ServerOptions{ + SessionsEnabled: true, Hooks: webwire.Hooks{ OnRequest: func(ctx context.Context) (webwire.Payload, error) { // Extract request message and requesting client from the context diff --git a/test/disabledSessions_test.go b/test/disabledSessions_test.go new file mode 100644 index 0000000..e81b9c1 --- /dev/null +++ b/test/disabledSessions_test.go @@ -0,0 +1,72 @@ +package test + +import ( + "context" + "reflect" + "testing" + "time" + + wwr "github.com/qbeon/webwire-go" + wwrclt "github.com/qbeon/webwire-go/client" +) + +// TestDisabledSessions verifies the server is connectable, +// and is able to receives requests and signals, create sessions +// and identify clients during request- and signal handling +func TestDisabledSessions(t *testing.T) { + verifyError := func(err error) { + if _, isDisabledErr := err.(wwr.SessionsDisabled); !isDisabledErr { + t.Errorf( + "Expected SessionsDisabled error, got: %s | %s", + reflect.TypeOf(err), + err, + ) + } + } + + // Initialize webwire server + _, addr := setupServer( + t, + wwr.ServerOptions{ + SessionsEnabled: false, + Hooks: wwr.Hooks{ + OnRequest: func(ctx context.Context) (wwr.Payload, error) { + // Extract request message and requesting client from the context + msg := ctx.Value(wwr.Msg).(wwr.Message) + + // Try to create a new session and expect an error + createErr := msg.Client.CreateSession(nil) + verifyError(createErr) + + // Try to create a new session and expect an error + closeErr := msg.Client.CloseSession() + verifyError(closeErr) + + return wwr.Payload{}, nil + }, + }, + }, + ) + + // Initialize client + client := wwrclt.NewClient( + addr, + wwrclt.Options{ + DefaultRequestTimeout: 2 * time.Second, + }, + ) + defer client.Close() + + if err := client.Connect(); err != nil { + t.Fatalf("Couldn't connect: %s", err) + } + + // Send authentication request and await reply + _, err := client.Request("login", wwr.Payload{Data: []byte("testdata")}) + if err != nil { + t.Fatalf("Request failed: %s", err) + } + + sessRestErr := client.RestoreSession([]byte("testkey")) + verifyError(sessRestErr) +} diff --git a/test/maxConcSessConn_test.go b/test/maxConcSessConn_test.go index ab0c92c..e5e8994 100644 --- a/test/maxConcSessConn_test.go +++ b/test/maxConcSessConn_test.go @@ -22,6 +22,7 @@ func TestMaxConcSessConn(t *testing.T) { _, addr := setupServer( t, wwr.ServerOptions{ + SessionsEnabled: true, MaxSessionConnections: concurrentConns, Hooks: wwr.Hooks{ OnClientConnected: func(client *wwr.Client) { diff --git a/test/restoreInexistentSession_test.go b/test/restoreInexistentSession_test.go index f427bd6..da0442e 100644 --- a/test/restoreInexistentSession_test.go +++ b/test/restoreInexistentSession_test.go @@ -15,6 +15,7 @@ func TestRestoreInexistentSession(t *testing.T) { _, addr := setupServer( t, wwr.ServerOptions{ + SessionsEnabled: true, Hooks: wwr.Hooks{ // Permanently store the session OnSessionCreated: func(_ *wwr.Client) error { diff --git a/test/serverInitiatedSessionDestruction_test.go b/test/serverInitiatedSessionDestruction_test.go index d9a3cc5..7acffa6 100644 --- a/test/serverInitiatedSessionDestruction_test.go +++ b/test/serverInitiatedSessionDestruction_test.go @@ -28,6 +28,7 @@ func TestServerInitiatedSessionDestruction(t *testing.T) { _, addr := setupServer( t, webwire.ServerOptions{ + SessionsEnabled: true, Hooks: webwire.Hooks{ OnRequest: func(ctx context.Context) (webwire.Payload, error) { // Extract request message and requesting client from the context