diff --git a/README.md b/README.md index d98f00c3..61fc41da 100644 --- a/README.md +++ b/README.md @@ -211,6 +211,12 @@ r.Table("test").Insert(doc, r.InsertOpts{ As shown above in the Between example optional arguments are passed to the function as a struct. Each function that has optional arguments as a related struct. This structs are named in the format FunctionNameOpts, for example BetweenOpts is the related struct for Between. +#### Cancelling queries + +For query cancellation use `Context` argument at `RunOpts`. If `Context` is `nil` and `ReadTimeout` or `WriteTimeout` is not 0 from `ConnectionOpts`, `Context` will be formed by summation of these timeouts. + +For unlimited timeouts for `Changes()` pass `context.Background()`. + ## Results Different result types are returned depending on what function is used to execute the query. diff --git a/cluster.go b/cluster.go index 2f14e825..64ff21f2 100644 --- a/cluster.go +++ b/cluster.go @@ -10,6 +10,7 @@ import ( "github.com/Sirupsen/logrus" "github.com/cenkalti/backoff" "github.com/hailocab/go-hostpool" + "golang.org/x/net/context" ) // A Cluster represents a connection to a RethinkDB cluster, a cluster is created @@ -57,7 +58,7 @@ func NewCluster(hosts []Host, opts *ConnectOpts) (*Cluster, error) { } // Query executes a ReQL query using the cluster to connect to the database -func (c *Cluster) Query(q Query) (cursor *Cursor, err error) { +func (c *Cluster) Query(ctx context.Context, q Query) (cursor *Cursor, err error) { for i := 0; i < c.numRetries(); i++ { var node *Node var hpr hostpool.HostPoolResponse @@ -67,7 +68,7 @@ func (c *Cluster) Query(q Query) (cursor *Cursor, err error) { return nil, err } - cursor, err = node.Query(q) + cursor, err = node.Query(ctx, q) hpr.Mark(err) if !shouldRetryQuery(q, err) { @@ -79,7 +80,7 @@ func (c *Cluster) Query(q Query) (cursor *Cursor, err error) { } // Exec executes a ReQL query using the cluster to connect to the database -func (c *Cluster) Exec(q Query) (err error) { +func (c *Cluster) Exec(ctx context.Context, q Query) (err error) { for i := 0; i < c.numRetries(); i++ { var node *Node var hpr hostpool.HostPoolResponse @@ -89,7 +90,7 @@ func (c *Cluster) Exec(q Query) (err error) { return err } - err = node.Exec(q) + err = node.Exec(ctx, q) hpr.Mark(err) if !shouldRetryQuery(q, err) { @@ -204,7 +205,7 @@ func (c *Cluster) listenForNodeChanges() error { return fmt.Errorf("Error building query: %s", err) } - cursor, err := node.Query(q) + cursor, err := node.Query(context.Background(), q) // no need for timeout due to Changes() if err != nil { hpr.Mark(err) return err @@ -279,7 +280,7 @@ func (c *Cluster) connectNodes(hosts []Host) error { continue } - _, cursor, err := conn.Query(q) + _, cursor, err := conn.Query(nil, q) // nil = connection opts' timeout if err != nil { attemptErr = err Log.Warnf("Error fetching cluster status: %s", err) diff --git a/connection.go b/connection.go index 75141e13..e8adc25e 100644 --- a/connection.go +++ b/connection.go @@ -10,6 +10,7 @@ import ( "sync/atomic" "time" + "golang.org/x/net/context" p "gopkg.in/gorethink/gorethink.v3/ql2" ) @@ -104,7 +105,11 @@ func (c *Connection) Close() error { // Cursor which should be used to view the query's response. // // This function is used internally by Run which should be used for most queries. -func (c *Connection) Query(q Query) (*Response, *Cursor, error) { +func (c *Connection) Query(ctx context.Context, q Query) (*Response, *Cursor, error) { + if ctx == nil { + ctx = c.contextFromConnectionOpts() + } + if c == nil { return nil, nil, ErrConnectionClosed } @@ -131,30 +136,51 @@ func (c *Connection) Query(q Query) (*Response, *Cursor, error) { } c.mu.Unlock() - err := c.sendQuery(q) - if err != nil { - return nil, nil, err - } + var response *Response + var cursor *Cursor + var errchan chan error = make(chan error, 1) + go func() { + err := c.sendQuery(q) + if err != nil { + errchan <- err + return + } - if noreply, ok := q.Opts["noreply"]; ok && noreply.(bool) { - return nil, nil, nil - } + if noreply, ok := q.Opts["noreply"]; ok && noreply.(bool) { + errchan <- nil + return + } - for { - response, err := c.readResponse() - if err != nil { - return nil, nil, err + for { + response, err := c.readResponse() + if err != nil { + errchan <- err + return + } + + if response.Token == q.Token { + // If this was the requested response process and return + response, cursor, err = c.processResponse(ctx, q, response) + errchan <- err + return + } else if _, ok := c.cursors[response.Token]; ok { + // If the token is in the cursor cache then process the response + c.processResponse(ctx, q, response) + } else { + putResponse(response) + } } + }() - if response.Token == q.Token { - // If this was the requested response process and return - return c.processResponse(q, response) - } else if _, ok := c.cursors[response.Token]; ok { - // If the token is in the cursor cache then process the response - c.processResponse(q, response) - } else { - putResponse(response) + select { + case err := <-errchan: + return response, cursor, err + case <-ctx.Done(): + if q.Type != p.Query_STOP { + stopQuery := newStopQuery(q.Token) + c.Query(c.contextFromConnectionOpts(), stopQuery) } + return nil, nil, ErrQueryTimeout } } @@ -167,7 +193,7 @@ type ServerResponse struct { func (c *Connection) Server() (ServerResponse, error) { var response ServerResponse - _, cur, err := c.Query(Query{ + _, cur, err := c.Query(c.contextFromConnectionOpts(), Query{ Type: p.Query_SERVER_INFO, }) if err != nil { @@ -255,7 +281,7 @@ func (c *Connection) readResponse() (*Response, error) { return response, nil } -func (c *Connection) processResponse(q Query, response *Response) (*Response, *Cursor, error) { +func (c *Connection) processResponse(ctx context.Context, q Query, response *Response) (*Response, *Cursor, error) { switch response.Type { case p.Response_CLIENT_ERROR: return c.processErrorResponse(q, response, RQLClientError{rqlServerError{response, q.Term}}) @@ -264,11 +290,11 @@ func (c *Connection) processResponse(q Query, response *Response) (*Response, *C case p.Response_RUNTIME_ERROR: return c.processErrorResponse(q, response, createRuntimeError(response.ErrorType, response, q.Term)) case p.Response_SUCCESS_ATOM, p.Response_SERVER_INFO: - return c.processAtomResponse(q, response) + return c.processAtomResponse(ctx, q, response) case p.Response_SUCCESS_PARTIAL: - return c.processPartialResponse(q, response) + return c.processPartialResponse(ctx, q, response) case p.Response_SUCCESS_SEQUENCE: - return c.processSequenceResponse(q, response) + return c.processSequenceResponse(ctx, q, response) case p.Response_WAIT_COMPLETE: return c.processWaitResponse(q, response) default: @@ -287,9 +313,9 @@ func (c *Connection) processErrorResponse(q Query, response *Response, err error return response, cursor, err } -func (c *Connection) processAtomResponse(q Query, response *Response) (*Response, *Cursor, error) { +func (c *Connection) processAtomResponse(ctx context.Context, q Query, response *Response) (*Response, *Cursor, error) { // Create cursor - cursor := newCursor(c, "Cursor", response.Token, q.Term, q.Opts) + cursor := newCursor(ctx, c, "Cursor", response.Token, q.Term, q.Opts) cursor.profile = response.Profile cursor.extend(response) @@ -297,7 +323,7 @@ func (c *Connection) processAtomResponse(q Query, response *Response) (*Response return response, cursor, nil } -func (c *Connection) processPartialResponse(q Query, response *Response) (*Response, *Cursor, error) { +func (c *Connection) processPartialResponse(ctx context.Context, q Query, response *Response) (*Response, *Cursor, error) { cursorType := "Cursor" if len(response.Notes) > 0 { switch response.Notes[0] { @@ -318,7 +344,7 @@ func (c *Connection) processPartialResponse(q Query, response *Response) (*Respo cursor, ok := c.cursors[response.Token] if !ok { // Create a new cursor if needed - cursor = newCursor(c, cursorType, response.Token, q.Term, q.Opts) + cursor = newCursor(ctx, c, cursorType, response.Token, q.Term, q.Opts) cursor.profile = response.Profile c.cursors[response.Token] = cursor @@ -330,12 +356,12 @@ func (c *Connection) processPartialResponse(q Query, response *Response) (*Respo return response, cursor, nil } -func (c *Connection) processSequenceResponse(q Query, response *Response) (*Response, *Cursor, error) { +func (c *Connection) processSequenceResponse(ctx context.Context, q Query, response *Response) (*Response, *Cursor, error) { c.mu.Lock() cursor, ok := c.cursors[response.Token] if !ok { // Create a new cursor if needed - cursor = newCursor(c, "Cursor", response.Token, q.Term, q.Opts) + cursor = newCursor(ctx, c, "Cursor", response.Token, q.Term, q.Opts) cursor.profile = response.Profile } diff --git a/connection_helper.go b/connection_helper.go index 68460707..6590de8a 100644 --- a/connection_helper.go +++ b/connection_helper.go @@ -1,6 +1,9 @@ package gorethink -import "encoding/binary" +import ( + "encoding/binary" + "golang.org/x/net/context" +) // Write 'data' to conn func (c *Connection) writeData(data []byte) error { @@ -39,3 +42,12 @@ func (c *Connection) writeQuery(token int64, q []byte) error { return c.writeData(data) } + +func (c *Connection) contextFromConnectionOpts() context.Context { + sum := c.opts.ReadTimeout + c.opts.WriteTimeout + if sum == 0 { + return context.Background() + } + ctx, _ := context.WithTimeout(context.Background(), sum) + return ctx +} diff --git a/cursor.go b/cursor.go index 41bd5fbc..fe670951 100644 --- a/cursor.go +++ b/cursor.go @@ -7,6 +7,7 @@ import ( "reflect" "sync" + "golang.org/x/net/context" "gopkg.in/gorethink/gorethink.v3/encoding" p "gopkg.in/gorethink/gorethink.v3/ql2" ) @@ -16,7 +17,7 @@ var ( errCursorClosed = errors.New("connection closed, cannot read cursor") ) -func newCursor(conn *Connection, cursorType string, token int64, term *Term, opts map[string]interface{}) *Cursor { +func newCursor(ctx context.Context, conn *Connection, cursorType string, token int64, term *Term, opts map[string]interface{}) *Cursor { if cursorType == "" { cursorType = "Cursor" } @@ -35,6 +36,7 @@ func newCursor(conn *Connection, cursorType string, token int64, term *Term, opt opts: opts, buffer: make([]interface{}, 0), responses: make([]json.RawMessage, 0), + ctx: ctx, } return cursor @@ -64,6 +66,7 @@ type Cursor struct { cursorType string term *Term opts map[string]interface{} + ctx context.Context mu sync.RWMutex lastErr error @@ -145,15 +148,7 @@ func (c *Cursor) Close() error { // Stop any unfinished queries if !c.finished { - q := Query{ - Type: p.Query_STOP, - Token: c.token, - Opts: map[string]interface{}{ - "noreply": true, - }, - } - - _, _, err = conn.Query(q) + _, _, err = conn.Query(c.ctx, newStopQuery(c.token)) } if c.releaseConn != nil { @@ -552,7 +547,7 @@ func (c *Cursor) fetchMore() error { } c.mu.Unlock() - _, _, err = c.conn.Query(q) + _, _, err = c.conn.Query(c.ctx, q) c.mu.Lock() } diff --git a/errors.go b/errors.go index 3c8a1682..c5d60f94 100644 --- a/errors.go +++ b/errors.go @@ -25,6 +25,8 @@ var ( // ErrConnectionClosed is returned when trying to send a query with a closed // connection. ErrConnectionClosed = errors.New("gorethink: the connection is closed") + // ErrQueryTimeout is returned when query context deadline exceeded. + ErrQueryTimeout = errors.New("gorethink: query timeout") ) func printCarrots(t Term, frames []*p.Frame) string { diff --git a/mock.go b/mock.go index af34d4e5..7d91fcac 100644 --- a/mock.go +++ b/mock.go @@ -7,6 +7,7 @@ import ( "sync" "time" + "golang.org/x/net/context" p "gopkg.in/gorethink/gorethink.v3/ql2" ) @@ -290,7 +291,7 @@ func (m *Mock) IsConnected() bool { return true } -func (m *Mock) Query(q Query) (*Cursor, error) { +func (m *Mock) Query(ctx context.Context, q Query) (*Cursor, error) { found, query := m.findExpectedQuery(q) if found < 0 { @@ -328,7 +329,7 @@ func (m *Mock) Query(q Query) (*Cursor, error) { } // Build cursor and return - c := newCursor(nil, "", query.Query.Token, query.Query.Term, query.Query.Opts) + c := newCursor(ctx, nil, "", query.Query.Token, query.Query.Term, query.Query.Opts) c.finished = true c.fetching = false c.isAtom = true @@ -345,8 +346,8 @@ func (m *Mock) Query(q Query) (*Cursor, error) { return c, nil } -func (m *Mock) Exec(q Query) error { - _, err := m.Query(q) +func (m *Mock) Exec(ctx context.Context, q Query) error { + _, err := m.Query(ctx, q) return err } diff --git a/node.go b/node.go index ca0ce89c..6d521cba 100644 --- a/node.go +++ b/node.go @@ -3,6 +3,7 @@ package gorethink import ( "sync" + "golang.org/x/net/context" p "gopkg.in/gorethink/gorethink.v3/ql2" ) @@ -83,27 +84,27 @@ func (n *Node) SetMaxOpenConns(openConns int) { // processed by the server. Note that this guarantee only applies to queries // run on the given connection func (n *Node) NoReplyWait() error { - return n.pool.Exec(Query{ + return n.pool.Exec(nil, Query{ // nil = connection opts' timeout Type: p.Query_NOREPLY_WAIT, }) } // Query executes a ReQL query using this nodes connection pool. -func (n *Node) Query(q Query) (cursor *Cursor, err error) { +func (n *Node) Query(ctx context.Context, q Query) (cursor *Cursor, err error) { if n.Closed() { return nil, ErrInvalidNode } - return n.pool.Query(q) + return n.pool.Query(ctx, q) } // Exec executes a ReQL query using this nodes connection pool. -func (n *Node) Exec(q Query) (err error) { +func (n *Node) Exec(ctx context.Context, q Query) (err error) { if n.Closed() { return ErrInvalidNode } - return n.pool.Exec(q) + return n.pool.Exec(ctx, q) } // Server returns the server name and server UUID being used by a connection. diff --git a/pool.go b/pool.go index 1a7bfa72..9b6ae4e9 100644 --- a/pool.go +++ b/pool.go @@ -6,6 +6,7 @@ import ( "net" "sync" + "golang.org/x/net/context" "gopkg.in/fatih/pool.v2" ) @@ -136,14 +137,14 @@ func (p *Pool) SetMaxOpenConns(n int) { // Query execution functions // Exec executes a query without waiting for any response. -func (p *Pool) Exec(q Query) error { +func (p *Pool) Exec(ctx context.Context, q Query) error { c, pc, err := p.conn() if err != nil { return err } defer pc.Close() - _, _, err = c.Query(q) + _, _, err = c.Query(ctx, q) if c.isBad() { pc.MarkUnusable() @@ -153,13 +154,13 @@ func (p *Pool) Exec(q Query) error { } // Query executes a query and waits for the response -func (p *Pool) Query(q Query) (*Cursor, error) { +func (p *Pool) Query(ctx context.Context, q Query) (*Cursor, error) { c, pc, err := p.conn() if err != nil { return nil, err } - _, cursor, err := c.Query(q) + _, cursor, err := c.Query(ctx, q) if err == nil { cursor.releaseConn = releaseConn(c, pc) diff --git a/query.go b/query.go index ca53dca6..0343947c 100644 --- a/query.go +++ b/query.go @@ -6,6 +6,7 @@ import ( "strconv" "strings" + "golang.org/x/net/context" p "gopkg.in/gorethink/gorethink.v3/ql2" ) @@ -256,8 +257,8 @@ func (t Term) OptArgs(args interface{}) Term { type QueryExecutor interface { IsConnected() bool - Query(Query) (*Cursor, error) - Exec(Query) error + Query(context.Context, Query) (*Cursor, error) + Exec(context.Context, Query) error newQuery(t Term, opts map[string]interface{}) (Query, error) } @@ -313,6 +314,8 @@ type RunOpts struct { MaxBatchBytes interface{} `gorethink:"max_batch_bytes,omitempty"` MaxBatchSeconds interface{} `gorethink:"max_batch_seconds,omitempty"` FirstBatchScaledownFactor interface{} `gorethink:"first_batch_scaledown_factor,omitempty"` + + Context context.Context `gorethink:"-"` } func (o RunOpts) toMap() map[string]interface{} { @@ -332,8 +335,10 @@ func (o RunOpts) toMap() map[string]interface{} { // } func (t Term) Run(s QueryExecutor, optArgs ...RunOpts) (*Cursor, error) { opts := map[string]interface{}{} + var ctx context.Context = nil // if it's nil connection will form context from connection opts if len(optArgs) >= 1 { opts = optArgs[0].toMap() + ctx = optArgs[0].Context } if s == nil || !s.IsConnected() { @@ -345,7 +350,7 @@ func (t Term) Run(s QueryExecutor, optArgs ...RunOpts) (*Cursor, error) { return nil, err } - return s.Query(q) + return s.Query(ctx, q) } // RunWrite runs a query using the given connection but unlike Run automatically @@ -424,6 +429,8 @@ type ExecOpts struct { FirstBatchScaledownFactor interface{} `gorethink:"first_batch_scaledown_factor,omitempty"` NoReply interface{} `gorethink:"noreply,omitempty"` + + Context context.Context `gorethink:"-"` } func (o ExecOpts) toMap() map[string]interface{} { @@ -438,8 +445,10 @@ func (o ExecOpts) toMap() map[string]interface{} { // }) func (t Term) Exec(s QueryExecutor, optArgs ...ExecOpts) error { opts := map[string]interface{}{} + var ctx context.Context = nil // if it's nil connection will form context from connection opts if len(optArgs) >= 1 { opts = optArgs[0].toMap() + ctx = optArgs[0].Context } if s == nil || !s.IsConnected() { @@ -451,5 +460,5 @@ func (t Term) Exec(s QueryExecutor, optArgs ...ExecOpts) error { return err } - return s.Exec(q) + return s.Exec(ctx, q) } diff --git a/query_helpers.go b/query_helpers.go new file mode 100644 index 00000000..5d894b86 --- /dev/null +++ b/query_helpers.go @@ -0,0 +1,15 @@ +package gorethink + +import ( + p "gopkg.in/gorethink/gorethink.v3/ql2" +) + +func newStopQuery(token int64) Query { + return Query{ + Type: p.Query_STOP, + Token: token, + Opts: map[string]interface{}{ + "noreply": true, + }, + } +} diff --git a/session.go b/session.go index 9c3ba8db..dbdb14a9 100644 --- a/session.go +++ b/session.go @@ -5,6 +5,7 @@ import ( "sync" "time" + "golang.org/x/net/context" p "gopkg.in/gorethink/gorethink.v3/ql2" ) @@ -265,7 +266,7 @@ func (s *Session) NoReplyWait() error { return ErrConnectionClosed } - return s.cluster.Exec(Query{ + return s.cluster.Exec(nil, Query{ // nil = connection opts' defaults Type: p.Query_NOREPLY_WAIT, }) } @@ -287,7 +288,7 @@ func (s *Session) Database() string { } // Query executes a ReQL query using the session to connect to the database -func (s *Session) Query(q Query) (*Cursor, error) { +func (s *Session) Query(ctx context.Context, q Query) (*Cursor, error) { s.mu.RLock() defer s.mu.RUnlock() @@ -295,11 +296,11 @@ func (s *Session) Query(q Query) (*Cursor, error) { return nil, ErrConnectionClosed } - return s.cluster.Query(q) + return s.cluster.Query(ctx, q) } // Exec executes a ReQL query using the session to connect to the database -func (s *Session) Exec(q Query) error { +func (s *Session) Exec(ctx context.Context, q Query) error { s.mu.RLock() defer s.mu.RUnlock() @@ -307,7 +308,7 @@ func (s *Session) Exec(q Query) error { return ErrConnectionClosed } - return s.cluster.Exec(q) + return s.cluster.Exec(ctx, q) } // Server returns the server name and server UUID being used by a connection.