Skip to content

Commit

Permalink
Fetch and decode query results concurrently
Browse files Browse the repository at this point in the history
  • Loading branch information
nineinchnick committed Sep 22, 2022
1 parent 79e6ed1 commit 9d1fc1f
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 77 deletions.
224 changes: 147 additions & 77 deletions trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,11 @@
//
// The driver should be used via the database/sql package:
//
// import "database/sql"
// import _ "github.com/trinodb/trino-go-client/trino"
//
// dsn := "http://user@localhost:8080?catalog=default&schema=test"
// db, err := sql.Open("trino", dsn)
// import "database/sql"
// import _ "github.com/trinodb/trino-go-client/trino"
//
// dsn := "http://user@localhost:8080?catalog=default&schema=test"
// db, err := sql.Open("trino", dsn)
package trino

import (
Expand Down Expand Up @@ -136,6 +135,8 @@ const (
kerberosRealmConfig = "KerberosRealm"
kerberosConfigPathConfig = "KerberosConfigPath"
SSLCertPathConfig = "SSLCertPath"

CHAN_SIZE = 10
)

var (
Expand Down Expand Up @@ -372,7 +373,6 @@ var customClientRegistry = struct {
// }
// trino.RegisterCustomClient("foobar", foobarClient)
// db, err := sql.Open("trino", "https://user@localhost:8080?custom_client=foobar")
//
func RegisterCustomClient(key string, client *http.Client) error {
if _, err := strconv.ParseBool(key); err == nil {
return fmt.Errorf("trino: custom client key %q is reserved", key)
Expand Down Expand Up @@ -549,11 +549,15 @@ func newErrQueryFailedFromResponse(resp *http.Response) *ErrQueryFailed {
}

type driverStmt struct {
conn *Conn
query string
user string
statsCh chan QueryProgressInfo
doneCh chan struct{}
conn *Conn
query string
user string
nextURIs chan string
httpResponses chan *http.Response
queryResponses chan queryResponse
statsCh chan QueryProgressInfo
errors chan error
doneCh chan struct{}
}

var (
Expand All @@ -565,12 +569,26 @@ var (

// Close closes statement just before releasing connection
func (st *driverStmt) Close() error {
if st.doneCh != nil {
close(st.doneCh)
if st.doneCh == nil {
return nil
}
close(st.doneCh)
if st.statsCh != nil {
<-st.statsCh
st.statsCh = nil
}
go func() {
// drain errors chan to allow goroutines to write to it
for range st.errors {
}
}()
for range st.queryResponses {
}
for range st.httpResponses {
}
close(st.nextURIs)
close(st.errors)
st.doneCh = nil
return nil
}

Expand Down Expand Up @@ -598,7 +616,7 @@ func (st *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValue)
}
// consume all results, if there are any
for err == nil {
err = rows.fetch(true)
err = rows.fetch()
}

if err != nil && err != io.EOF {
Expand Down Expand Up @@ -709,7 +727,7 @@ func (st *driverStmt) QueryContext(ctx context.Context, args []driver.NamedValue
statsCh: st.statsCh,
doneCh: st.doneCh,
}
if err = rows.fetch(false); err != nil {
if err = rows.fetch(); err != nil && err != io.EOF {
return nil, err
}
return rows, nil
Expand Down Expand Up @@ -782,9 +800,89 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
return nil, fmt.Errorf("trino: %v", err)
}

st.doneCh = make(chan struct{})
st.nextURIs = make(chan string, CHAN_SIZE)
st.httpResponses = make(chan *http.Response, CHAN_SIZE)
st.queryResponses = make(chan queryResponse, CHAN_SIZE)
st.errors = make(chan error)
st.nextURIs <- sr.NextURI
go func() {
defer close(st.httpResponses)
for {
select {
case nextURI := <-st.nextURIs:
if nextURI == "" {
return
}
hs := make(http.Header)
hs.Add(trinoUserHeader, st.user)
req, err := st.conn.newRequest("GET", nextURI, nil, hs)
if err != nil {
st.errors <- err
return
}
resp, err := st.conn.roundTrip(ctx, req)
if err != nil {
if ctx.Err() == context.Canceled {
st.errors <- context.Canceled
return
}
st.errors <- err
return
}
select {
case st.httpResponses <- resp:
case <-st.doneCh:
return
}
case <-st.doneCh:
return
}
}
}()
go func() {
defer close(st.queryResponses)
for {
select {
case resp := <-st.httpResponses:
if resp == nil {
return
}
var qresp queryResponse
d := json.NewDecoder(resp.Body)
d.UseNumber()
err = d.Decode(&qresp)
if err != nil {
st.errors <- fmt.Errorf("trino: %v", err)
return
}
err = resp.Body.Close()
if err != nil {
st.errors <- err
return
}
err = handleResponseError(resp.StatusCode, qresp.Error)
if err != nil {
st.errors <- err
return
}
select {
case st.nextURIs <- qresp.NextURI:
case <-st.doneCh:
return
}
select {
case st.queryResponses <- qresp:
case <-st.doneCh:
return
}
case <-st.doneCh:
return
}
}
}()
if st.conn.progressUpdater != nil {
st.statsCh = make(chan QueryProgressInfo)
st.doneCh = make(chan struct{})

// progress updater go func
go func() {
Expand Down Expand Up @@ -812,7 +910,6 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
st.conn.progressUpdaterPeriod.LastCallbackTime = time.Now()
st.conn.progressUpdaterPeriod.LastQueryState = sr.Stats.State
}

return &sr, handleResponseError(resp.StatusCode, sr.Error)
}

Expand Down Expand Up @@ -875,7 +972,7 @@ func (qr *driverRows) Columns() []string {
return []string{}
}
if qr.columns == nil {
if err := qr.fetch(false); err != nil {
if err := qr.fetch(); err != nil && err != io.EOF {
qr.err = err
return []string{}
}
Expand Down Expand Up @@ -917,7 +1014,7 @@ func (qr *driverRows) Next(dest []driver.Value) error {
qr.err = io.EOF
return qr.err
}
if err := qr.fetch(true); err != nil {
if err := qr.fetch(); err != nil {
qr.err = err
return err
}
Expand All @@ -927,6 +1024,9 @@ func (qr *driverRows) Next(dest []driver.Value) error {
return qr.err
}
for i, v := range qr.coltype {
if i > len(dest)-1 {
break
}
vv, err := v.ConvertValue(qr.data[qr.rowindex][i])
if err != nil {
qr.err = err
Expand All @@ -947,7 +1047,7 @@ func (qr driverRows) LastInsertId() (int64, error) {

// RowsAffected returns the number of rows affected by the query.
func (qr driverRows) RowsAffected() (int64, error) {
return qr.rowsAffected, qr.err
return qr.rowsAffected, nil
}

type queryResponse struct {
Expand Down Expand Up @@ -1016,71 +1116,34 @@ func handleResponseError(status int, respErr stmtError) error {
}
}

func (qr *driverRows) fetch(allowEOF bool) error {
if qr.nextURI == "" {
if allowEOF {
return io.EOF
}
return nil
}

for qr.nextURI != "" {
var qresp queryResponse
err := qr.executeFetchRequest(&qresp)
if err != nil {
return err
}

qr.rowindex = 0
qr.data = qresp.Data
qr.nextURI = qresp.NextURI
qr.rowsAffected = qresp.UpdateCount
qr.scheduleProgressUpdate(qresp.ID, qresp.Stats)

if len(qr.data) == 0 {
if qr.nextURI != "" {
continue
}
if allowEOF {
qr.err = io.EOF
return qr.err
func (qr *driverRows) fetch() error {
var qresp queryResponse
var err error
for {
select {
case qresp = <-qr.stmt.queryResponses:
if qresp.ID == "" {
return io.EOF
}
}
if qr.columns == nil && len(qresp.Columns) > 0 {
err = qr.initColumns(&qresp)
if err != nil {
return err
}
}
return nil
}
return nil
}

func (qr *driverRows) executeFetchRequest(qresp *queryResponse) error {
hs := make(http.Header)
hs.Add(trinoUserHeader, qr.stmt.user)
req, err := qr.stmt.conn.newRequest("GET", qr.nextURI, nil, hs)
if err != nil {
return err
}
resp, err := qr.stmt.conn.roundTrip(qr.ctx, req)
if err != nil {
if qr.ctx.Err() == context.Canceled {
qr.Close()
qr.rowindex = 0
qr.data = qresp.Data
qr.rowsAffected = qresp.UpdateCount
qr.scheduleProgressUpdate(qresp.ID, qresp.Stats)
if len(qr.data) != 0 {
return nil
}
case err = <-qr.stmt.errors:
if err == context.Canceled {
qr.Close()
}
qr.err = err
return err
}
return err
}
defer resp.Body.Close()

d := json.NewDecoder(resp.Body)
d.UseNumber()
err = d.Decode(&qresp)
if err != nil {
return fmt.Errorf("trino: %v", err)
}
return handleResponseError(resp.StatusCode, qresp.Error)
}

func unmarshalArguments(signature *typeSignature) error {
Expand Down Expand Up @@ -1112,6 +1175,9 @@ func unmarshalArguments(signature *typeSignature) error {
}

func (qr *driverRows) initColumns(qresp *queryResponse) error {
if qr.columns != nil || len(qresp.Columns) == 0 {
return nil
}
var err error
for i := range qresp.Columns {
err = unmarshalArguments(&(qresp.Columns[i].TypeSignature))
Expand All @@ -1122,6 +1188,10 @@ func (qr *driverRows) initColumns(qresp *queryResponse) error {
qr.columns = make([]string, len(qresp.Columns))
qr.coltype = make([]*typeConverter, len(qresp.Columns))
for i, col := range qresp.Columns {
err = unmarshalArguments(&(qresp.Columns[i].TypeSignature))
if err != nil {
return fmt.Errorf("error decoding column type signature: %w", err)
}
qr.columns[i] = col.Name
qr.coltype[i], err = newTypeConverter(col.Type, col.TypeSignature)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions trino/trino_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,7 @@ func TestFetchNoStackOverflow(t *testing.T) {
if buf == nil {
buf = new(bytes.Buffer)
json.NewEncoder(buf).Encode(&stmtResponse{
ID: "fake-query",
NextURI: ts.URL + "/v1/statement/20210817_140827_00000_arvdv/1",
})
}
Expand Down

0 comments on commit 9d1fc1f

Please sign in to comment.