Skip to content

Commit

Permalink
Merge e9b8794 into a92378f
Browse files Browse the repository at this point in the history
  • Loading branch information
jirenius committed Mar 1, 2021
2 parents a92378f + e9b8794 commit 8d84f23
Show file tree
Hide file tree
Showing 15 changed files with 243 additions and 11 deletions.
35 changes: 26 additions & 9 deletions nats/nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,21 @@ func (c *Client) IsClosed() bool {

// Close closes the client connection.
func (c *Client) Close() {
stopped := c.close()
if stopped == nil {
return
}

<-stopped
c.Debugf("NATS listener stopped")
}

func (c *Client) close() chan struct{} {
c.mu.Lock()
defer c.mu.Unlock()

if c.mq == nil {
c.mu.Unlock()
return
return nil
}

if !c.mq.IsClosed() {
Expand All @@ -134,10 +144,7 @@ func (c *Client) Close() {
stopped := c.stopped
c.stopped = nil

c.mu.Unlock()

<-stopped
c.Debugf("NATS listener stopped")
return stopped
}

// SetClosedHandler sets the handler when the connection is closed
Expand All @@ -156,21 +163,26 @@ func (c *Client) onClose(conn *nats.Conn) {
func (c *Client) SendRequest(subj string, payload []byte, cb mq.Response) {
inbox := nats.NewInbox()

// Validate max control line size
if len(subj)+len(inbox) > nats.MAX_CONTROL_LINE_SIZE {
go cb("", nil, mq.ErrSubjectTooLong)
return
}

c.mu.Lock()
defer c.mu.Unlock()

sub, err := c.mq.ChanSubscribe(inbox, c.mqCh)
if err != nil {
cb("", nil, err)
go cb("", nil, err)
return
}

c.Tracef("<== (%s) %s: %s", inboxSubstr(inbox), subj, payload)

err = c.mq.PublishRequest(subj, inbox, payload)
if err != nil {
sub.Unsubscribe()
cb("", nil, err)
go cb("", nil, err)
return
}

Expand All @@ -181,6 +193,11 @@ func (c *Client) SendRequest(subj string, payload []byte, cb mq.Response) {
// Subscribe to all events on a resource namespace.
// The namespace has the format "event."+resource
func (c *Client) Subscribe(namespace string, cb mq.Response) (mq.Unsubscriber, error) {
// Validate max control line size
if len(namespace) > nats.MAX_CONTROL_LINE_SIZE-2 {
return nil, mq.ErrSubjectTooLong
}

c.mu.Lock()
defer c.mu.Unlock()

Expand Down
2 changes: 2 additions & 0 deletions server/apiHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ func httpError(w http.ResponseWriter, err error, enc APIEncoder) {
code = http.StatusServiceUnavailable
case reserr.CodeForbidden:
code = http.StatusForbidden
case reserr.CodeSubjectTooLong:
code = http.StatusRequestURITooLong
default:
code = http.StatusBadRequest
}
Expand Down
3 changes: 3 additions & 0 deletions server/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ const (
// WSTimeout is the wait time for WebSocket connections to close on shutdown.
WSTimeout = 3 * time.Second

// MQTimeout is the wait time for the messaging client to close on shutdown.
MQTimeout = 3 * time.Second

// WSConnWorkerQueueSize is the size of the queue for each connection worker.
WSConnWorkerQueueSize = 256

Expand Down
6 changes: 5 additions & 1 deletion server/mq/mq.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type Client interface {
Connect() error

// SendRequest sends an asynchronous request on a subject, expecting the Response
// callback to be called once.
// callback to be called once on a separate go routine.
SendRequest(subject string, payload []byte, cb Response)

// Subscribe to all events on a resource namespace.
Expand All @@ -37,3 +37,7 @@ type Client interface {
// ErrRequestTimeout is the error the client should pass to the Response
// when a call to SendRequest times out
var ErrRequestTimeout = reserr.ErrTimeout

// ErrSubjectTooLong is the error the client should pass to the Response when
// the subject exceeds the maximum control line size
var ErrSubjectTooLong = reserr.ErrSubjectTooLong
17 changes: 16 additions & 1 deletion server/mqClient.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package server

import (
"time"

"github.com/resgateio/resgate/server/rescache"
)

Expand All @@ -25,7 +27,20 @@ func (s *Service) startMQClient() error {

// stopMQClient closes the connection to the nats server
func (s *Service) stopMQClient() {
s.mq.Close()
s.Debugf("Closing messaging client...")
done := make(chan struct{})
go func() {
defer close(done)
s.mq.Close()
}()

select {
case <-done:
s.Debugf("Messaging client gracefully closed")
case <-time.After(MQTimeout):
s.Errorf("Closing messaging client timed out. Continuing shutdown.")
}

s.Debugf("Stopping cache workers...")
s.cache.Stop()
s.Debugf("Cache workers stopped")
Expand Down
2 changes: 2 additions & 0 deletions server/reserr/reserr.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ const (
CodeTimeout = "system.timeout"
CodeInvalidRequest = "system.invalidRequest"
CodeUnsupportedProtocol = "system.unsupportedProtocol"
CodeSubjectTooLong = "system.subjectTooLong"
// HTTP only error codes
CodeBadRequest = "system.badRequest"
CodeMethodNotAllowed = "system.methodNotAllowed"
Expand All @@ -68,6 +69,7 @@ var (
ErrTimeout = &Error{Code: CodeTimeout, Message: "Request timeout"}
ErrInvalidRequest = &Error{Code: CodeInvalidRequest, Message: "Invalid request"}
ErrUnsupportedProtocol = &Error{Code: CodeUnsupportedProtocol, Message: "Unsupported protocol"}
ErrSubjectTooLong = &Error{Code: CodeSubjectTooLong, Message: "Subject too long"}
// HTTP only errors
ErrBadRequest = &Error{Code: CodeBadRequest, Message: "Bad request"}
ErrMethodNotAllowed = &Error{Code: CodeMethodNotAllowed, Message: "Method not allowed"}
Expand Down
9 changes: 9 additions & 0 deletions test/01subscribe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,12 @@ func TestSubscribe_MultipleClientsSubscribingResource_FetchedFromCache(t *testin
})
}
}

func TestSubscribe_LongResourceID_ReturnsErrSubjectTooLong(t *testing.T) {
runTest(t, func(s *Session) {
c := s.Connect()
c.Request("subscribe.test."+generateString(10000), nil).
GetResponse(t).
AssertError(t, reserr.ErrSubjectTooLong)
})
}
11 changes: 11 additions & 0 deletions test/02get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package test
import (
"encoding/json"
"testing"

"github.com/resgateio/resgate/server/reserr"
)

// Test that events are not sent to a model fetched with a client get request
Expand Down Expand Up @@ -89,3 +91,12 @@ func TestGet_WithCIDPlaceholder_ReplacesCID(t *testing.T) {
c.AssertNoEvent(t, "test."+cid+".model")
})
}

func TestGet_LongResourceID_ReturnsErrSubjectTooLong(t *testing.T) {
runTest(t, func(s *Session) {
c := s.Connect()
c.Request("get.test."+generateString(10000), nil).
GetResponse(t).
AssertError(t, reserr.ErrSubjectTooLong)
})
}
24 changes: 24 additions & 0 deletions test/04call_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -508,3 +508,27 @@ func TestCall_WithCIDPlaceholder_ReplacesCID(t *testing.T) {
AssertResult(t, json.RawMessage(`{"payload":"zoo"}`))
})
}

func TestCall_LongResourceMethod_ReturnsErrSubjectTooLong(t *testing.T) {
runTest(t, func(s *Session) {
c := s.Connect()
creq := c.Request("call.test."+generateString(10000), nil)

s.GetRequest(t).
AssertSubject(t, "access.test").
RespondSuccess(json.RawMessage(`{"get":true,"call":"*"}`))

creq.GetResponse(t).
AssertError(t, reserr.ErrSubjectTooLong)
})
}

func TestCall_LongResourceID_ReturnsErrSubjectTooLong(t *testing.T) {
runTest(t, func(s *Session) {
c := s.Connect()
creq := c.Request("call.test."+generateString(10000)+".method", nil)

creq.GetResponse(t).
AssertError(t, reserr.ErrSubjectTooLong)
})
}
20 changes: 20 additions & 0 deletions test/05auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,23 @@ func TestAuth_WithCIDPlaceholder_ReplacesCID(t *testing.T) {
AssertResult(t, json.RawMessage(`{"payload":"zoo"}`))
})
}

func TestAuth_LongResourceMethod_ReturnsErrSubjectTooLong(t *testing.T) {
runTest(t, func(s *Session) {
c := s.Connect()
creq := c.Request("auth.test."+generateString(10000), nil)

creq.GetResponse(t).
AssertError(t, reserr.ErrSubjectTooLong)
})
}

func TestAuth_LongResourceID_ReturnsErrSubjectTooLong(t *testing.T) {
runTest(t, func(s *Session) {
c := s.Connect()
creq := c.Request("auth.test."+generateString(10000)+".method", nil)

creq.GetResponse(t).
AssertError(t, reserr.ErrSubjectTooLong)
})
}
16 changes: 16 additions & 0 deletions test/12query_subscribe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,19 @@ func TestDifferentQueriesTriggersGetAndAccessRequests(t *testing.T) {
subscribeToTestQueryModel(t, s, c, "q=foo&f=baz", "q=foo&f=baz")
})
}

// Test subscribing to query model
func TestSubscribingToQuery_LongQuery_ReturnModel(t *testing.T) {
runTest(t, func(s *Session) {
event := json.RawMessage(`{"foo":"bar"}`)

c := s.Connect()
str := generateString(10000)
subscribeToTestQueryModel(t, s, c, "q="+str, "q="+str)

// Send event on non-query model and validate no event is sent to client
s.ResourceEvent("test.model", "custom", event)
c.AssertNoEvent(t, "test.model")
c.AssertNoNATSRequest(t, "test.model")
})
}
36 changes: 36 additions & 0 deletions test/14http_get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,39 @@ func TestHTTPGet_HeaderAuth_ExpectedResponse(t *testing.T) {
})
}
}

