Skip to content

Commit

Permalink
Pass parent context to requests
Browse files Browse the repository at this point in the history
  • Loading branch information
nineinchnick authored and losipiuk committed Apr 23, 2024
1 parent 38fd110 commit bdabe1c
Showing 1 changed file with 20 additions and 15 deletions.
35 changes: 20 additions & 15 deletions trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,8 @@ func (c *Conn) Close() error {
return nil
}

func (c *Conn) newRequest(method, url string, body io.Reader, hs http.Header) (*http.Request, error) {
req, err := http.NewRequest(method, url, body)
func (c *Conn) newRequest(ctx context.Context, method, url string, body io.Reader, hs http.Header) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil {
return nil, fmt.Errorf("trino: %w", err)
}
Expand Down Expand Up @@ -485,14 +485,7 @@ func (c *Conn) roundTrip(ctx context.Context, req *http.Request) (*http.Response
case <-ctx.Done():
return nil, ctx.Err()
case <-timer.C:
timeout := DefaultQueryTimeout
if deadline, ok := ctx.Deadline(); ok {
timeout = time.Until(deadline)
}
client := c.httpClient
client.Timeout = timeout
req.Cancel = ctx.Done()
resp, err := client.Do(req)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, &ErrQueryFailed{Reason: err}
}
Expand Down Expand Up @@ -845,13 +838,19 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
}
}

req, err := st.conn.newRequest("POST", st.conn.baseURL+"/v1/statement", strings.NewReader(query), hs)
var cancel context.CancelFunc = func() {}
if _, ok := ctx.Deadline(); !ok {
ctx, cancel = context.WithTimeout(ctx, DefaultQueryTimeout)
}
req, err := st.conn.newRequest(ctx, "POST", st.conn.baseURL+"/v1/statement", strings.NewReader(query), hs)
if err != nil {
cancel()
return nil, err
}

resp, err := st.conn.roundTrip(ctx, req)
if err != nil {
cancel()
return nil, err
}

Expand All @@ -861,6 +860,7 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
d.UseNumber()
err = d.Decode(&sr)
if err != nil {
cancel()
return nil, fmt.Errorf("trino: %w", err)
}

Expand All @@ -879,8 +879,12 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
}
hs := make(http.Header)
hs.Add(trinoUserHeader, st.user)
req, err := st.conn.newRequest("GET", nextURI, nil, hs)
req, err := st.conn.newRequest(ctx, "GET", nextURI, nil, hs)
if err != nil {
if ctx.Err() == context.Canceled {
st.errors <- context.Canceled
return
}
st.errors <- err
return
}
Expand All @@ -905,6 +909,7 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
}()
go func() {
defer close(st.queryResponses)
defer cancel()
for {
select {
case resp := <-st.httpResponses:
Expand Down Expand Up @@ -1011,12 +1016,12 @@ func (qr *driverRows) Close() error {
if qr.stmt.user != "" {
hs.Add(trinoUserHeader, qr.stmt.user)
}
req, err := qr.stmt.conn.newRequest("DELETE", qr.stmt.conn.baseURL+"/v1/query/"+url.PathEscape(qr.queryID), nil, hs)
ctx, cancel := context.WithTimeout(context.WithoutCancel(qr.ctx), DefaultCancelQueryTimeout)
defer cancel()
req, err := qr.stmt.conn.newRequest(ctx, "DELETE", qr.stmt.conn.baseURL+"/v1/query/"+url.PathEscape(qr.queryID), nil, hs)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), DefaultCancelQueryTimeout)
defer cancel()
resp, err := qr.stmt.conn.roundTrip(ctx, req)
if err != nil {
qferr, ok := err.(*ErrQueryFailed)
Expand Down

0 comments on commit bdabe1c

Please sign in to comment.