Skip to content

Commit

Permalink
api: context support
Browse files Browse the repository at this point in the history
This patch adds the support of using context in API. The API is based
on using request objects. Added tests that cover almost all cases of
using the context in a query. Added benchamrk tests are equivalent to
other, that use the same query but without any context.

Closes #48
  • Loading branch information
vr009 committed Jul 20, 2022
1 parent e1bb59c commit b539f85
Show file tree
Hide file tree
Showing 5 changed files with 448 additions and 58 deletions.
199 changes: 143 additions & 56 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package tarantool
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -125,8 +126,11 @@ type Connection struct {
c net.Conn
mutex sync.Mutex
// Schema contains schema loaded on connection.
Schema *Schema
Schema *Schema
// requestId contains the last request ID for requests with nil context.
requestId uint32
// contextRequestId contains the last request ID for requests with context.
contextRequestId uint32
// Greeting contains first message sent by Tarantool.
Greeting *Greeting

Expand All @@ -143,16 +147,57 @@ type Connection struct {

var _ = Connector(&Connection{}) // Check compatibility with connector interface.

type futureList struct {
first *Future
last **Future
}

func (list *futureList) findFuture(reqid uint32, fetch bool) *Future {
root := &list.first
for {
fut := *root
if fut == nil {
return nil
}
if fut.requestId == reqid {
if fetch {
*root = fut.next
if fut.next == nil {
list.last = root
} else {
fut.next = nil
}
}
return fut
}
root = &fut.next
}
}

func (list *futureList) addFuture(fut *Future) {
*list.last = fut
list.last = &fut.next
}

func (list *futureList) clear(err error, conn *Connection) {
fut := list.first
list.first = nil
list.last = &list.first
for fut != nil {
fut.SetError(err)
conn.markDone(fut)
fut, fut.next = fut.next, nil
}
}

type connShard struct {
rmut sync.Mutex
requests [requestsMap]struct {
first *Future
last **Future
}
bufmut sync.Mutex
buf smallWBuf
enc *msgpack.Encoder
_pad [16]uint64 //nolint: unused,structcheck
rmut sync.Mutex
requests [requestsMap]futureList
requestsWithCtx [requestsMap]futureList
bufmut sync.Mutex
buf smallWBuf
enc *msgpack.Encoder
_pad [16]uint64 //nolint: unused,structcheck
}

// Greeting is a message sent by Tarantool on connect.
Expand Down Expand Up @@ -262,12 +307,13 @@ type SslOpts struct {
// and will not finish to make attempts on authorization failures.
func Connect(addr string, opts Opts) (conn *Connection, err error) {
conn = &Connection{
addr: addr,
requestId: 0,
Greeting: &Greeting{},
control: make(chan struct{}),
opts: opts,
dec: msgpack.NewDecoder(&smallBuf{}),
addr: addr,
requestId: 0,
contextRequestId: 1,
Greeting: &Greeting{},
control: make(chan struct{}),
opts: opts,
dec: msgpack.NewDecoder(&smallBuf{}),
}
maxprocs := uint32(runtime.GOMAXPROCS(-1))
if conn.opts.Concurrency == 0 || conn.opts.Concurrency > maxprocs*128 {
Expand All @@ -283,8 +329,11 @@ func Connect(addr string, opts Opts) (conn *Connection, err error) {
conn.shard = make([]connShard, conn.opts.Concurrency)
for i := range conn.shard {
shard := &conn.shard[i]
for j := range shard.requests {
shard.requests[j].last = &shard.requests[j].first
requestsLists := []*[requestsMap]futureList{&shard.requests, &shard.requestsWithCtx}
for _, requests := range requestsLists {
for j := range requests {
requests[j].last = &requests[j].first
}
}
}

Expand Down Expand Up @@ -387,6 +436,13 @@ func (conn *Connection) Handle() interface{} {
return conn.opts.Handle
}

func (conn *Connection) cancelFuture(fut *Future, err error) {
if fut = conn.fetchFuture(fut.requestId); fut != nil {
fut.SetError(err)
conn.markDone(fut)
}
}

func (conn *Connection) dial() (err error) {
var connection net.Conn
network := "tcp"
Expand Down Expand Up @@ -580,15 +636,10 @@ func (conn *Connection) closeConnection(neterr error, forever bool) (err error)
}
for i := range conn.shard {
conn.shard[i].buf.Reset()
requests := &conn.shard[i].requests
for pos := range requests {
fut := requests[pos].first
requests[pos].first = nil
requests[pos].last = &requests[pos].first
for fut != nil {
fut.SetError(neterr)
conn.markDone(fut)
fut, fut.next = fut.next, nil
requestsLists := []*[requestsMap]futureList{&conn.shard[i].requests, &conn.shard[i].requestsWithCtx}
for _, requests := range requestsLists {
for pos := range requests {
requests[pos].clear(neterr, conn)
}
}
}
Expand Down Expand Up @@ -721,7 +772,7 @@ func (conn *Connection) reader(r *bufio.Reader, c net.Conn) {
}
}

func (conn *Connection) newFuture() (fut *Future) {
func (conn *Connection) newFuture(ctx context.Context) (fut *Future) {
fut = NewFuture()
if conn.rlimit != nil && conn.opts.RLimitAction == RLimitDrop {
select {
Expand All @@ -736,7 +787,7 @@ func (conn *Connection) newFuture() (fut *Future) {
return
}
}
fut.requestId = conn.nextRequestId()
fut.requestId = conn.nextRequestId(ctx != nil)
shardn := fut.requestId & (conn.opts.Concurrency - 1)
shard := &conn.shard[shardn]
shard.rmut.Lock()
Expand All @@ -761,11 +812,20 @@ func (conn *Connection) newFuture() (fut *Future) {
return
}
pos := (fut.requestId / conn.opts.Concurrency) & (requestsMap - 1)
pair := &shard.requests[pos]
*pair.last = fut
pair.last = &fut.next
if conn.opts.Timeout > 0 {
fut.timeout = time.Since(epoch) + conn.opts.Timeout
if ctx != nil {
select {
case <-ctx.Done():
fut.SetError(fmt.Errorf("context is done"))
shard.rmut.Unlock()
return
default:
}
shard.requestsWithCtx[pos].addFuture(fut)
} else {
shard.requests[pos].addFuture(fut)
if conn.opts.Timeout > 0 {
fut.timeout = time.Since(epoch) + conn.opts.Timeout
}
}
shard.rmut.Unlock()
if conn.rlimit != nil && conn.opts.RLimitAction == RLimitWait {
Expand All @@ -785,12 +845,40 @@ func (conn *Connection) newFuture() (fut *Future) {
return
}

func (conn *Connection) contextWatchdog(fut *Future, ctx context.Context) {
select {
case <-fut.done:
default:
select {
case <-ctx.Done():
conn.cancelFuture(fut, fmt.Errorf("context is done"))
default:
select {
case <-fut.done:
case <-ctx.Done():
conn.cancelFuture(fut, fmt.Errorf("context is done"))
}
}
}
}

func (conn *Connection) send(req Request) *Future {
fut := conn.newFuture()
fut := conn.newFuture(req.Ctx())
if fut.ready == nil {
return fut
}
if req.Ctx() != nil {
select {
case <-req.Ctx().Done():
conn.cancelFuture(fut, fmt.Errorf("context is done"))
return fut
default:
}
}
conn.putFuture(fut, req)
if req.Ctx() != nil {
go conn.contextWatchdog(fut, req.Ctx())
}
return fut
}

Expand Down Expand Up @@ -877,25 +965,11 @@ func (conn *Connection) fetchFuture(reqid uint32) (fut *Future) {
func (conn *Connection) getFutureImp(reqid uint32, fetch bool) *Future {
shard := &conn.shard[reqid&(conn.opts.Concurrency-1)]
pos := (reqid / conn.opts.Concurrency) & (requestsMap - 1)
pair := &shard.requests[pos]
root := &pair.first
for {
fut := *root
if fut == nil {
return nil
}
if fut.requestId == reqid {
if fetch {
*root = fut.next
if fut.next == nil {
pair.last = root
} else {
fut.next = nil
}
}
return fut
}
root = &fut.next
// futures with even requests id belong to requests list with nil context
if reqid%2 == 0 {
return shard.requests[pos].findFuture(reqid, fetch)
} else {
return shard.requestsWithCtx[pos].findFuture(reqid, fetch)
}
}

Expand Down Expand Up @@ -984,8 +1058,12 @@ func (conn *Connection) read(r io.Reader) (response []byte, err error) {
return
}

func (conn *Connection) nextRequestId() (requestId uint32) {
return atomic.AddUint32(&conn.requestId, 1)
func (conn *Connection) nextRequestId(context bool) (requestId uint32) {
if context {
return atomic.AddUint32(&conn.contextRequestId, 2)
} else {
return atomic.AddUint32(&conn.requestId, 2)
}
}

// Do performs a request asynchronously on the connection.
Expand All @@ -1000,6 +1078,15 @@ func (conn *Connection) Do(req Request) *Future {
return fut
}
}
if req.Ctx() != nil {
select {
case <-req.Ctx().Done():
fut := NewFuture()
fut.SetError(fmt.Errorf("context is done"))
return fut
default:
}
}
return conn.send(req)
}

Expand Down
19 changes: 19 additions & 0 deletions prepared.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tarantool

import (
"context"
"fmt"

"gopkg.in/vmihailenco/msgpack.v2"
Expand Down Expand Up @@ -58,6 +59,12 @@ func (req *PrepareRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error
return fillPrepare(enc, req.expr)
}

// Context sets a passed context to the request.
func (req *PrepareRequest) Context(ctx context.Context) *PrepareRequest {
req.ctx = ctx
return req
}

// UnprepareRequest helps you to create an unprepare request object for
// execution by a Connection.
type UnprepareRequest struct {
Expand All @@ -83,6 +90,12 @@ func (req *UnprepareRequest) Body(res SchemaResolver, enc *msgpack.Encoder) erro
return fillUnprepare(enc, *req.stmt)
}

// Context sets a passed context to the request.
func (req *UnprepareRequest) Context(ctx context.Context) *UnprepareRequest {
req.ctx = ctx
return req
}

// ExecutePreparedRequest helps you to create an execute prepared request
// object for execution by a Connection.
type ExecutePreparedRequest struct {
Expand Down Expand Up @@ -117,6 +130,12 @@ func (req *ExecutePreparedRequest) Body(res SchemaResolver, enc *msgpack.Encoder
return fillExecutePrepared(enc, *req.stmt, req.args)
}

// Context sets a passed context to the request.
func (req *ExecutePreparedRequest) Context(ctx context.Context) *ExecutePreparedRequest {
req.ctx = ctx
return req
}

func fillPrepare(enc *msgpack.Encoder, expr string) error {
enc.EncodeMapLen(1)
enc.EncodeUint64(KeySQLText)
Expand Down
Loading

0 comments on commit b539f85

Please sign in to comment.