func TestHTTPGet_LongResourceID_ReturnsStatus414(t *testing.T) {
longStr := generateString(10000)
runTest(t, func(s *Session) {
hreq := s.HTTPRequest("GET", "/api/test/"+longStr, nil)

// Validate http response
hreq.GetResponse(t).
AssertError(t, reserr.ErrSubjectTooLong).
AssertStatusCode(t, http.StatusRequestURITooLong)
})
}

func TestHTTPGet_LongModelQuery_ReturnsModel(t *testing.T) {
query := "q=" + generateString(10000)
model := resourceData("test.model")

runTest(t, func(s *Session) {
hreq := s.HTTPRequest("GET", "/api/test/model?"+query, nil)

// Handle model get and access request
mreqs := s.GetParallelRequests(t, 2)
mreqs.
GetRequest(t, "access.test.model").
AssertPathPayload(t, "token", nil).
AssertPathPayload(t, "query", query).
RespondSuccess(json.RawMessage(`{"get":true}`))
mreqs.
GetRequest(t, "get.test.model").
AssertPathPayload(t, "query", query).
RespondSuccess(json.RawMessage(`{"model":` + model + `,"query":"` + query + `"}`))

// Validate http response
hreq.GetResponse(t).Equals(t, http.StatusOK, json.RawMessage(model))
})
}
51 changes: 51 additions & 0 deletions test/15http_post_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,54 @@ func TestHTTPPost_HeaderAuth_ExpectedResponse(t *testing.T) {
})
}
}

