Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concurrent requests #57

Merged
merged 3 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 145 additions & 77 deletions trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,11 @@
//
// The driver should be used via the database/sql package:
//
// import "database/sql"
// import _ "github.com/trinodb/trino-go-client/trino"
//
// dsn := "http://user@localhost:8080?catalog=default&schema=test"
// db, err := sql.Open("trino", dsn)
// import "database/sql"
nineinchnick marked this conversation as resolved.
Show resolved Hide resolved
// import _ "github.com/trinodb/trino-go-client/trino"
//
// dsn := "http://user@localhost:8080?catalog=default&schema=test"
// db, err := sql.Open("trino", dsn)
package trino

import (
Expand Down Expand Up @@ -372,7 +371,6 @@ var customClientRegistry = struct {
// }
// trino.RegisterCustomClient("foobar", foobarClient)
// db, err := sql.Open("trino", "https://user@localhost:8080?custom_client=foobar")
//
func RegisterCustomClient(key string, client *http.Client) error {
if _, err := strconv.ParseBool(key); err == nil {
return fmt.Errorf("trino: custom client key %q is reserved", key)
Expand Down Expand Up @@ -549,11 +547,15 @@ func newErrQueryFailedFromResponse(resp *http.Response) *ErrQueryFailed {
}

type driverStmt struct {
conn *Conn
query string
user string
statsCh chan QueryProgressInfo
doneCh chan struct{}
conn *Conn
query string
user string
nextURIs chan string
httpResponses chan *http.Response
queryResponses chan queryResponse
statsCh chan QueryProgressInfo
errors chan error
doneCh chan struct{}
}

var (
Expand All @@ -565,12 +567,26 @@ var (

// Close closes statement just before releasing connection
func (st *driverStmt) Close() error {
if st.doneCh != nil {
close(st.doneCh)
if st.doneCh == nil {
return nil
}
close(st.doneCh)
if st.statsCh != nil {
<-st.statsCh
st.statsCh = nil
}
go func() {
// drain errors chan to allow goroutines to write to it
for range st.errors {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not quite get it? What would happen if we did not run active drain loop here? Would go routines writing to st.error block?
Then what would happen if we want to write to st.errors and Close() was not called? Do we assume that then fetch() is running and the erros channels is being read from?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. If an error happens in any of the two goroutines, they write errors to errors.
  2. fetch() consumes the first error and returns it to the user, who's responsible for closing the whole statement.
  3. Close() closes up all channels, which will allow goroutines to finish, but before they do that, more errors can happen. These are ignored because the errors channel is drained.

Both goroutines exit after the first error, but there could be a very rare case when we receive a malformed response AND the next HTTP request will fail. The HTTP request can fail first, but there can be an error in the other goroutine when trying to decode the previous response. Draining this errors channel handles this.

If the user doesn't call Close(), the goroutines can remain blocked on trying to write to full channels, either the regular one for results or the errors channel.

Dealing with all of this is forced by running tests with the race detector (go test -race), which fails if there's any possible race. This requires closing all the channels in the right order.

}
}()
for range st.queryResponses {
}
for range st.httpResponses {
}
close(st.nextURIs)
close(st.errors)
st.doneCh = nil
return nil
}

Expand Down Expand Up @@ -598,7 +614,7 @@ func (st *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValue)
}
// consume all results, if there are any
for err == nil {
err = rows.fetch(true)
err = rows.fetch()
}

if err != nil && err != io.EOF {
Expand Down Expand Up @@ -709,7 +725,7 @@ func (st *driverStmt) QueryContext(ctx context.Context, args []driver.NamedValue
statsCh: st.statsCh,
doneCh: st.doneCh,
}
if err = rows.fetch(false); err != nil {
if err = rows.fetch(); err != nil && err != io.EOF {
return nil, err
}
return rows, nil
Expand Down Expand Up @@ -782,9 +798,89 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
return nil, fmt.Errorf("trino: %v", err)
}

st.doneCh = make(chan struct{})
st.nextURIs = make(chan string)
st.httpResponses = make(chan *http.Response)
st.queryResponses = make(chan queryResponse)
st.errors = make(chan error)
go func() {
defer close(st.httpResponses)
for {
select {
case nextURI := <-st.nextURIs:
if nextURI == "" {
return
}
hs := make(http.Header)
hs.Add(trinoUserHeader, st.user)
req, err := st.conn.newRequest("GET", nextURI, nil, hs)
if err != nil {
st.errors <- err
return
}
resp, err := st.conn.roundTrip(ctx, req)
if err != nil {
if ctx.Err() == context.Canceled {
st.errors <- context.Canceled
return
}
st.errors <- err
return
}
select {
case st.httpResponses <- resp:
case <-st.doneCh:
return
}
case <-st.doneCh:
return
}
}
}()
go func() {
defer close(st.queryResponses)
for {
select {
case resp := <-st.httpResponses:
if resp == nil {
return
}
var qresp queryResponse
d := json.NewDecoder(resp.Body)
d.UseNumber()
err = d.Decode(&qresp)
if err != nil {
st.errors <- fmt.Errorf("trino: %v", err)
return
}
err = resp.Body.Close()
if err != nil {
st.errors <- err
return
}
err = handleResponseError(resp.StatusCode, qresp.Error)
if err != nil {
st.errors <- err
return
}
select {
case st.nextURIs <- qresp.NextURI:
case <-st.doneCh:
return
}
select {
case st.queryResponses <- qresp:
case <-st.doneCh:
return
}
case <-st.doneCh:
return
}
}
}()
st.nextURIs <- sr.NextURI
if st.conn.progressUpdater != nil {
st.statsCh = make(chan QueryProgressInfo)
st.doneCh = make(chan struct{})

// progress updater go func
go func() {
Expand Down Expand Up @@ -812,7 +908,6 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
st.conn.progressUpdaterPeriod.LastCallbackTime = time.Now()
st.conn.progressUpdaterPeriod.LastQueryState = sr.Stats.State
}

return &sr, handleResponseError(resp.StatusCode, sr.Error)
}

Expand Down Expand Up @@ -875,7 +970,7 @@ func (qr *driverRows) Columns() []string {
return []string{}
}
if qr.columns == nil {
if err := qr.fetch(false); err != nil {
if err := qr.fetch(); err != nil && err != io.EOF {
qr.err = err
return []string{}
}
Expand Down Expand Up @@ -917,7 +1012,7 @@ func (qr *driverRows) Next(dest []driver.Value) error {
qr.err = io.EOF
return qr.err
}
if err := qr.fetch(true); err != nil {
if err := qr.fetch(); err != nil {
qr.err = err
return err
}
Expand All @@ -927,6 +1022,9 @@ func (qr *driverRows) Next(dest []driver.Value) error {
return qr.err
}
for i, v := range qr.coltype {
if i > len(dest)-1 {
break
}
vv, err := v.ConvertValue(qr.data[qr.rowindex][i])
if err != nil {
qr.err = err
Expand All @@ -947,7 +1045,7 @@ func (qr driverRows) LastInsertId() (int64, error) {

// RowsAffected returns the number of rows affected by the query.
func (qr driverRows) RowsAffected() (int64, error) {
return qr.rowsAffected, qr.err
return qr.rowsAffected, nil
}

type queryResponse struct {
Expand Down Expand Up @@ -1016,71 +1114,34 @@ func handleResponseError(status int, respErr stmtError) error {
}
}

func (qr *driverRows) fetch(allowEOF bool) error {
if qr.nextURI == "" {
if allowEOF {
return io.EOF
}
return nil
}

for qr.nextURI != "" {
var qresp queryResponse
err := qr.executeFetchRequest(&qresp)
if err != nil {
return err
}

qr.rowindex = 0
qr.data = qresp.Data
qr.nextURI = qresp.NextURI
qr.rowsAffected = qresp.UpdateCount
qr.scheduleProgressUpdate(qresp.ID, qresp.Stats)

if len(qr.data) == 0 {
if qr.nextURI != "" {
continue
}
if allowEOF {
qr.err = io.EOF
return qr.err
func (qr *driverRows) fetch() error {
var qresp queryResponse
var err error
for {
select {
case qresp = <-qr.stmt.queryResponses:
if qresp.ID == "" {
return io.EOF
}
}
if qr.columns == nil && len(qresp.Columns) > 0 {
err = qr.initColumns(&qresp)
if err != nil {
return err
}
}
return nil
}
return nil
}

func (qr *driverRows) executeFetchRequest(qresp *queryResponse) error {
hs := make(http.Header)
hs.Add(trinoUserHeader, qr.stmt.user)
req, err := qr.stmt.conn.newRequest("GET", qr.nextURI, nil, hs)
if err != nil {
return err
}
resp, err := qr.stmt.conn.roundTrip(qr.ctx, req)
if err != nil {
if qr.ctx.Err() == context.Canceled {
qr.Close()
qr.rowindex = 0
qr.data = qresp.Data
qr.rowsAffected = qresp.UpdateCount
qr.scheduleProgressUpdate(qresp.ID, qresp.Stats)
if len(qr.data) != 0 {
return nil
}
case err = <-qr.stmt.errors:
if err == context.Canceled {
qr.Close()
}
qr.err = err
return err
}
return err
}
defer resp.Body.Close()

d := json.NewDecoder(resp.Body)
d.UseNumber()
err = d.Decode(&qresp)
if err != nil {
return fmt.Errorf("trino: %v", err)
}
return handleResponseError(resp.StatusCode, qresp.Error)
}

func unmarshalArguments(signature *typeSignature) error {
Expand Down Expand Up @@ -1112,6 +1173,9 @@ func unmarshalArguments(signature *typeSignature) error {
}

func (qr *driverRows) initColumns(qresp *queryResponse) error {
if qr.columns != nil || len(qresp.Columns) == 0 {
return nil
}
var err error
for i := range qresp.Columns {
err = unmarshalArguments(&(qresp.Columns[i].TypeSignature))
Expand All @@ -1122,6 +1186,10 @@ func (qr *driverRows) initColumns(qresp *queryResponse) error {
qr.columns = make([]string, len(qresp.Columns))
qr.coltype = make([]*typeConverter, len(qresp.Columns))
for i, col := range qresp.Columns {
err = unmarshalArguments(&(qresp.Columns[i].TypeSignature))
if err != nil {
return fmt.Errorf("error decoding column type signature: %w", err)
}
qr.columns[i] = col.Name
qr.coltype[i], err = newTypeConverter(col.Type, col.TypeSignature)
if err != nil {
Expand Down
27 changes: 27 additions & 0 deletions trino/trino_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,7 @@ func TestFetchNoStackOverflow(t *testing.T) {
if buf == nil {
buf = new(bytes.Buffer)
json.NewEncoder(buf).Encode(&stmtResponse{
ID: "fake-query",
NextURI: ts.URL + "/v1/statement/20210817_140827_00000_arvdv/1",
})
}
Expand Down Expand Up @@ -1385,3 +1386,29 @@ func TestSlice3TypeConversion(t *testing.T) {
})
}
}

func BenchmarkQuery(b *testing.B) {
c := &Config{
ServerURI: *integrationServerFlag,
SessionProperties: map[string]string{"query_priority": "1"},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is it for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, I copied this config from other tests :-)

}

dsn, err := c.FormatDSN()
require.NoError(b, err)

db, err := sql.Open("trino", dsn)
require.NoError(b, err)

b.Cleanup(func() {
assert.NoError(b, db.Close())
})

q := `SELECT * FROM tpch.sf1.orders LIMIT 10000000`
for n := 0; n < b.N; n++ {
rows, err := db.Query(q)
require.NoError(b, err)
for rows.Next() {
}
rows.Close()
}
}