Skip to content

Commit

Permalink
Fetch and decode query results concurrently
Browse files Browse the repository at this point in the history
  • Loading branch information
nineinchnick committed Sep 22, 2022
1 parent d60ab69 commit 019254e
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 71 deletions.
212 changes: 141 additions & 71 deletions trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -547,11 +547,15 @@ func newErrQueryFailedFromResponse(resp *http.Response) *ErrQueryFailed {
}

type driverStmt struct {
conn *Conn
query string
user string
statsCh chan QueryProgressInfo
doneCh chan struct{}
conn *Conn
query string
user string
nextURIs chan string
httpResponses chan *http.Response
queryResponses chan queryResponse
statsCh chan QueryProgressInfo
errors chan error
doneCh chan struct{}
}

var (
Expand All @@ -563,12 +567,26 @@ var (

// Close closes statement just before releasing connection
func (st *driverStmt) Close() error {
if st.doneCh != nil {
close(st.doneCh)
if st.doneCh == nil {
return nil
}
close(st.doneCh)
if st.statsCh != nil {
<-st.statsCh
st.statsCh = nil
}
go func() {
// drain errors chan to allow goroutines to write to it
for range st.errors {
}
}()
for range st.queryResponses {
}
for range st.httpResponses {
}
close(st.nextURIs)
close(st.errors)
st.doneCh = nil
return nil
}

Expand Down Expand Up @@ -596,7 +614,7 @@ func (st *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValue)
}
// consume all results, if there are any
for err == nil {
err = rows.fetch(true)
err = rows.fetch()
}

if err != nil && err != io.EOF {
Expand Down Expand Up @@ -707,7 +725,7 @@ func (st *driverStmt) QueryContext(ctx context.Context, args []driver.NamedValue
statsCh: st.statsCh,
doneCh: st.doneCh,
}
if err = rows.fetch(false); err != nil {
if err = rows.fetch(); err != nil && err != io.EOF {
return nil, err
}
return rows, nil
Expand Down Expand Up @@ -780,9 +798,89 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
return nil, fmt.Errorf("trino: %v", err)
}

st.doneCh = make(chan struct{})
st.nextURIs = make(chan string)
st.httpResponses = make(chan *http.Response)
st.queryResponses = make(chan queryResponse)
st.errors = make(chan error)
go func() {
defer close(st.httpResponses)
for {
select {
case nextURI := <-st.nextURIs:
if nextURI == "" {
return
}
hs := make(http.Header)
hs.Add(trinoUserHeader, st.user)
req, err := st.conn.newRequest("GET", nextURI, nil, hs)
if err != nil {
st.errors <- err
return
}
resp, err := st.conn.roundTrip(ctx, req)
if err != nil {
if ctx.Err() == context.Canceled {
st.errors <- context.Canceled
return
}
st.errors <- err
return
}
select {
case st.httpResponses <- resp:
case <-st.doneCh:
return
}
case <-st.doneCh:
return
}
}
}()
go func() {
defer close(st.queryResponses)
for {
select {
case resp := <-st.httpResponses:
if resp == nil {
return
}
var qresp queryResponse
d := json.NewDecoder(resp.Body)
d.UseNumber()
err = d.Decode(&qresp)
if err != nil {
st.errors <- fmt.Errorf("trino: %v", err)
return
}
err = resp.Body.Close()
if err != nil {
st.errors <- err
return
}
err = handleResponseError(resp.StatusCode, qresp.Error)
if err != nil {
st.errors <- err
return
}
select {
case st.nextURIs <- qresp.NextURI:
case <-st.doneCh:
return
}
select {
case st.queryResponses <- qresp:
case <-st.doneCh:
return
}
case <-st.doneCh:
return
}
}
}()
st.nextURIs <- sr.NextURI
if st.conn.progressUpdater != nil {
st.statsCh = make(chan QueryProgressInfo)
st.doneCh = make(chan struct{})

// progress updater go func
go func() {
Expand Down Expand Up @@ -810,7 +908,6 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
st.conn.progressUpdaterPeriod.LastCallbackTime = time.Now()
st.conn.progressUpdaterPeriod.LastQueryState = sr.Stats.State
}

return &sr, handleResponseError(resp.StatusCode, sr.Error)
}

Expand Down Expand Up @@ -873,7 +970,7 @@ func (qr *driverRows) Columns() []string {
return []string{}
}
if qr.columns == nil {
if err := qr.fetch(false); err != nil {
if err := qr.fetch(); err != nil && err != io.EOF {
qr.err = err
return []string{}
}
Expand Down Expand Up @@ -915,7 +1012,7 @@ func (qr *driverRows) Next(dest []driver.Value) error {
qr.err = io.EOF
return qr.err
}
if err := qr.fetch(true); err != nil {
if err := qr.fetch(); err != nil {
qr.err = err
return err
}
Expand All @@ -925,6 +1022,9 @@ func (qr *driverRows) Next(dest []driver.Value) error {
return qr.err
}
for i, v := range qr.coltype {
if i > len(dest)-1 {
break
}
vv, err := v.ConvertValue(qr.data[qr.rowindex][i])
if err != nil {
qr.err = err
Expand All @@ -945,7 +1045,7 @@ func (qr driverRows) LastInsertId() (int64, error) {

// RowsAffected returns the number of rows affected by the query.
func (qr driverRows) RowsAffected() (int64, error) {
return qr.rowsAffected, qr.err
return qr.rowsAffected, nil
}

type queryResponse struct {
Expand Down Expand Up @@ -1014,71 +1114,34 @@ func handleResponseError(status int, respErr stmtError) error {
}
}

func (qr *driverRows) fetch(allowEOF bool) error {
if qr.nextURI == "" {
if allowEOF {
return io.EOF
}
return nil
}

for qr.nextURI != "" {
var qresp queryResponse
err := qr.executeFetchRequest(&qresp)
if err != nil {
return err
}

qr.rowindex = 0
qr.data = qresp.Data
qr.nextURI = qresp.NextURI
qr.rowsAffected = qresp.UpdateCount
qr.scheduleProgressUpdate(qresp.ID, qresp.Stats)

if len(qr.data) == 0 {
if qr.nextURI != "" {
continue
}
if allowEOF {
qr.err = io.EOF
return qr.err
func (qr *driverRows) fetch() error {
var qresp queryResponse
var err error
for {
select {
case qresp = <-qr.stmt.queryResponses:
if qresp.ID == "" {
return io.EOF
}
}
if qr.columns == nil && len(qresp.Columns) > 0 {
err = qr.initColumns(&qresp)
if err != nil {
return err
}
}
return nil
}
return nil
}

func (qr *driverRows) executeFetchRequest(qresp *queryResponse) error {
hs := make(http.Header)
hs.Add(trinoUserHeader, qr.stmt.user)
req, err := qr.stmt.conn.newRequest("GET", qr.nextURI, nil, hs)
if err != nil {
return err
}
resp, err := qr.stmt.conn.roundTrip(qr.ctx, req)
if err != nil {
if qr.ctx.Err() == context.Canceled {
qr.Close()
qr.rowindex = 0
qr.data = qresp.Data
qr.rowsAffected = qresp.UpdateCount
qr.scheduleProgressUpdate(qresp.ID, qresp.Stats)
if len(qr.data) != 0 {
return nil
}
case err = <-qr.stmt.errors:
if err == context.Canceled {
qr.Close()
}
qr.err = err
return err
}
return err
}
defer resp.Body.Close()

d := json.NewDecoder(resp.Body)
d.UseNumber()
err = d.Decode(&qresp)
if err != nil {
return fmt.Errorf("trino: %v", err)
}
return handleResponseError(resp.StatusCode, qresp.Error)
}

func unmarshalArguments(signature *typeSignature) error {
Expand Down Expand Up @@ -1110,6 +1173,9 @@ func unmarshalArguments(signature *typeSignature) error {
}

func (qr *driverRows) initColumns(qresp *queryResponse) error {
if qr.columns != nil || len(qresp.Columns) == 0 {
return nil
}
var err error
for i := range qresp.Columns {
err = unmarshalArguments(&(qresp.Columns[i].TypeSignature))
Expand All @@ -1120,6 +1186,10 @@ func (qr *driverRows) initColumns(qresp *queryResponse) error {
qr.columns = make([]string, len(qresp.Columns))
qr.coltype = make([]*typeConverter, len(qresp.Columns))
for i, col := range qresp.Columns {
err = unmarshalArguments(&(qresp.Columns[i].TypeSignature))
if err != nil {
return fmt.Errorf("error decoding column type signature: %w", err)
}
qr.columns[i] = col.Name
qr.coltype[i], err = newTypeConverter(col.Type, col.TypeSignature)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions trino/trino_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,7 @@ func TestFetchNoStackOverflow(t *testing.T) {
if buf == nil {
buf = new(bytes.Buffer)
json.NewEncoder(buf).Encode(&stmtResponse{
ID: "fake-query",
NextURI: ts.URL + "/v1/statement/20210817_140827_00000_arvdv/1",
})
}
Expand Down

0 comments on commit 019254e

Please sign in to comment.