func TestHTTPPost_LongResourceID_ReturnsStatus414(t *testing.T) {
longStr := generateString(10000)
runTest(t, func(s *Session) {
hreq := s.HTTPRequest("POST", "/api/test/"+longStr+"/method", nil)

// Validate http response
hreq.GetResponse(t).
AssertError(t, reserr.ErrSubjectTooLong).
AssertStatusCode(t, http.StatusRequestURITooLong)
})
}

func TestHTTPPost_LongResourceMethod_ReturnsStatus414(t *testing.T) {
longStr := generateString(10000)
runTest(t, func(s *Session) {
hreq := s.HTTPRequest("POST", "/api/test/"+longStr, nil)

s.GetRequest(t).
AssertSubject(t, "access.test").
RespondSuccess(json.RawMessage(`{"get":true,"call":"*"}`))

// Validate http response
hreq.GetResponse(t).
AssertError(t, reserr.ErrSubjectTooLong).
AssertStatusCode(t, http.StatusRequestURITooLong)
})
}

func TestHTTPCall_LongModelQuery_ReturnsResult(t *testing.T) {
query := "q=" + generateString(10000)
successResponse := json.RawMessage(`{"foo":"bar"}`)

runTest(t, func(s *Session) {
hreq := s.HTTPRequest("POST", "/api/test/model/method?"+query, nil)

// Get access request
s.GetRequest(t).
AssertSubject(t, "access.test.model").
AssertPathPayload(t, "query", query).
RespondSuccess(json.RawMessage(`{"get":true,"call":"*"}`))
// Get call request
s.GetRequest(t).
AssertSubject(t, "call.test.model.method").
AssertPathPayload(t, "query", query).
RespondSuccess(successResponse)

// Validate http response
hreq.GetResponse(t).Equals(t, http.StatusOK, successResponse)
})
}
14 changes: 14 additions & 0 deletions test/natstest.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"testing"
"time"

"github.com/nats-io/nats.go"
"github.com/resgateio/resgate/logger"
"github.com/resgateio/resgate/server/mq"
"github.com/resgateio/resgate/server/reserr"
Expand Down Expand Up @@ -108,6 +109,14 @@ func (c *NATSTestClient) Close() {
// SendRequest sends an asynchronous request on a subject, expecting the Response
// callback to be called once.
func (c *NATSTestClient) SendRequest(subj string, payload []byte, cb mq.Response) {
// Validate max control line size
// 7 = nats inbox prefix length
// 22 = nuid size
if len(subj)+7+22 > nats.MAX_CONTROL_LINE_SIZE {
go cb("", nil, mq.ErrSubjectTooLong)
return
}

c.mu.Lock()
defer c.mu.Unlock()

Expand Down Expand Up @@ -136,6 +145,11 @@ func (c *NATSTestClient) SendRequest(subj string, payload []byte, cb mq.Response
// Subscribe to all events on a resource namespace.
// The namespace has the format "event."+resource
func (c *NATSTestClient) Subscribe(namespace string, cb mq.Response) (mq.Unsubscriber, error) {
// Validate max control line size
if len(namespace) > nats.MAX_CONTROL_LINE_SIZE-2 {
return nil, mq.ErrSubjectTooLong
}

c.mu.Lock()
defer c.mu.Unlock()

Expand Down

0 comments on commit 8d84f23

Please sign in to comment.