From b539f852b7f94d4a7b580da36a2a6a0a8612f753 Mon Sep 17 00:00:00 2001 From: Rianov Viacheslav Date: Fri, 15 Jul 2022 10:32:13 +0300 Subject: [PATCH] api: context support This patch adds the support of using context in API. The API is based on using request objects. Added tests that cover almost all cases of using the context in a query. Added benchamrk tests are equivalent to other, that use the same query but without any context. Closes #48 --- connection.go | 199 +++++++++++++++++++++++--------- prepared.go | 19 ++++ request.go | 69 ++++++++++++ tarantool_test.go | 213 ++++++++++++++++++++++++++++++++++- test_helpers/request_mock.go | 6 + 5 files changed, 448 insertions(+), 58 deletions(-) diff --git a/connection.go b/connection.go index 6de1e9d01..06de6c8e3 100644 --- a/connection.go +++ b/connection.go @@ -5,6 +5,7 @@ package tarantool import ( "bufio" "bytes" + "context" "errors" "fmt" "io" @@ -125,8 +126,11 @@ type Connection struct { c net.Conn mutex sync.Mutex // Schema contains schema loaded on connection. - Schema *Schema + Schema *Schema + // requestId contains the last request ID for requests with nil context. requestId uint32 + // contextRequestId contains the last request ID for requests with context. + contextRequestId uint32 // Greeting contains first message sent by Tarantool. Greeting *Greeting @@ -143,16 +147,57 @@ type Connection struct { var _ = Connector(&Connection{}) // Check compatibility with connector interface. +type futureList struct { + first *Future + last **Future +} + +func (list *futureList) findFuture(reqid uint32, fetch bool) *Future { + root := &list.first + for { + fut := *root + if fut == nil { + return nil + } + if fut.requestId == reqid { + if fetch { + *root = fut.next + if fut.next == nil { + list.last = root + } else { + fut.next = nil + } + } + return fut + } + root = &fut.next + } +} + +func (list *futureList) addFuture(fut *Future) { + *list.last = fut + list.last = &fut.next +} + +func (list *futureList) clear(err error, conn *Connection) { + fut := list.first + list.first = nil + list.last = &list.first + for fut != nil { + fut.SetError(err) + conn.markDone(fut) + fut, fut.next = fut.next, nil + } +} + type connShard struct { - rmut sync.Mutex - requests [requestsMap]struct { - first *Future - last **Future - } - bufmut sync.Mutex - buf smallWBuf - enc *msgpack.Encoder - _pad [16]uint64 //nolint: unused,structcheck + rmut sync.Mutex + requests [requestsMap]futureList + requestsWithCtx [requestsMap]futureList + bufmut sync.Mutex + buf smallWBuf + enc *msgpack.Encoder + _pad [16]uint64 //nolint: unused,structcheck } // Greeting is a message sent by Tarantool on connect. @@ -262,12 +307,13 @@ type SslOpts struct { // and will not finish to make attempts on authorization failures. func Connect(addr string, opts Opts) (conn *Connection, err error) { conn = &Connection{ - addr: addr, - requestId: 0, - Greeting: &Greeting{}, - control: make(chan struct{}), - opts: opts, - dec: msgpack.NewDecoder(&smallBuf{}), + addr: addr, + requestId: 0, + contextRequestId: 1, + Greeting: &Greeting{}, + control: make(chan struct{}), + opts: opts, + dec: msgpack.NewDecoder(&smallBuf{}), } maxprocs := uint32(runtime.GOMAXPROCS(-1)) if conn.opts.Concurrency == 0 || conn.opts.Concurrency > maxprocs*128 { @@ -283,8 +329,11 @@ func Connect(addr string, opts Opts) (conn *Connection, err error) { conn.shard = make([]connShard, conn.opts.Concurrency) for i := range conn.shard { shard := &conn.shard[i] - for j := range shard.requests { - shard.requests[j].last = &shard.requests[j].first + requestsLists := []*[requestsMap]futureList{&shard.requests, &shard.requestsWithCtx} + for _, requests := range requestsLists { + for j := range requests { + requests[j].last = &requests[j].first + } } } @@ -387,6 +436,13 @@ func (conn *Connection) Handle() interface{} { return conn.opts.Handle } +func (conn *Connection) cancelFuture(fut *Future, err error) { + if fut = conn.fetchFuture(fut.requestId); fut != nil { + fut.SetError(err) + conn.markDone(fut) + } +} + func (conn *Connection) dial() (err error) { var connection net.Conn network := "tcp" @@ -580,15 +636,10 @@ func (conn *Connection) closeConnection(neterr error, forever bool) (err error) } for i := range conn.shard { conn.shard[i].buf.Reset() - requests := &conn.shard[i].requests - for pos := range requests { - fut := requests[pos].first - requests[pos].first = nil - requests[pos].last = &requests[pos].first - for fut != nil { - fut.SetError(neterr) - conn.markDone(fut) - fut, fut.next = fut.next, nil + requestsLists := []*[requestsMap]futureList{&conn.shard[i].requests, &conn.shard[i].requestsWithCtx} + for _, requests := range requestsLists { + for pos := range requests { + requests[pos].clear(neterr, conn) } } } @@ -721,7 +772,7 @@ func (conn *Connection) reader(r *bufio.Reader, c net.Conn) { } } -func (conn *Connection) newFuture() (fut *Future) { +func (conn *Connection) newFuture(ctx context.Context) (fut *Future) { fut = NewFuture() if conn.rlimit != nil && conn.opts.RLimitAction == RLimitDrop { select { @@ -736,7 +787,7 @@ func (conn *Connection) newFuture() (fut *Future) { return } } - fut.requestId = conn.nextRequestId() + fut.requestId = conn.nextRequestId(ctx != nil) shardn := fut.requestId & (conn.opts.Concurrency - 1) shard := &conn.shard[shardn] shard.rmut.Lock() @@ -761,11 +812,20 @@ func (conn *Connection) newFuture() (fut *Future) { return } pos := (fut.requestId / conn.opts.Concurrency) & (requestsMap - 1) - pair := &shard.requests[pos] - *pair.last = fut - pair.last = &fut.next - if conn.opts.Timeout > 0 { - fut.timeout = time.Since(epoch) + conn.opts.Timeout + if ctx != nil { + select { + case <-ctx.Done(): + fut.SetError(fmt.Errorf("context is done")) + shard.rmut.Unlock() + return + default: + } + shard.requestsWithCtx[pos].addFuture(fut) + } else { + shard.requests[pos].addFuture(fut) + if conn.opts.Timeout > 0 { + fut.timeout = time.Since(epoch) + conn.opts.Timeout + } } shard.rmut.Unlock() if conn.rlimit != nil && conn.opts.RLimitAction == RLimitWait { @@ -785,12 +845,40 @@ func (conn *Connection) newFuture() (fut *Future) { return } +func (conn *Connection) contextWatchdog(fut *Future, ctx context.Context) { + select { + case <-fut.done: + default: + select { + case <-ctx.Done(): + conn.cancelFuture(fut, fmt.Errorf("context is done")) + default: + select { + case <-fut.done: + case <-ctx.Done(): + conn.cancelFuture(fut, fmt.Errorf("context is done")) + } + } + } +} + func (conn *Connection) send(req Request) *Future { - fut := conn.newFuture() + fut := conn.newFuture(req.Ctx()) if fut.ready == nil { return fut } + if req.Ctx() != nil { + select { + case <-req.Ctx().Done(): + conn.cancelFuture(fut, fmt.Errorf("context is done")) + return fut + default: + } + } conn.putFuture(fut, req) + if req.Ctx() != nil { + go conn.contextWatchdog(fut, req.Ctx()) + } return fut } @@ -877,25 +965,11 @@ func (conn *Connection) fetchFuture(reqid uint32) (fut *Future) { func (conn *Connection) getFutureImp(reqid uint32, fetch bool) *Future { shard := &conn.shard[reqid&(conn.opts.Concurrency-1)] pos := (reqid / conn.opts.Concurrency) & (requestsMap - 1) - pair := &shard.requests[pos] - root := &pair.first - for { - fut := *root - if fut == nil { - return nil - } - if fut.requestId == reqid { - if fetch { - *root = fut.next - if fut.next == nil { - pair.last = root - } else { - fut.next = nil - } - } - return fut - } - root = &fut.next + // futures with even requests id belong to requests list with nil context + if reqid%2 == 0 { + return shard.requests[pos].findFuture(reqid, fetch) + } else { + return shard.requestsWithCtx[pos].findFuture(reqid, fetch) } } @@ -984,8 +1058,12 @@ func (conn *Connection) read(r io.Reader) (response []byte, err error) { return } -func (conn *Connection) nextRequestId() (requestId uint32) { - return atomic.AddUint32(&conn.requestId, 1) +func (conn *Connection) nextRequestId(context bool) (requestId uint32) { + if context { + return atomic.AddUint32(&conn.contextRequestId, 2) + } else { + return atomic.AddUint32(&conn.requestId, 2) + } } // Do performs a request asynchronously on the connection. @@ -1000,6 +1078,15 @@ func (conn *Connection) Do(req Request) *Future { return fut } } + if req.Ctx() != nil { + select { + case <-req.Ctx().Done(): + fut := NewFuture() + fut.SetError(fmt.Errorf("context is done")) + return fut + default: + } + } return conn.send(req) } diff --git a/prepared.go b/prepared.go index 9508f0546..013490f41 100644 --- a/prepared.go +++ b/prepared.go @@ -1,6 +1,7 @@ package tarantool import ( + "context" "fmt" "gopkg.in/vmihailenco/msgpack.v2" @@ -58,6 +59,12 @@ func (req *PrepareRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error return fillPrepare(enc, req.expr) } +// Context sets a passed context to the request. +func (req *PrepareRequest) Context(ctx context.Context) *PrepareRequest { + req.ctx = ctx + return req +} + // UnprepareRequest helps you to create an unprepare request object for // execution by a Connection. type UnprepareRequest struct { @@ -83,6 +90,12 @@ func (req *UnprepareRequest) Body(res SchemaResolver, enc *msgpack.Encoder) erro return fillUnprepare(enc, *req.stmt) } +// Context sets a passed context to the request. +func (req *UnprepareRequest) Context(ctx context.Context) *UnprepareRequest { + req.ctx = ctx + return req +} + // ExecutePreparedRequest helps you to create an execute prepared request // object for execution by a Connection. type ExecutePreparedRequest struct { @@ -117,6 +130,12 @@ func (req *ExecutePreparedRequest) Body(res SchemaResolver, enc *msgpack.Encoder return fillExecutePrepared(enc, *req.stmt, req.args) } +// Context sets a passed context to the request. +func (req *ExecutePreparedRequest) Context(ctx context.Context) *ExecutePreparedRequest { + req.ctx = ctx + return req +} + func fillPrepare(enc *msgpack.Encoder, expr string) error { enc.EncodeMapLen(1) enc.EncodeUint64(KeySQLText) diff --git a/request.go b/request.go index a83094145..e942dcd7d 100644 --- a/request.go +++ b/request.go @@ -1,6 +1,7 @@ package tarantool import ( + "context" "errors" "reflect" "strings" @@ -537,6 +538,8 @@ type Request interface { Code() int32 // Body fills an encoder with a request body. Body(resolver SchemaResolver, enc *msgpack.Encoder) error + // Ctx returns a context of the request. + Ctx() context.Context } // ConnectedRequest is an interface that provides the info about a Connection @@ -549,6 +552,7 @@ type ConnectedRequest interface { type baseRequest struct { requestCode int32 + ctx context.Context } // Code returns a IPROTO code for the request. @@ -556,6 +560,11 @@ func (req *baseRequest) Code() int32 { return req.requestCode } +// Ctx returns a context of the request. +func (req *baseRequest) Ctx() context.Context { + return req.ctx +} + type spaceRequest struct { baseRequest space interface{} @@ -613,6 +622,12 @@ func (req *PingRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillPing(enc) } +// Context sets a passed context to the request. +func (req *PingRequest) Context(ctx context.Context) *PingRequest { + req.ctx = ctx + return req +} + // SelectRequest allows you to create a select request object for execution // by a Connection. type SelectRequest struct { @@ -683,6 +698,12 @@ func (req *SelectRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillSelect(enc, spaceNo, indexNo, req.offset, req.limit, req.iterator, req.key) } +// Context sets a passed context to the request. +func (req *SelectRequest) Context(ctx context.Context) *SelectRequest { + req.ctx = ctx + return req +} + // InsertRequest helps you to create an insert request object for execution // by a Connection. type InsertRequest struct { @@ -716,6 +737,12 @@ func (req *InsertRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillInsert(enc, spaceNo, req.tuple) } +// Context sets a passed context to the request. +func (req *InsertRequest) Context(ctx context.Context) *InsertRequest { + req.ctx = ctx + return req +} + // ReplaceRequest helps you to create a replace request object for execution // by a Connection. type ReplaceRequest struct { @@ -749,6 +776,12 @@ func (req *ReplaceRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error return fillInsert(enc, spaceNo, req.tuple) } +// Context sets a passed context to the request. +func (req *ReplaceRequest) Context(ctx context.Context) *ReplaceRequest { + req.ctx = ctx + return req +} + // DeleteRequest helps you to create a delete request object for execution // by a Connection. type DeleteRequest struct { @@ -789,6 +822,12 @@ func (req *DeleteRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillDelete(enc, spaceNo, indexNo, req.key) } +// Context sets a passed context to the request. +func (req *DeleteRequest) Context(ctx context.Context) *DeleteRequest { + req.ctx = ctx + return req +} + // UpdateRequest helps you to create an update request object for execution // by a Connection. type UpdateRequest struct { @@ -840,6 +879,12 @@ func (req *UpdateRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillUpdate(enc, spaceNo, indexNo, req.key, req.ops) } +// Context sets a passed context to the request. +func (req *UpdateRequest) Context(ctx context.Context) *UpdateRequest { + req.ctx = ctx + return req +} + // UpsertRequest helps you to create an upsert request object for execution // by a Connection. type UpsertRequest struct { @@ -884,6 +929,12 @@ func (req *UpsertRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillUpsert(enc, spaceNo, req.tuple, req.ops) } +// Context sets a passed context to the request. +func (req *UpsertRequest) Context(ctx context.Context) *UpsertRequest { + req.ctx = ctx + return req +} + // CallRequest helps you to create a call request object for execution // by a Connection. type CallRequest struct { @@ -915,6 +966,12 @@ func (req *CallRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillCall(enc, req.function, req.args) } +// Context sets a passed context to the request. +func (req *CallRequest) Context(ctx context.Context) *CallRequest { + req.ctx = ctx + return req +} + // NewCall16Request returns a new empty Call16Request. It uses request code for // Tarantool 1.6. // Deprecated since Tarantool 1.7.2. @@ -961,6 +1018,12 @@ func (req *EvalRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillEval(enc, req.expr, req.args) } +// Context sets a passed context to the request. +func (req *EvalRequest) Context(ctx context.Context) *EvalRequest { + req.ctx = ctx + return req +} + // ExecuteRequest helps you to create an execute request object for execution // by a Connection. type ExecuteRequest struct { @@ -989,3 +1052,9 @@ func (req *ExecuteRequest) Args(args interface{}) *ExecuteRequest { func (req *ExecuteRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { return fillExecute(enc, req.expr, req.args) } + +// Context sets a passed context to the request. +func (req *ExecuteRequest) Context(ctx context.Context) *ExecuteRequest { + req.ctx = ctx + return req +} diff --git a/tarantool_test.go b/tarantool_test.go index 06771338c..f5360ba6b 100644 --- a/tarantool_test.go +++ b/tarantool_test.go @@ -1,10 +1,12 @@ package tarantool_test import ( + "context" "fmt" "log" "os" "reflect" + "runtime" "strings" "sync" "testing" @@ -100,16 +102,45 @@ func BenchmarkClientSerialRequestObject(b *testing.B) { if err != nil { b.Error(err) } + req := NewSelectRequest(spaceNo). + Index(indexNo). + Offset(0). + Limit(1). + Iterator(IterEq). + Key([]interface{}{uint(1111)}) b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := conn.Do(req).Get() + if err != nil { + b.Error(err) + } + } +} + +func BenchmarkClientSerialRequestObjectWithContext(b *testing.B) { + var err error + + conn := test_helpers.ConnectWithValidation(b, server, opts) + defer conn.Close() + + _, err = conn.Replace(spaceNo, []interface{}{uint(1111), "hello", "world"}) + if err != nil { + b.Error(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + b.ResetTimer() + for i := 0; i < b.N; i++ { req := NewSelectRequest(spaceNo). Index(indexNo). - Offset(0). Limit(1). Iterator(IterEq). - Key([]interface{}{uint(1111)}) + Key([]interface{}{uint(1111)}). + Context(ctx) _, err := conn.Do(req).Get() if err != nil { b.Error(err) @@ -342,6 +373,131 @@ func BenchmarkClientParallel(b *testing.B) { }) } +func benchmarkClientParallelRequestObject(multiplier int, b *testing.B) { + conn := test_helpers.ConnectWithValidation(b, server, opts) + defer conn.Close() + + _, err := conn.Replace(spaceNo, []interface{}{uint(1111), "hello", "world"}) + if err != nil { + b.Fatal("No connection available") + } + + req := NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(IterEq). + Key([]interface{}{uint(1111)}) + + b.SetParallelism(multiplier) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = conn.Do(req) + _, err := conn.Do(req).Get() + if err != nil { + b.Error(err) + } + } + }) +} + +func benchmarkClientParallelRequestObjectWithContext(multiplier int, b *testing.B) { + conn := test_helpers.ConnectWithValidation(b, server, opts) + defer conn.Close() + + _, err := conn.Replace(spaceNo, []interface{}{uint(1111), "hello", "world"}) + if err != nil { + b.Fatal("No connection available") + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + req := NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(IterEq). + Key([]interface{}{uint(1111)}). + Context(ctx) + + b.SetParallelism(multiplier) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = conn.Do(req) + _, err := conn.Do(req).Get() + if err != nil { + b.Error(err) + } + } + }) +} + +func benchmarkClientParallelRequestObjectMixed(multiplier int, b *testing.B) { + conn := test_helpers.ConnectWithValidation(b, server, opts) + defer conn.Close() + + _, err := conn.Replace(spaceNo, []interface{}{uint(1111), "hello", "world"}) + if err != nil { + b.Fatal("No connection available") + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + req := NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(IterEq). + Key([]interface{}{uint(1111)}) + + reqWithCtx := NewSelectRequest(spaceNo). + Index(indexNo). + Limit(1). + Iterator(IterEq). + Key([]interface{}{uint(1111)}). + Context(ctx) + + b.SetParallelism(multiplier) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = conn.Do(req) + _, err := conn.Do(reqWithCtx).Get() + if err != nil { + b.Error(err) + } + } + }) +} + +func BenchmarkClientParallelRequestObject(b *testing.B) { + multipliers := []int{10, 50, 500, 1000} + conn := test_helpers.ConnectWithValidation(b, server, opts) + defer conn.Close() + + _, err := conn.Replace(spaceNo, []interface{}{uint(1111), "hello", "world"}) + if err != nil { + b.Fatal("No connection available") + } + + for _, m := range multipliers { + goroutinesNum := runtime.GOMAXPROCS(0) * m + + b.Run(fmt.Sprintf("Plain %d goroutines", goroutinesNum), func(b *testing.B) { + benchmarkClientParallelRequestObject(m, b) + }) + + b.Run(fmt.Sprintf("With Context %d goroutines", goroutinesNum), func(b *testing.B) { + benchmarkClientParallelRequestObjectWithContext(m, b) + }) + + b.Run(fmt.Sprintf("Mixed %d goroutines", goroutinesNum), func(b *testing.B) { + benchmarkClientParallelRequestObjectMixed(m, b) + }) + } +} + func BenchmarkClientParallelMassive(b *testing.B) { conn := test_helpers.ConnectWithValidation(b, server, opts) defer conn.Close() @@ -2081,6 +2237,59 @@ func TestClientRequestObjects(t *testing.T) { } } +func TestClientRequestObjectsWithNilContext(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + req := NewPingRequest().Context(nil) //nolint + resp, err := conn.Do(req).Get() + if err != nil { + t.Fatalf("Failed to Ping: %s", err.Error()) + } + if resp == nil { + t.Fatalf("Response is nil after Ping") + } + if len(resp.Data) != 0 { + t.Errorf("Response Body len != 0") + } +} + +func TestClientRequestObjectsWithPassedCanceledContext(t *testing.T) { + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + ctx, cancel := context.WithCancel(context.Background()) + req := NewPingRequest().Context(ctx) + cancel() + resp, err := conn.Do(req).Get() + if err.Error() != "context is done" { + t.Fatalf("Failed to catch an error from done context") + } + if resp != nil { + t.Fatalf("Response is not nil after the occured error") + } +} + +func TestClientRequestObjectsWithContext(t *testing.T) { + var err error + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + ctx, cancel := context.WithCancel(context.Background()) + req := NewPingRequest().Context(ctx) + fut := conn.Do(req) + cancel() + resp, err := fut.Get() + if resp != nil { + t.Fatalf("response must be nil") + } + if err == nil { + t.Fatalf("catched nil error") + } + if err.Error() != "context is done" { + t.Fatalf("wrong error catched: %v", err) + } +} + func TestComplexStructs(t *testing.T) { var err error diff --git a/test_helpers/request_mock.go b/test_helpers/request_mock.go index 00674a3a7..630d57e66 100644 --- a/test_helpers/request_mock.go +++ b/test_helpers/request_mock.go @@ -1,6 +1,8 @@ package test_helpers import ( + "context" + "github.com/tarantool/go-tarantool" "gopkg.in/vmihailenco/msgpack.v2" ) @@ -23,3 +25,7 @@ func (sr *StrangerRequest) Body(resolver tarantool.SchemaResolver, enc *msgpack. func (sr *StrangerRequest) Conn() *tarantool.Connection { return &tarantool.Connection{} } + +func (sr *StrangerRequest) Ctx() context.Context { + return nil +}